42#define DEBUG_TYPE "scalarize-masked-mem-intrin"
46class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
58 return "Scalarize Masked Memory Intrinsics";
76char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79 "Scalarize unsupported masked memory intrinsics",
false,
88 return new ScalarizeMaskedMemIntrinLegacyPass();
96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97 for (
unsigned i = 0; i != NumElts; ++i) {
98 Constant *CElt =
C->getAggregateElement(i);
99 if (!CElt || !isa<ConstantInt>(CElt))
108 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153 Type *EltTy = VecType->getElementType();
159 Builder.SetInsertPoint(InsertPt);
163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
171 const Align AdjustedAlignVal =
173 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
176 Value *VResult = Src0;
179 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
180 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
183 LoadInst *Load =
Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
184 VResult =
Builder.CreateInsertElement(VResult, Load,
Idx);
194 if (VectorWidth != 1) {
196 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
199 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
208 if (VectorWidth != 1) {
211 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
212 Builder.getIntN(VectorWidth, 0));
214 Predicate =
Builder.CreateExtractElement(Mask,
Idx);
228 CondBlock->
setName(
"cond.load");
232 LoadInst *Load =
Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
239 IfBlock = NewIfBlock;
242 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
244 Phi->addIncoming(NewVResult, CondBlock);
245 Phi->addIncoming(VResult, PrevIfBlock);
288 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
289 auto *VecType = cast<VectorType>(Src->getType());
291 Type *EltTy = VecType->getElementType();
295 Builder.SetInsertPoint(InsertPt);
299 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
300 Builder.CreateAlignedStore(Src,
Ptr, AlignVal);
306 const Align AdjustedAlignVal =
308 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
311 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
312 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
316 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
325 if (VectorWidth != 1) {
327 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
330 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
338 if (VectorWidth != 1) {
341 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
342 Builder.getIntN(VectorWidth, 0));
344 Predicate =
Builder.CreateExtractElement(Mask,
Idx);
358 CondBlock->
setName(
"cond.store");
363 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
369 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
412 auto *VecType = cast<FixedVectorType>(CI->
getType());
413 Type *EltTy = VecType->getElementType();
418 Builder.SetInsertPoint(InsertPt);
419 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
424 Value *VResult = Src0;
425 unsigned VectorWidth = VecType->getNumElements();
429 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
430 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
446 if (VectorWidth != 1) {
448 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
451 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
460 if (VectorWidth != 1) {
463 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
464 Builder.getIntN(VectorWidth, 0));
480 CondBlock->
setName(
"cond.load");
493 IfBlock = NewIfBlock;
496 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
498 Phi->addIncoming(NewVResult, CondBlock);
499 Phi->addIncoming(VResult, PrevIfBlock);
542 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
545 isa<VectorType>(Ptrs->
getType()) &&
546 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
547 "Vector of pointers is expected in masked scatter intrinsic");
551 Builder.SetInsertPoint(InsertPt);
554 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
555 unsigned VectorWidth = SrcFVTy->getNumElements();
559 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
560 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
565 Builder.CreateAlignedStore(OneElt,
Ptr, AlignVal);
574 if (VectorWidth != 1) {
576 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
579 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
587 if (VectorWidth != 1) {
590 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
591 Builder.getIntN(VectorWidth, 0));
607 CondBlock->
setName(
"cond.store");
612 Builder.CreateAlignedStore(OneElt,
Ptr, AlignVal);
618 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
631 auto *VecType = cast<FixedVectorType>(CI->
getType());
633 Type *EltTy = VecType->getElementType();
639 Builder.SetInsertPoint(InsertPt);
642 unsigned VectorWidth = VecType->getNumElements();
645 Value *VResult = PassThru;
651 unsigned MemIndex = 0;
654 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
656 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
658 ShuffleMask[
Idx] =
Idx + VectorWidth;
661 Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, MemIndex);
662 InsertElt =
Builder.CreateAlignedLoad(EltTy, NewPtr,
Align(1),
667 VResult =
Builder.CreateInsertElement(VResult, InsertElt,
Idx,
670 VResult =
Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
679 if (VectorWidth != 1) {
681 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
684 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
693 if (VectorWidth != 1) {
696 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
697 Builder.getIntN(VectorWidth, 0));
713 CondBlock->
setName(
"cond.load");
721 if ((
Idx + 1) != VectorWidth)
722 NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, 1);
728 IfBlock = NewIfBlock;
731 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
732 PHINode *ResultPhi =
Builder.CreatePHI(VecType, 2,
"res.phi.else");
738 if ((
Idx + 1) != VectorWidth) {
759 auto *VecType = cast<FixedVectorType>(Src->getType());
765 Builder.SetInsertPoint(InsertPt);
768 Type *EltTy = VecType->getElementType();
770 unsigned VectorWidth = VecType->getNumElements();
774 unsigned MemIndex = 0;
775 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
776 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
780 Value *NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, MemIndex);
791 if (VectorWidth != 1) {
793 SclrMask =
Builder.CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
796 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
803 if (VectorWidth != 1) {
806 Predicate =
Builder.CreateICmpNE(
Builder.CreateAnd(SclrMask, Mask),
807 Builder.getIntN(VectorWidth, 0));
823 CondBlock->
setName(
"cond.store");
831 if ((
Idx + 1) != VectorWidth)
832 NewPtr =
Builder.CreateConstInBoundsGEP1_32(EltTy,
Ptr, 1);
838 IfBlock = NewIfBlock;
840 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->
begin());
843 if ((
Idx + 1) != VectorWidth) {
857 std::optional<DomTreeUpdater> DTU;
859 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
861 bool EverMadeChange =
false;
862 bool MadeChange =
true;
863 auto &
DL =
F.getParent()->getDataLayout();
867 bool ModifiedDTOnIteration =
false;
869 DTU ? &*DTU :
nullptr);
872 if (ModifiedDTOnIteration)
876 EverMadeChange |= MadeChange;
878 return EverMadeChange;
881bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
882 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
884 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
885 DT = &DTWP->getDomTree();
904 bool MadeChange =
false;
907 while (CurInstIterator != BB.
end()) {
908 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
923 if (isa<ScalableVectorType>(II->
getType()) ||
925 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
931 case Intrinsic::masked_load:
939 case Intrinsic::masked_store:
946 case Intrinsic::masked_gather: {
948 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
950 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
958 case Intrinsic::masked_scatter: {
960 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
962 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
971 case Intrinsic::masked_expandload:
976 case Intrinsic::masked_compressstore:
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Value * getArgOperand(unsigned i) const
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
const BasicBlock * getParent() const
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)