LLVM 18.0.0git
BPFCheckAndAdjustIR.cpp
Go to the documentation of this file.
1//===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Check IR and adjust IR for verifier friendly codes.
10// The following are done for IR checking:
11// - no relocation globals in PHI node.
12// The following are done for IR adjustment:
13// - remove __builtin_bpf_passthrough builtins. Target independent IR
14// optimizations are done and those builtins can be removed.
15//
16//===----------------------------------------------------------------------===//
17
18#include "BPF.h"
19#include "BPFCORE.h"
20#include "BPFTargetMachine.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instruction.h"
27#include "llvm/IR/Module.h"
28#include "llvm/IR/Type.h"
29#include "llvm/IR/User.h"
30#include "llvm/IR/Value.h"
31#include "llvm/Pass.h"
33
34#define DEBUG_TYPE "bpf-check-and-opt-ir"
35
36using namespace llvm;
37
38namespace {
39
40class BPFCheckAndAdjustIR final : public ModulePass {
41 bool runOnModule(Module &F) override;
42
43public:
44 static char ID;
45 BPFCheckAndAdjustIR() : ModulePass(ID) {}
46 virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
47
48private:
49 void checkIR(Module &M);
50 bool adjustIR(Module &M);
51 bool removePassThroughBuiltin(Module &M);
52 bool removeCompareBuiltin(Module &M);
53 bool sinkMinMax(Module &M);
54};
55} // End anonymous namespace
56
57char BPFCheckAndAdjustIR::ID = 0;
58INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
59 false, false)
60
62 return new BPFCheckAndAdjustIR();
63}
64
65void BPFCheckAndAdjustIR::checkIR(Module &M) {
66 // Ensure relocation global won't appear in PHI node
67 // This may happen if the compiler generated the following code:
68 // B1:
69 // g1 = @llvm.skb_buff:0:1...
70 // ...
71 // goto B_COMMON
72 // B2:
73 // g2 = @llvm.skb_buff:0:2...
74 // ...
75 // goto B_COMMON
76 // B_COMMON:
77 // g = PHI(g1, g2)
78 // x = load g
79 // ...
80 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
81 for (Function &F : M)
82 for (auto &BB : F)
83 for (auto &I : BB) {
84 PHINode *PN = dyn_cast<PHINode>(&I);
85 if (!PN || PN->use_empty())
86 continue;
87 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
88 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
89 if (!GV)
90 continue;
91 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
92 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
93 report_fatal_error("relocation global in PHI node");
94 }
95 }
96}
97
98bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
99 // Remove __builtin_bpf_passthrough()'s which are used to prevent
100 // certain IR optimizations. Now major IR optimizations are done,
101 // remove them.
102 bool Changed = false;
103 CallInst *ToBeDeleted = nullptr;
104 for (Function &F : M)
105 for (auto &BB : F)
106 for (auto &I : BB) {
107 if (ToBeDeleted) {
108 ToBeDeleted->eraseFromParent();
109 ToBeDeleted = nullptr;
110 }
111
112 auto *Call = dyn_cast<CallInst>(&I);
113 if (!Call)
114 continue;
115 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
116 if (!GV)
117 continue;
118 if (!GV->getName().startswith("llvm.bpf.passthrough"))
119 continue;
120 Changed = true;
121 Value *Arg = Call->getArgOperand(1);
122 Call->replaceAllUsesWith(Arg);
123 ToBeDeleted = Call;
124 }
125 return Changed;
126}
127
128bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
129 // Remove __builtin_bpf_compare()'s which are used to prevent
130 // certain IR optimizations. Now major IR optimizations are done,
131 // remove them.
132 bool Changed = false;
133 CallInst *ToBeDeleted = nullptr;
134 for (Function &F : M)
135 for (auto &BB : F)
136 for (auto &I : BB) {
137 if (ToBeDeleted) {
138 ToBeDeleted->eraseFromParent();
139 ToBeDeleted = nullptr;
140 }
141
142 auto *Call = dyn_cast<CallInst>(&I);
143 if (!Call)
144 continue;
145 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
146 if (!GV)
147 continue;
148 if (!GV->getName().startswith("llvm.bpf.compare"))
149 continue;
150
151 Changed = true;
152 Value *Arg0 = Call->getArgOperand(0);
153 Value *Arg1 = Call->getArgOperand(1);
154 Value *Arg2 = Call->getArgOperand(2);
155
156 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
158
159 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
160 ICmp->insertBefore(Call);
161
162 Call->replaceAllUsesWith(ICmp);
163 ToBeDeleted = Call;
164 }
165 return Changed;
166}
167
175
177 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
178 ZExt(nullptr), SExt(nullptr) {}
179};
180
182 const std::function<bool(Instruction *)> &Filter) {
183 // Check if V is:
184 // (fn %a %b) or (ext (fn %a %b))
185 // Where:
186 // ext := sext | zext
187 // fn := smin | umin | smax | umax
188 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
189 if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
190 V = ZExt->getOperand(0);
191 Info.ZExt = ZExt;
192 } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
193 V = SExt->getOperand(0);
194 Info.SExt = SExt;
195 }
196
197 auto *Call = dyn_cast<CallInst>(V);
198 if (!Call)
199 return false;
200
201 auto *Called = dyn_cast<Function>(Call->getCalledOperand());
202 if (!Called)
203 return false;
204
205 switch (Called->getIntrinsicID()) {
206 case Intrinsic::smin:
207 case Intrinsic::umin:
208 case Intrinsic::smax:
209 case Intrinsic::umax:
210 break;
211 default:
212 return false;
213 }
214
215 if (!Filter(Call))
216 return false;
217
218 Info.MinMax = Call;
219
220 return true;
221 };
222
223 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
225 if (Info.SExt) {
226 if (Info.SExt->getType() == V->getType())
227 return V;
228 return Builder.CreateSExt(V, Info.SExt->getType());
229 }
230 if (Info.ZExt) {
231 if (Info.ZExt->getType() == V->getType())
232 return V;
233 return Builder.CreateZExt(V, Info.ZExt->getType());
234 }
235 return V;
236 };
237
238 bool Changed = false;
240
241 // Check BB for instructions like:
242 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
243 //
244 // Where:
245 // fn := min | max | (sext (min ...)) | (sext (max ...))
246 //
247 // Put such instructions to SinkList.
248 for (Instruction &I : BB) {
249 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
250 if (!ICmp)
251 continue;
252 if (!ICmp->isRelational())
253 continue;
254 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
255 ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
256 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
257 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
258 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
259 if (!(FirstMinMax ^ SecondMinMax))
260 continue;
261 SinkList.push_back(FirstMinMax ? First : Second);
262 }
263
264 // Iterate SinkList and replace each (icmp ...) with corresponding
265 // `x < a && x < b` or similar expression.
266 for (auto &Info : SinkList) {
267 ICmpInst *ICmp = Info.ICmp;
268 CallInst *MinMax = Info.MinMax;
269 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
270 ICmpInst::Predicate P = Info.Predicate;
271 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
272 IID != Intrinsic::smax)
273 continue;
274
275 IRBuilder<> Builder(ICmp);
276 Value *X = Info.Other;
277 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
278 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
279 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
280 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
281 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
282 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
283 assert(IsMin ^ IsMax);
284 assert(IsLess ^ IsGreater);
285
286 Value *Replacement;
287 Value *LHS = Builder.CreateICmp(P, X, A);
288 Value *RHS = Builder.CreateICmp(P, X, B);
289 if ((IsLess && IsMin) || (IsGreater && IsMax))
290 // x < min(a, b) -> x < a && x < b
291 // x > max(a, b) -> x > a && x > b
292 Replacement = Builder.CreateLogicalAnd(LHS, RHS);
293 else
294 // x > min(a, b) -> x > a || x > b
295 // x < max(a, b) -> x < a || x < b
296 Replacement = Builder.CreateLogicalOr(LHS, RHS);
297
298 ICmp->replaceAllUsesWith(Replacement);
299
300 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
301 for (Instruction *I : ToRemove)
302 if (I && I->use_empty())
303 I->eraseFromParent();
304
305 Changed = true;
306 }
307
308 return Changed;
309}
310
311// Do the following transformation:
312//
313// x < min(a, b) -> x < a && x < b
314// x > min(a, b) -> x > a || x > b
315// x < max(a, b) -> x < a || x < b
316// x > max(a, b) -> x > a && x > b
317//
318// Such patterns are introduced by LICM.cpp:hoistMinMax()
319// transformation and might lead to BPF verification failures for
320// older kernels.
321//
322// To minimize "collateral" changes only do it for icmp + min/max
323// calls when icmp is inside a loop and min/max is outside of that
324// loop.
325//
326// Verification failure happens when:
327// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
328// - verifier can recognize RHS as a constant scalar in some context;
329// - verifier can't recognize RHS1 as a constant scalar in the same
330// context;
331//
332// The "constant scalar" is not a compile time constant, but a register
333// that holds a scalar value known to verifier at some point in time
334// during abstract interpretation.
335//
336// See also:
337// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
338bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
339 bool Changed = false;
340
341 for (Function &F : M) {
342 if (F.isDeclaration())
343 continue;
344
345 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
346 for (Loop *L : LI)
347 for (BasicBlock *BB : L->blocks()) {
348 // Filter out instructions coming from the same loop
349 Loop *BBLoop = LI.getLoopFor(BB);
350 auto OtherLoopFilter = [&](Instruction *I) {
351 return LI.getLoopFor(I->getParent()) != BBLoop;
352 };
353 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
354 }
355 }
356
357 return Changed;
358}
359
360void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
362}
363
364bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
365 bool Changed = removePassThroughBuiltin(M);
366 Changed = removeCompareBuiltin(M) || Changed;
367 Changed = sinkMinMax(M) || Changed;
368 return Changed;
369}
370
371bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
372 checkIR(M);
373 return adjustIR(M);
374}
ReachingDefAnalysis InstSet & ToRemove
assume Assume Builder
static bool sinkMinMaxInBB(BasicBlock &BB, const std::function< bool(Instruction *)> &Filter)
#define DEBUG_TYPE
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
#define P(N)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Value * RHS
Value * LHS
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
static constexpr StringRef TypeIdAttr
The attribute attached to globals representing a type id.
Definition: BPFCORE.h:47
static constexpr StringRef AmaAttr
The attribute attached to globals representing a field access.
Definition: BPFCORE.h:45
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
This class represents a function call, abstracting a target machine's calling convention.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:711
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:801
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2628
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:83
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:594
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:47
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
This class represents a sign extension of integer types.
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:535
bool use_empty() const
Definition: Value.h:344
This class represents zero extension of integer types.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
ModulePass * createBPFCheckAndAdjustIR()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
ICmpInst::Predicate Predicate
MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)