LLVM 19.0.0git
BPFCheckAndAdjustIR.cpp
Go to the documentation of this file.
1//===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Check IR and adjust IR for verifier friendly codes.
10// The following are done for IR checking:
11// - no relocation globals in PHI node.
12// The following are done for IR adjustment:
13// - remove __builtin_bpf_passthrough builtins. Target independent IR
14// optimizations are done and those builtins can be removed.
15// - remove llvm.bpf.getelementptr.and.load builtins.
16// - remove llvm.bpf.getelementptr.and.store builtins.
17//
18//===----------------------------------------------------------------------===//
19
20#include "BPF.h"
21#include "BPFCORE.h"
22#include "BPFTargetMachine.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Instruction.h"
29#include "llvm/IR/IntrinsicsBPF.h"
30#include "llvm/IR/Module.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/User.h"
33#include "llvm/IR/Value.h"
34#include "llvm/Pass.h"
36
37#define DEBUG_TYPE "bpf-check-and-opt-ir"
38
39using namespace llvm;
40
41namespace {
42
43class BPFCheckAndAdjustIR final : public ModulePass {
44 bool runOnModule(Module &F) override;
45
46public:
47 static char ID;
48 BPFCheckAndAdjustIR() : ModulePass(ID) {}
49 virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
50
51private:
52 void checkIR(Module &M);
53 bool adjustIR(Module &M);
54 bool removePassThroughBuiltin(Module &M);
55 bool removeCompareBuiltin(Module &M);
56 bool sinkMinMax(Module &M);
57 bool removeGEPBuiltins(Module &M);
58};
59} // End anonymous namespace
60
61char BPFCheckAndAdjustIR::ID = 0;
62INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
63 false, false)
64
66 return new BPFCheckAndAdjustIR();
67}
68
69void BPFCheckAndAdjustIR::checkIR(Module &M) {
70 // Ensure relocation global won't appear in PHI node
71 // This may happen if the compiler generated the following code:
72 // B1:
73 // g1 = @llvm.skb_buff:0:1...
74 // ...
75 // goto B_COMMON
76 // B2:
77 // g2 = @llvm.skb_buff:0:2...
78 // ...
79 // goto B_COMMON
80 // B_COMMON:
81 // g = PHI(g1, g2)
82 // x = load g
83 // ...
84 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
85 for (Function &F : M)
86 for (auto &BB : F)
87 for (auto &I : BB) {
88 PHINode *PN = dyn_cast<PHINode>(&I);
89 if (!PN || PN->use_empty())
90 continue;
91 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
92 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
93 if (!GV)
94 continue;
95 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
96 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
97 report_fatal_error("relocation global in PHI node");
98 }
99 }
100}
101
102bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
103 // Remove __builtin_bpf_passthrough()'s which are used to prevent
104 // certain IR optimizations. Now major IR optimizations are done,
105 // remove them.
106 bool Changed = false;
107 CallInst *ToBeDeleted = nullptr;
108 for (Function &F : M)
109 for (auto &BB : F)
110 for (auto &I : BB) {
111 if (ToBeDeleted) {
112 ToBeDeleted->eraseFromParent();
113 ToBeDeleted = nullptr;
114 }
115
116 auto *Call = dyn_cast<CallInst>(&I);
117 if (!Call)
118 continue;
119 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
120 if (!GV)
121 continue;
122 if (!GV->getName().starts_with("llvm.bpf.passthrough"))
123 continue;
124 Changed = true;
125 Value *Arg = Call->getArgOperand(1);
126 Call->replaceAllUsesWith(Arg);
127 ToBeDeleted = Call;
128 }
129 return Changed;
130}
131
132bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
133 // Remove __builtin_bpf_compare()'s which are used to prevent
134 // certain IR optimizations. Now major IR optimizations are done,
135 // remove them.
136 bool Changed = false;
137 CallInst *ToBeDeleted = nullptr;
138 for (Function &F : M)
139 for (auto &BB : F)
140 for (auto &I : BB) {
141 if (ToBeDeleted) {
142 ToBeDeleted->eraseFromParent();
143 ToBeDeleted = nullptr;
144 }
145
146 auto *Call = dyn_cast<CallInst>(&I);
147 if (!Call)
148 continue;
149 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
150 if (!GV)
151 continue;
152 if (!GV->getName().starts_with("llvm.bpf.compare"))
153 continue;
154
155 Changed = true;
156 Value *Arg0 = Call->getArgOperand(0);
157 Value *Arg1 = Call->getArgOperand(1);
158 Value *Arg2 = Call->getArgOperand(2);
159
160 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
162
163 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
164 ICmp->insertBefore(Call);
165
166 Call->replaceAllUsesWith(ICmp);
167 ToBeDeleted = Call;
168 }
169 return Changed;
170}
171
179
181 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
182 ZExt(nullptr), SExt(nullptr) {}
183};
184
186 const std::function<bool(Instruction *)> &Filter) {
187 // Check if V is:
188 // (fn %a %b) or (ext (fn %a %b))
189 // Where:
190 // ext := sext | zext
191 // fn := smin | umin | smax | umax
192 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
193 if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
194 V = ZExt->getOperand(0);
195 Info.ZExt = ZExt;
196 } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
197 V = SExt->getOperand(0);
198 Info.SExt = SExt;
199 }
200
201 auto *Call = dyn_cast<CallInst>(V);
202 if (!Call)
203 return false;
204
205 auto *Called = dyn_cast<Function>(Call->getCalledOperand());
206 if (!Called)
207 return false;
208
209 switch (Called->getIntrinsicID()) {
210 case Intrinsic::smin:
211 case Intrinsic::umin:
212 case Intrinsic::smax:
213 case Intrinsic::umax:
214 break;
215 default:
216 return false;
217 }
218
219 if (!Filter(Call))
220 return false;
221
222 Info.MinMax = Call;
223
224 return true;
225 };
226
227 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
229 if (Info.SExt) {
230 if (Info.SExt->getType() == V->getType())
231 return V;
232 return Builder.CreateSExt(V, Info.SExt->getType());
233 }
234 if (Info.ZExt) {
235 if (Info.ZExt->getType() == V->getType())
236 return V;
237 return Builder.CreateZExt(V, Info.ZExt->getType());
238 }
239 return V;
240 };
241
242 bool Changed = false;
244
245 // Check BB for instructions like:
246 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
247 //
248 // Where:
249 // fn := min | max | (sext (min ...)) | (sext (max ...))
250 //
251 // Put such instructions to SinkList.
252 for (Instruction &I : BB) {
253 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
254 if (!ICmp)
255 continue;
256 if (!ICmp->isRelational())
257 continue;
258 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
259 ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
260 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
261 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
262 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
263 if (!(FirstMinMax ^ SecondMinMax))
264 continue;
265 SinkList.push_back(FirstMinMax ? First : Second);
266 }
267
268 // Iterate SinkList and replace each (icmp ...) with corresponding
269 // `x < a && x < b` or similar expression.
270 for (auto &Info : SinkList) {
271 ICmpInst *ICmp = Info.ICmp;
272 CallInst *MinMax = Info.MinMax;
273 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
274 ICmpInst::Predicate P = Info.Predicate;
275 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
276 IID != Intrinsic::smax)
277 continue;
278
279 IRBuilder<> Builder(ICmp);
280 Value *X = Info.Other;
281 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
282 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
283 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
284 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
285 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
286 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
287 assert(IsMin ^ IsMax);
288 assert(IsLess ^ IsGreater);
289
290 Value *Replacement;
291 Value *LHS = Builder.CreateICmp(P, X, A);
292 Value *RHS = Builder.CreateICmp(P, X, B);
293 if ((IsLess && IsMin) || (IsGreater && IsMax))
294 // x < min(a, b) -> x < a && x < b
295 // x > max(a, b) -> x > a && x > b
296 Replacement = Builder.CreateLogicalAnd(LHS, RHS);
297 else
298 // x > min(a, b) -> x > a || x > b
299 // x < max(a, b) -> x < a || x < b
300 Replacement = Builder.CreateLogicalOr(LHS, RHS);
301
302 ICmp->replaceAllUsesWith(Replacement);
303
304 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
305 for (Instruction *I : ToRemove)
306 if (I && I->use_empty())
307 I->eraseFromParent();
308
309 Changed = true;
310 }
311
312 return Changed;
313}
314
315// Do the following transformation:
316//
317// x < min(a, b) -> x < a && x < b
318// x > min(a, b) -> x > a || x > b
319// x < max(a, b) -> x < a || x < b
320// x > max(a, b) -> x > a && x > b
321//
322// Such patterns are introduced by LICM.cpp:hoistMinMax()
323// transformation and might lead to BPF verification failures for
324// older kernels.
325//
326// To minimize "collateral" changes only do it for icmp + min/max
327// calls when icmp is inside a loop and min/max is outside of that
328// loop.
329//
330// Verification failure happens when:
331// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
332// - verifier can recognize RHS as a constant scalar in some context;
333// - verifier can't recognize RHS1 as a constant scalar in the same
334// context;
335//
336// The "constant scalar" is not a compile time constant, but a register
337// that holds a scalar value known to verifier at some point in time
338// during abstract interpretation.
339//
340// See also:
341// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
342bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
343 bool Changed = false;
344
345 for (Function &F : M) {
346 if (F.isDeclaration())
347 continue;
348
349 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
350 for (Loop *L : LI)
351 for (BasicBlock *BB : L->blocks()) {
352 // Filter out instructions coming from the same loop
353 Loop *BBLoop = LI.getLoopFor(BB);
354 auto OtherLoopFilter = [&](Instruction *I) {
355 return LI.getLoopFor(I->getParent()) != BBLoop;
356 };
357 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
358 }
359 }
360
361 return Changed;
362}
363
364void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
366}
367
368static void unrollGEPLoad(CallInst *Call) {
370 GEP->insertBefore(Call);
371 Load->insertBefore(Call);
372 Call->replaceAllUsesWith(Load);
373 Call->eraseFromParent();
374}
375
376static void unrollGEPStore(CallInst *Call) {
378 GEP->insertBefore(Call);
379 Store->insertBefore(Call);
380 Call->eraseFromParent();
381}
382
385 SmallVector<CallInst *> GEPStores;
386 for (auto &BB : F)
387 for (auto &Insn : BB)
388 if (auto *Call = dyn_cast<CallInst>(&Insn))
389 if (auto *Called = Call->getCalledFunction())
390 switch (Called->getIntrinsicID()) {
391 case Intrinsic::bpf_getelementptr_and_load:
392 GEPLoads.push_back(Call);
393 break;
394 case Intrinsic::bpf_getelementptr_and_store:
395 GEPStores.push_back(Call);
396 break;
397 }
398
399 if (GEPLoads.empty() && GEPStores.empty())
400 return false;
401
402 for_each(GEPLoads, unrollGEPLoad);
403 for_each(GEPStores, unrollGEPStore);
404
405 return true;
406}
407
408// Rewrites the following builtins:
409// - llvm.bpf.getelementptr.and.load
410// - llvm.bpf.getelementptr.and.store
411// As (load (getelementptr ...)) or (store (getelementptr ...)).
412bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
413 bool Changed = false;
414 for (auto &F : M)
415 Changed = removeGEPBuiltinsInFunc(F) || Changed;
416 return Changed;
417}
418
419bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
420 bool Changed = removePassThroughBuiltin(M);
421 Changed = removeCompareBuiltin(M) || Changed;
422 Changed = sinkMinMax(M) || Changed;
423 Changed = removeGEPBuiltins(M) || Changed;
424 return Changed;
425}
426
427bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
428 checkIR(M);
429 return adjustIR(M);
430}
SmallVector< AArch64_IMM::ImmInsnModel, 4 > Insn
ReachingDefAnalysis InstSet & ToRemove
static bool sinkMinMaxInBB(BasicBlock &BB, const std::function< bool(Instruction *)> &Filter)
static void unrollGEPStore(CallInst *Call)
static void unrollGEPLoad(CallInst *Call)
static bool removeGEPBuiltinsInFunc(Function &F)
#define DEBUG_TYPE
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
Hexagon Common GEP
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
#define P(N)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Value * RHS
Value * LHS
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
static constexpr StringRef TypeIdAttr
The attribute attached to globals representing a type id.
Definition: BPFCORE.h:48
static constexpr StringRef AmaAttr
The attribute attached to globals representing a field access.
Definition: BPFCORE.h:46
static std::pair< GetElementPtrInst *, StoreInst * > reconstructStore(CallInst *Call)
static std::pair< GetElementPtrInst *, LoadInst * > reconstructLoad(CallInst *Call)
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
This class represents a function call, abstracting a target machine's calling convention.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:965
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:1066
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
Value * CreateSExt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2022
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Definition: IRBuilder.h:2010
Value * CreateLogicalAnd(Value *Cond1, Value *Cond2, const Twine &Name="")
Definition: IRBuilder.h:1670
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2334
Value * CreateLogicalOr(Value *Cond1, Value *Cond2, const Twine &Name="")
Definition: IRBuilder.h:1676
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2649
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
This class represents a sign extension of integer types.
bool empty() const
Definition: SmallVector.h:94
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
bool use_empty() const
Definition: Value.h:344
This class represents zero extension of integer types.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1724
ModulePass * createBPFCheckAndAdjustIR()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
ICmpInst::Predicate Predicate
MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)