LLVM 22.0.0git
NVPTXIRPeephole.cpp
Go to the documentation of this file.
1//===------ NVPTXIRPeephole.cpp - NVPTX IR Peephole --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements IR-level peephole optimizations. These transformations
10// run late in the NVPTX IR pass pipeline just before the instruction selection.
11//
12// Currently, it implements the following transformation(s):
13// 1. FMA folding (float/double types):
14// Transforms FMUL+FADD/FSUB sequences into FMA intrinsics when the
15// 'contract' fast-math flag is present. Supported patterns:
16// - fadd(fmul(a, b), c) => fma(a, b, c)
17// - fadd(c, fmul(a, b)) => fma(a, b, c)
18// - fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
19// - fsub(fmul(a, b), c) => fma(a, b, fneg(c))
20// - fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
21// - fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
22//
23//===----------------------------------------------------------------------===//
24
25#include "NVPTXUtilities.h"
26#include "llvm/IR/IRBuilder.h"
29#include "llvm/IR/Intrinsics.h"
30
31#define DEBUG_TYPE "nvptx-ir-peephole"
32
33using namespace llvm;
34
36 Value *Op0 = BI->getOperand(0);
37 Value *Op1 = BI->getOperand(1);
38
39 auto *FMul0 = dyn_cast<BinaryOperator>(Op0);
40 auto *FMul1 = dyn_cast<BinaryOperator>(Op1);
41
42 BinaryOperator *FMul = nullptr;
43 Value *OtherOperand = nullptr;
44 bool IsFirstOperand = false;
45
46 // Either Op0 or Op1 should be a valid FMul
47 if (FMul0 && FMul0->getOpcode() == Instruction::FMul && FMul0->hasOneUse() &&
48 FMul0->hasAllowContract()) {
49 FMul = FMul0;
50 OtherOperand = Op1;
51 IsFirstOperand = true;
52 } else if (FMul1 && FMul1->getOpcode() == Instruction::FMul &&
53 FMul1->hasOneUse() && FMul1->hasAllowContract()) {
54 FMul = FMul1;
55 OtherOperand = Op0;
56 IsFirstOperand = false;
57 } else {
58 return false;
59 }
60
61 bool IsFSub = BI->getOpcode() == Instruction::FSub;
63 const char *OpName = IsFSub ? "FSub" : "FAdd";
64 dbgs() << "Found " << OpName << " with FMul (single use) as "
65 << (IsFirstOperand ? "first" : "second") << " operand: " << *BI
66 << "\n";
67 });
68
69 Value *MulOp0 = FMul->getOperand(0);
70 Value *MulOp1 = FMul->getOperand(1);
71 IRBuilder<> Builder(BI);
72 Value *FMA = nullptr;
73
74 if (!IsFSub) {
75 // fadd(fmul(a, b), c) => fma(a, b, c)
76 // fadd(c, fmul(a, b)) => fma(a, b, c)
77 FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
78 {MulOp0, MulOp1, OtherOperand});
79 } else {
80 if (IsFirstOperand) {
81 // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
82 Value *NegOtherOp =
83 Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags());
84 FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
85 {MulOp0, MulOp1, NegOtherOp});
86 } else {
87 // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
88 Value *NegMulOp0 =
89 Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags());
90 FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
91 {NegMulOp0, MulOp1, OtherOperand});
92 }
93 }
94
95 // Combine fast-math flags from the original instructions
96 auto *FMAInst = cast<Instruction>(FMA);
97 FastMathFlags BinaryFMF = BI->getFastMathFlags();
98 FastMathFlags FMulFMF = FMul->getFastMathFlags();
99 FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
100 FastMathFlags::unionValue(BinaryFMF, FMulFMF);
101 FMAInst->setFastMathFlags(NewFMF);
102
103 LLVM_DEBUG({
104 const char *OpName = IsFSub ? "FSub" : "FAdd";
105 dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
106 });
107 BI->replaceAllUsesWith(FMA);
108 BI->eraseFromParent();
109 FMul->eraseFromParent();
110 return true;
111}
112
113static bool foldFMA(Function &F) {
114 bool Changed = false;
115
116 // Iterate and process float/double FAdd/FSub instructions with allow-contract
118 if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
119 // Only FAdd and FSub are supported.
120 if (BI->getOpcode() != Instruction::FAdd &&
121 BI->getOpcode() != Instruction::FSub)
122 continue;
123
124 // At minimum, the instruction should have allow-contract.
125 if (!BI->hasAllowContract())
126 continue;
127
128 // Only float and double are supported.
129 if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
130 continue;
131
132 if (tryFoldBinaryFMul(BI))
133 Changed = true;
134 }
135 }
136 return Changed;
137}
138
139namespace {
140
141struct NVPTXIRPeephole : public FunctionPass {
142 static char ID;
143 NVPTXIRPeephole() : FunctionPass(ID) {}
144 bool runOnFunction(Function &F) override;
145};
146
147} // namespace
148
149char NVPTXIRPeephole::ID = 0;
150INITIALIZE_PASS(NVPTXIRPeephole, "nvptx-ir-peephole", "NVPTX IR Peephole",
151 false, false)
152
153bool NVPTXIRPeephole::runOnFunction(Function &F) { return foldFMA(F); }
154
156 return new NVPTXIRPeephole();
157}
158
Expand Atomic instructions
static bool runOnFunction(Function &F, bool PostInlining)
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
static bool tryFoldBinaryFMul(BinaryOperator *BI)
static bool foldFMA(Function &F)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
#define LLVM_DEBUG(...)
Definition Debug.h:114
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:22
static FastMathFlags intersectRewrite(FastMathFlags LHS, FastMathFlags RHS)
Intersect rewrite-based flags.
Definition FMF.h:112
static FastMathFlags unionValue(FastMathFlags LHS, FastMathFlags RHS)
Union value flags.
Definition FMF.h:120
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
Changed
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createNVPTXIRPeepholePass()
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
@ FMul
Product of floats.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)