42#define DEBUG_TYPE "scalarize-masked-mem-intrin"
46class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
58 return "Scalarize Masked Memory Intrinsics";
76char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79 "Scalarize unsupported masked memory intrinsics",
false,
88 return new ScalarizeMaskedMemIntrinLegacyPass();
96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97 for (
unsigned i = 0; i != NumElts; ++i) {
98 Constant *CElt =
C->getAggregateElement(i);
99 if (!CElt || !isa<ConstantInt>(CElt))
108 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153 Type *EltTy = VecType->getElementType();
159 Builder.SetInsertPoint(InsertPt);
163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
171 const Align AdjustedAlignVal =
177 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
180 Value *VResult = Src0;
183 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
184 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
186 Value *Gep =
Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr,
Idx);
187 LoadInst *Load =
Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
188 VResult =
Builder.CreateInsertElement(VResult, Load,
Idx);
198 if (VectorWidth != 1) {
200 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
203 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
212 if (VectorWidth != 1) {
215 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
216 Builder.getIntN(VectorWidth, 0));
218 Predicate =
Builder.CreateExtractElement(Mask,
Idx);
232 CondBlock->
setName(
"cond.load");
235 Value *Gep =
Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr,
Idx);
236 LoadInst *Load =
Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
243 IfBlock = NewIfBlock;
246 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
292 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
293 auto *VecType = cast<VectorType>(Src->getType());
295 Type *EltTy = VecType->getElementType();
299 Builder.SetInsertPoint(InsertPt);
303 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
304 Builder.CreateAlignedStore(Src,
Ptr, AlignVal);
310 const Align AdjustedAlignVal =
316 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
319 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
320 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
323 Value *Gep =
Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr,
Idx);
324 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
333 if (VectorWidth != 1) {
335 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
338 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
346 if (VectorWidth != 1) {
349 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
350 Builder.getIntN(VectorWidth, 0));
352 Predicate =
Builder.CreateExtractElement(Mask,
Idx);
366 CondBlock->
setName(
"cond.store");
370 Value *Gep =
Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr,
Idx);
371 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
377 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
420 auto *VecType = cast<FixedVectorType>(CI->
getType());
421 Type *EltTy = VecType->getElementType();
426 Builder.SetInsertPoint(InsertPt);
427 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
432 Value *VResult = Src0;
433 unsigned VectorWidth = VecType->getNumElements();
437 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
438 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
454 if (VectorWidth != 1) {
456 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
459 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
468 if (VectorWidth != 1) {
471 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
472 Builder.getIntN(VectorWidth, 0));
488 CondBlock->
setName(
"cond.load");
501 IfBlock = NewIfBlock;
504 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
550 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
553 isa<VectorType>(Ptrs->
getType()) &&
554 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
555 "Vector of pointers is expected in masked scatter intrinsic");
559 Builder.SetInsertPoint(InsertPt);
562 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
563 unsigned VectorWidth = SrcFVTy->getNumElements();
567 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
568 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
573 Builder.CreateAlignedStore(OneElt,
Ptr, AlignVal);
582 if (VectorWidth != 1) {
584 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
587 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
595 if (VectorWidth != 1) {
598 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
599 Builder.getIntN(VectorWidth, 0));
615 CondBlock->
setName(
"cond.store");
620 Builder.CreateAlignedStore(OneElt,
Ptr, AlignVal);
626 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
639 auto *VecType = cast<FixedVectorType>(CI->
getType());
641 Type *EltTy = VecType->getElementType();
647 Builder.SetInsertPoint(InsertPt);
650 unsigned VectorWidth = VecType->getNumElements();
653 Value *VResult = PassThru;
659 unsigned MemIndex = 0;
662 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
664 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
666 ShuffleMask[
Idx] =
Idx + VectorWidth;
669 Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, MemIndex);
670 InsertElt =
Builder.CreateAlignedLoad(EltTy, NewPtr,
Align(1),
675 VResult =
Builder.CreateInsertElement(VResult, InsertElt,
Idx,
678 VResult =
Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
687 if (VectorWidth != 1) {
689 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
692 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
701 if (VectorWidth != 1) {
704 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
705 Builder.getIntN(VectorWidth, 0));
721 CondBlock->
setName(
"cond.load");
729 if ((
Idx + 1) != VectorWidth)
730 NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, 1);
736 IfBlock = NewIfBlock;
739 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
740 PHINode *ResultPhi =
Builder.CreatePHI(VecType, 2,
"res.phi.else");
746 if ((
Idx + 1) != VectorWidth) {
767 auto *VecType = cast<FixedVectorType>(Src->getType());
773 Builder.SetInsertPoint(InsertPt);
776 Type *EltTy = VecType->getElementType();
778 unsigned VectorWidth = VecType->getNumElements();
782 unsigned MemIndex = 0;
783 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
784 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
788 Value *NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, MemIndex);
799 if (VectorWidth != 1) {
801 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
804 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
811 if (VectorWidth != 1) {
814 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
815 Builder.getIntN(VectorWidth, 0));
831 CondBlock->
setName(
"cond.store");
839 if ((
Idx + 1) != VectorWidth)
840 NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, 1);
846 IfBlock = NewIfBlock;
848 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
851 if ((
Idx + 1) != VectorWidth) {
865 std::optional<DomTreeUpdater> DTU;
867 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
869 bool EverMadeChange =
false;
870 bool MadeChange =
true;
871 auto &
DL =
F.getParent()->getDataLayout();
875 bool ModifiedDTOnIteration =
false;
877 DTU ? &*DTU :
nullptr);
880 if (ModifiedDTOnIteration)
884 EverMadeChange |= MadeChange;
886 return EverMadeChange;
889bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
890 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
892 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
893 DT = &DTWP->getDomTree();
912 bool MadeChange =
false;
915 while (CurInstIterator != BB.
end()) {
916 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
931 if (isa<ScalableVectorType>(II->
getType()) ||
933 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
939 case Intrinsic::masked_load:
947 case Intrinsic::masked_store:
954 case Intrinsic::masked_gather: {
956 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
958 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
966 case Intrinsic::masked_scatter: {
968 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
970 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
979 case Intrinsic::masked_expandload:
984 case Intrinsic::masked_compressstore:
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Value * getArgOperand(unsigned i) const
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
const BasicBlock * getParent() const
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
constexpr int UndefMaskElem
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Instruction * SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, bool Unreachable, MDNode *BranchWeights, DominatorTree *DT, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)