LLVM 23.0.0git
ExpandReductions.cpp
Go to the documentation of this file.
1//===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements IR expansion for reduction intrinsics, allowing targets
10// to enable the intrinsics until just before codegen.
11//
12//===----------------------------------------------------------------------===//
13
17#include "llvm/CodeGen/Passes.h"
18#include "llvm/IR/Dominators.h"
19#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Intrinsics.h"
24#include "llvm/Pass.h"
26
27using namespace llvm;
28
29namespace {
30
31bool expandReductions(Function &F, const TargetTransformInfo *TTI,
32 DominatorTree *DT, LoopInfo *LI) {
33 bool Changed = false;
35 for (auto &I : instructions(F)) {
36 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
37 switch (II->getIntrinsicID()) {
38 default: break;
39 case Intrinsic::vector_reduce_fadd:
40 case Intrinsic::vector_reduce_fmul:
41 case Intrinsic::vector_reduce_add:
42 case Intrinsic::vector_reduce_mul:
43 case Intrinsic::vector_reduce_and:
44 case Intrinsic::vector_reduce_or:
45 case Intrinsic::vector_reduce_xor:
46 case Intrinsic::vector_reduce_smax:
47 case Intrinsic::vector_reduce_smin:
48 case Intrinsic::vector_reduce_umax:
49 case Intrinsic::vector_reduce_umin:
50 case Intrinsic::vector_reduce_fmax:
51 case Intrinsic::vector_reduce_fmin:
52 if (TTI->shouldExpandReduction(II))
53 Worklist.push_back(II);
54
55 break;
56 }
57 }
58 }
59
60 for (auto *II : Worklist) {
61 FastMathFlags FMF = II->getFastMathFlagsOrNone();
62 Intrinsic::ID ID = II->getIntrinsicID();
65 TTI->getPreferredExpandedReductionShuffle(II);
66
67 Value *Rdx = nullptr;
68 IRBuilder<> Builder(II);
69 IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
70 Builder.setFastMathFlags(FMF);
71 switch (ID) {
72 default: llvm_unreachable("Unexpected intrinsic!");
73 case Intrinsic::vector_reduce_fadd:
74 case Intrinsic::vector_reduce_fmul: {
75 // FMFs must be attached to the call, otherwise it's an ordered reduction
76 // and it can't be handled by generating a shuffle sequence.
77 Value *Acc = II->getArgOperand(0);
78 Value *Vec = II->getArgOperand(1);
79 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
80 if (isa<ScalableVectorType>(Vec->getType())) {
81 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Acc, DT, LI);
82 break;
83 }
84 if (!FMF.allowReassoc())
85 Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
86 else {
87 if (!isPowerOf2_32(
88 cast<FixedVectorType>(Vec->getType())->getNumElements()))
89 continue;
90 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
91 Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
92 "bin.rdx");
93 }
94 break;
95 }
96 case Intrinsic::vector_reduce_and:
97 case Intrinsic::vector_reduce_or: {
98 // Canonicalize logical or/and reductions:
99 // Or reduction for i1 is represented as:
100 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
101 // %res = cmp ne iReduxWidth %val, 0
102 // And reduction for i1 is represented as:
103 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
104 // %res = cmp eq iReduxWidth %val, 11111
105 Value *Vec = II->getArgOperand(0);
106 auto *FTy = cast<FixedVectorType>(Vec->getType());
107 unsigned NumElts = FTy->getNumElements();
108 if (!isPowerOf2_32(NumElts))
109 continue;
110
111 if (FTy->getElementType() == Builder.getInt1Ty()) {
112 Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts));
113 if (ID == Intrinsic::vector_reduce_and) {
114 Rdx = Builder.CreateICmpEQ(
116 } else {
117 assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
118 Rdx = Builder.CreateIsNotNull(Rdx);
119 }
120 break;
121 }
122 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
123 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
124 break;
125 }
126 case Intrinsic::vector_reduce_add:
127 case Intrinsic::vector_reduce_mul:
128 case Intrinsic::vector_reduce_xor:
129 case Intrinsic::vector_reduce_smax:
130 case Intrinsic::vector_reduce_smin:
131 case Intrinsic::vector_reduce_umax:
132 case Intrinsic::vector_reduce_umin: {
133 Value *Vec = II->getArgOperand(0);
134 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
135 if (isa<ScalableVectorType>(Vec->getType())) {
136 Type *EltTy = Vec->getType()->getScalarType();
137 Value *Ident = getReductionIdentity(ID, EltTy, FMF);
138 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Ident, DT, LI);
139 break;
140 }
141 if (!isPowerOf2_32(
142 cast<FixedVectorType>(Vec->getType())->getNumElements()))
143 continue;
144 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
145 break;
146 }
147 case Intrinsic::vector_reduce_fmax:
148 case Intrinsic::vector_reduce_fmin: {
149 // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
150 // semantics of the reduction.
151 Value *Vec = II->getArgOperand(0);
152 if (!isPowerOf2_32(
153 cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
154 !FMF.noNaNs())
155 continue;
156 unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
157 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
158 break;
159 }
160 }
161 II->replaceAllUsesWith(Rdx);
162 II->eraseFromParent();
163 Changed = true;
164 }
165 return Changed;
166}
167
168class ExpandReductions : public FunctionPass {
169public:
170 static char ID;
171 ExpandReductions() : FunctionPass(ID) {}
172
173 bool runOnFunction(Function &F) override {
174 const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
175 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
176 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
177 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
178 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
179 return expandReductions(F, TTI, DT, LI);
180 }
181
182 void getAnalysisUsage(AnalysisUsage &AU) const override {
183 AU.addRequired<TargetTransformInfoWrapperPass>();
184 AU.addPreserved<DominatorTreeWrapperPass>();
185 AU.addPreserved<LoopInfoWrapperPass>();
186 }
187};
188}
189
190char ExpandReductions::ID;
191INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
192 "Expand reduction intrinsics", false, false)
194INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
195 "Expand reduction intrinsics", false, false)
196
198 return new ExpandReductions();
199}
200
203 const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
205 auto *LI = AM.getCachedResult<LoopAnalysis>(F);
206 if (!expandReductions(F, &TTI, DT, LI))
207 return PreservedAnalyses::all();
211 return PA;
212}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Expand Atomic instructions
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This pass exposes codegen information to IR-level passes.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
static LLVM_ABI Constant * getAllOnesValue(Type *Ty)
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:23
bool allowReassoc() const
Flag queries.
Definition FMF.h:64
bool noNaNs() const
Definition FMF.h:65
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2858
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:368
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI Value * getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, FastMathFlags FMF)
Given information about an @llvm.vector.reduce.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID)
Returns the arithmetic instruction opcode used when expanding a reduction.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
LLVM_ABI Value * getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op, TargetTransformInfo::ReductionShuffle RS, RecurKind MinMaxKind=RecurKind::None)
Generates a vector reduction using shufflevectors to reduce the value.
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
TargetTransformInfo TTI
RecurKind
These are the kinds of recurrences that we support.
LLVM_ABI FunctionPass * createExpandReductionsPass()
This pass expands the reduction intrinsics into sequences of shuffles.
LLVM_ABI Value * expandReductionViaLoop(IRBuilderBase &Builder, Value *Vec, unsigned RdxOpcode, Value *Acc, DominatorTree *DT=nullptr, LoopInfo *LI=nullptr)
Expand a scalable vector reduction into a runtime loop that applies RdxOpcode element by element,...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI RecurKind getMinMaxReductionRecurKind(Intrinsic::ID RdxID)
Returns the recurence kind used when expanding a min/max reduction.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI Value * getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, unsigned Op, RecurKind MinMaxKind=RecurKind::None)
Generates an ordered vector reduction using extracts to reduce the value.