LLVM  16.0.0git
ExpandReductions.cpp
Go to the documentation of this file.
1 //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for reduction intrinsics, allowing targets
10 // to enable the intrinsics until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 unsigned getOpcode(Intrinsic::ID ID) {
30  switch (ID) {
31  case Intrinsic::vector_reduce_fadd:
32  return Instruction::FAdd;
33  case Intrinsic::vector_reduce_fmul:
34  return Instruction::FMul;
35  case Intrinsic::vector_reduce_add:
36  return Instruction::Add;
37  case Intrinsic::vector_reduce_mul:
38  return Instruction::Mul;
39  case Intrinsic::vector_reduce_and:
40  return Instruction::And;
41  case Intrinsic::vector_reduce_or:
42  return Instruction::Or;
43  case Intrinsic::vector_reduce_xor:
44  return Instruction::Xor;
45  case Intrinsic::vector_reduce_smax:
46  case Intrinsic::vector_reduce_smin:
47  case Intrinsic::vector_reduce_umax:
48  case Intrinsic::vector_reduce_umin:
49  return Instruction::ICmp;
50  case Intrinsic::vector_reduce_fmax:
51  case Intrinsic::vector_reduce_fmin:
52  return Instruction::FCmp;
53  default:
54  llvm_unreachable("Unexpected ID");
55  }
56 }
57 
58 RecurKind getRK(Intrinsic::ID ID) {
59  switch (ID) {
60  case Intrinsic::vector_reduce_smax:
61  return RecurKind::SMax;
62  case Intrinsic::vector_reduce_smin:
63  return RecurKind::SMin;
64  case Intrinsic::vector_reduce_umax:
65  return RecurKind::UMax;
66  case Intrinsic::vector_reduce_umin:
67  return RecurKind::UMin;
68  case Intrinsic::vector_reduce_fmax:
69  return RecurKind::FMax;
70  case Intrinsic::vector_reduce_fmin:
71  return RecurKind::FMin;
72  default:
73  return RecurKind::None;
74  }
75 }
76 
77 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78  bool Changed = false;
80  for (auto &I : instructions(F)) {
81  if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
82  switch (II->getIntrinsicID()) {
83  default: break;
84  case Intrinsic::vector_reduce_fadd:
85  case Intrinsic::vector_reduce_fmul:
86  case Intrinsic::vector_reduce_add:
87  case Intrinsic::vector_reduce_mul:
88  case Intrinsic::vector_reduce_and:
89  case Intrinsic::vector_reduce_or:
90  case Intrinsic::vector_reduce_xor:
91  case Intrinsic::vector_reduce_smax:
92  case Intrinsic::vector_reduce_smin:
93  case Intrinsic::vector_reduce_umax:
94  case Intrinsic::vector_reduce_umin:
95  case Intrinsic::vector_reduce_fmax:
96  case Intrinsic::vector_reduce_fmin:
97  if (TTI->shouldExpandReduction(II))
98  Worklist.push_back(II);
99 
100  break;
101  }
102  }
103  }
104 
105  for (auto *II : Worklist) {
106  FastMathFlags FMF =
107  isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108  Intrinsic::ID ID = II->getIntrinsicID();
109  RecurKind RK = getRK(ID);
110 
111  Value *Rdx = nullptr;
112  IRBuilder<> Builder(II);
114  Builder.setFastMathFlags(FMF);
115  switch (ID) {
116  default: llvm_unreachable("Unexpected intrinsic!");
117  case Intrinsic::vector_reduce_fadd:
118  case Intrinsic::vector_reduce_fmul: {
119  // FMFs must be attached to the call, otherwise it's an ordered reduction
120  // and it can't be handled by generating a shuffle sequence.
121  Value *Acc = II->getArgOperand(0);
122  Value *Vec = II->getArgOperand(1);
123  if (!FMF.allowReassoc())
124  Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125  else {
126  if (!isPowerOf2_32(
127  cast<FixedVectorType>(Vec->getType())->getNumElements()))
128  continue;
129 
130  Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131  Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132  Acc, Rdx, "bin.rdx");
133  }
134  break;
135  }
136  case Intrinsic::vector_reduce_add:
137  case Intrinsic::vector_reduce_mul:
138  case Intrinsic::vector_reduce_and:
139  case Intrinsic::vector_reduce_or:
140  case Intrinsic::vector_reduce_xor:
141  case Intrinsic::vector_reduce_smax:
142  case Intrinsic::vector_reduce_smin:
143  case Intrinsic::vector_reduce_umax:
144  case Intrinsic::vector_reduce_umin: {
145  Value *Vec = II->getArgOperand(0);
146  if (!isPowerOf2_32(
147  cast<FixedVectorType>(Vec->getType())->getNumElements()))
148  continue;
149 
150  Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
151  break;
152  }
153  case Intrinsic::vector_reduce_fmax:
154  case Intrinsic::vector_reduce_fmin: {
155  // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
156  // semantics of the reduction.
157  Value *Vec = II->getArgOperand(0);
158  if (!isPowerOf2_32(
159  cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
160  !FMF.noNaNs())
161  continue;
162 
163  Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
164  break;
165  }
166  }
167  II->replaceAllUsesWith(Rdx);
168  II->eraseFromParent();
169  Changed = true;
170  }
171  return Changed;
172 }
173 
174 class ExpandReductions : public FunctionPass {
175 public:
176  static char ID;
177  ExpandReductions() : FunctionPass(ID) {
179  }
180 
181  bool runOnFunction(Function &F) override {
182  const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
183  return expandReductions(F, TTI);
184  }
185 
186  void getAnalysisUsage(AnalysisUsage &AU) const override {
188  AU.setPreservesCFG();
189  }
190 };
191 }
192 
194 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
195  "Expand reduction intrinsics", false, false)
199 
201  return new ExpandReductions();
202 }
203 
206  const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
207  if (!expandReductions(F, &TTI))
208  return PreservedAnalyses::all();
210  PA.preserveSet<CFGAnalyses>();
211  return PA;
212 }
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::TargetIRAnalysis
Analysis pass providing the TargetTransformInfo.
Definition: TargetTransformInfo.h:2592
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::getShuffleReduction
Value * getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op, RecurKind MinMaxKind=RecurKind::None)
Generates a vector reduction using shufflevectors to reduce the value.
Definition: LoopUtils.cpp:946
IntrinsicInst.h
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
InstIterator.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1199
llvm::TargetTransformInfo
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Definition: TargetTransformInfo.h:173
llvm::IRBuilder<>
llvm::isPowerOf2_32
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition: MathExtras.h:458
llvm::FastMathFlags
Convenience struct for specifying and reasoning about fast-math flags.
Definition: FMF.h:21
llvm::RecurKind::SMin
@ SMin
Signed integer min implemented in terms of select(cmp()).
llvm::FastMathFlags::noNaNs
bool noNaNs() const
Definition: FMF.h:67
llvm::ExpandReductionsPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: ExpandReductions.cpp:204
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::RecurKind
RecurKind
These are the kinds of recurrences that we support.
Definition: IVDescriptors.h:35
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
llvm::getOrderedReduction
Value * getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, unsigned Op, RecurKind MinMaxKind=RecurKind::None)
Generates an ordered vector reduction using extracts to reduce the value.
Definition: LoopUtils.cpp:921
Intrinsics.h
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::FastMathFlags::allowReassoc
bool allowReassoc() const
Flag queries.
Definition: FMF.h:66
false
Definition: StackSlotColoring.cpp:141
LoopUtils.h
llvm::RecurKind::None
@ None
Not a recurrence.
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
llvm::RecurKind::UMin
@ UMin
Unisgned integer min implemented in terms of select(cmp()).
Passes.h
llvm::initializeExpandReductionsPass
void initializeExpandReductionsPass(PassRegistry &)
reductions
expand reductions
Definition: ExpandReductions.cpp:197
llvm::instructions
inst_range instructions(Function *F)
Definition: InstIterator.h:133
expand
static Expected< BitVector > expand(StringRef S, StringRef Original)
Definition: GlobPattern.cpp:26
llvm::TargetTransformInfoWrapperPass
Wrapper pass for TargetTransformInfo.
Definition: TargetTransformInfo.h:2648
intrinsics
expand Expand reduction intrinsics
Definition: ExpandReductions.cpp:198
INITIALIZE_PASS_DEPENDENCY
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
I
#define I(x, y, z)
Definition: MD5.cpp:58
IRBuilder.h
llvm::createExpandReductionsPass
FunctionPass * createExpandReductionsPass()
This pass expands the experimental reduction intrinsics into sequences of shuffles.
Definition: ExpandReductions.cpp:200
reduction
Straight line strength reduction
Definition: StraightLineStrengthReduce.cpp:266
getOpcode
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::RecurKind::UMax
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
Mul
BinaryOperator * Mul
Definition: X86PartialReduction.cpp:70
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:265
llvm_unreachable
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Definition: ErrorHandling.h:143
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::CFGAnalyses
Represents analyses that only rely on functions' control flow.
Definition: PassManager.h:113
llvm::RecurKind::FMax
@ FMax
FP max implemented in terms of select(cmp()).
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions", "Expand reduction intrinsics", false, false) INITIALIZE_PASS_END(ExpandReductions
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:185
llvm::Instruction::BinaryOps
BinaryOps
Definition: Instruction.h:793
llvm::PreservedAnalyses::preserveSet
void preserveSet()
Mark an analysis set as preserved.
Definition: PassManager.h:188
ExpandReductions.h
TargetTransformInfo.h
llvm::TargetTransformInfo::shouldExpandReduction
bool shouldExpandReduction(const IntrinsicInst *II) const
Definition: TargetTransformInfo.cpp:1163
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::RecurKind::FMin
@ FMin
FP min implemented in terms of select(cmp()).
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
InitializePasses.h
llvm::RecurKind::SMax
@ SMax
Signed integer max implemented in terms of select(cmp()).
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:39