84 #include "llvm/IR/IntrinsicsBPF.h"
94 #define DEBUG_TYPE "bpf-abstract-member-access"
114 using namespace llvm;
117 class BPFAbstractMemberAccess final {
126 Align RecordAlignment;
130 typedef std::stack<std::pair<CallInst *, CallInfo>> CallInfoStack;
134 BPFPreserveArrayAI = 1,
135 BPFPreserveUnionAI = 2,
136 BPFPreserveStructAI = 3,
137 BPFPreserveFieldInfoAI = 4,
144 static std::map<std::string, GlobalVariable *> GEPGlobals;
146 std::map<CallInst *, std::pair<CallInst *, CallInfo>> AIChain;
150 std::map<CallInst *, CallInfo> BaseAICalls;
164 bool removePreserveAccessIndexIntrinsic(
Function &
F);
165 void replaceWithGEP(std::vector<CallInst *> &CallList,
167 bool HasPreserveFieldInfoCall(CallInfoStack &CallStack);
172 Align RecordAlignment);
175 std::string &AccessKey,
MDNode *&BaseMeta);
177 std::string &AccessKey,
bool &IsInt32Ret);
178 uint64_t getConstant(
const Value *IndexValue);
182 std::map<std::string, GlobalVariable *> BPFAbstractMemberAccess::GEPGlobals;
184 class BPFAbstractMemberAccessLegacyPass final :
public FunctionPass {
188 return BPFAbstractMemberAccess(
TM).run(
F);
206 "BPF Abstract Member Access",
false,
false)
209 return new BPFAbstractMemberAccessLegacyPass(
TM);
212 bool BPFAbstractMemberAccess::run(
Function &
F) {
213 LLVM_DEBUG(
dbgs() <<
"********** Abstract Member Accesses **********\n");
220 if (
M->debug_compile_units().empty())
223 DL = &
M->getDataLayout();
224 return doTransformation(
F);
228 if (Tag != dwarf::DW_TAG_typedef && Tag != dwarf::DW_TAG_const_type &&
229 Tag != dwarf::DW_TAG_volatile_type &&
230 Tag != dwarf::DW_TAG_restrict_type &&
231 Tag != dwarf::DW_TAG_member)
233 if (Tag == dwarf::DW_TAG_typedef && !skipTypedef)
239 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
242 Ty = DTy->getBaseType();
248 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
251 Ty = DTy->getBaseType();
259 for (
uint32_t I = StartDim;
I < Elements.size(); ++
I) {
260 if (
auto *Element = dyn_cast_or_null<DINode>(Elements[
I]))
261 if (Element->getTag() == dwarf::DW_TAG_subrange_type) {
262 const DISubrange *SR = cast<DISubrange>(Element);
264 DimSize *= CI->getSExtValue();
272 bool BPFAbstractMemberAccess::IsPreserveDIAccessIndexCall(
const CallInst *Call,
277 const auto *GV = dyn_cast<GlobalValue>(
Call->getCalledOperand());
280 if (GV->getName().startswith(
"llvm.preserve.array.access.index")) {
281 CInfo.Kind = BPFPreserveArrayAI;
282 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
284 report_fatal_error(
"Missing metadata for llvm.preserve.array.access.index intrinsic");
285 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
286 CInfo.Base =
Call->getArgOperand(0);
287 CInfo.RecordAlignment =
288 DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
291 if (GV->getName().startswith(
"llvm.preserve.union.access.index")) {
292 CInfo.Kind = BPFPreserveUnionAI;
293 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
295 report_fatal_error(
"Missing metadata for llvm.preserve.union.access.index intrinsic");
296 CInfo.AccessIndex = getConstant(
Call->getArgOperand(1));
297 CInfo.Base =
Call->getArgOperand(0);
298 CInfo.RecordAlignment =
299 DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
302 if (GV->getName().startswith(
"llvm.preserve.struct.access.index")) {
303 CInfo.Kind = BPFPreserveStructAI;
304 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
306 report_fatal_error(
"Missing metadata for llvm.preserve.struct.access.index intrinsic");
307 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
308 CInfo.Base =
Call->getArgOperand(0);
309 CInfo.RecordAlignment =
310 DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
313 if (GV->getName().startswith(
"llvm.bpf.preserve.field.info")) {
314 CInfo.Kind = BPFPreserveFieldInfoAI;
315 CInfo.Metadata =
nullptr;
317 uint64_t InfoKind = getConstant(
Call->getArgOperand(1));
320 CInfo.AccessIndex = InfoKind;
323 if (GV->getName().startswith(
"llvm.bpf.preserve.type.info")) {
324 CInfo.Kind = BPFPreserveFieldInfoAI;
325 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
328 uint64_t
Flag = getConstant(
Call->getArgOperand(1));
337 if (GV->getName().startswith(
"llvm.bpf.preserve.enum.value")) {
338 CInfo.Kind = BPFPreserveFieldInfoAI;
339 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
342 uint64_t
Flag = getConstant(
Call->getArgOperand(2));
355 void BPFAbstractMemberAccess::replaceWithGEP(std::vector<CallInst *> &CallList,
358 for (
auto Call : CallList) {
360 if (DimensionIndex > 0)
361 Dimension = getConstant(
Call->getArgOperand(DimensionIndex));
367 IdxList.push_back(Zero);
368 IdxList.push_back(
Call->getArgOperand(GEPIndex));
373 Call->eraseFromParent();
377 bool BPFAbstractMemberAccess::removePreserveAccessIndexIntrinsic(
Function &
F) {
378 std::vector<CallInst *> PreserveArrayIndexCalls;
379 std::vector<CallInst *> PreserveUnionIndexCalls;
380 std::vector<CallInst *> PreserveStructIndexCalls;
385 auto *
Call = dyn_cast<CallInst>(&
I);
387 if (!IsPreserveDIAccessIndexCall(Call, CInfo))
391 if (CInfo.Kind == BPFPreserveArrayAI)
392 PreserveArrayIndexCalls.push_back(Call);
393 else if (CInfo.Kind == BPFPreserveUnionAI)
394 PreserveUnionIndexCalls.push_back(Call);
396 PreserveStructIndexCalls.push_back(Call);
409 replaceWithGEP(PreserveArrayIndexCalls, 1, 2);
410 replaceWithGEP(PreserveStructIndexCalls, 0, 1);
411 for (
auto Call : PreserveUnionIndexCalls) {
412 Call->replaceAllUsesWith(
Call->getArgOperand(0));
413 Call->eraseFromParent();
422 bool BPFAbstractMemberAccess::IsValidAIChain(
const MDNode *ParentType,
424 const MDNode *ChildType) {
433 if (isa<DIDerivedType>(CType))
437 if (
const auto *PtrTy = dyn_cast<DIDerivedType>(PType)) {
438 if (PtrTy->getTag() != dwarf::DW_TAG_pointer_type)
444 const auto *PTy = dyn_cast<DICompositeType>(PType);
445 const auto *CTy = dyn_cast<DICompositeType>(CType);
446 assert(PTy && CTy &&
"ParentType or ChildType is null or not composite");
449 assert(PTyTag == dwarf::DW_TAG_array_type ||
450 PTyTag == dwarf::DW_TAG_structure_type ||
451 PTyTag == dwarf::DW_TAG_union_type);
454 assert(CTyTag == dwarf::DW_TAG_array_type ||
455 CTyTag == dwarf::DW_TAG_structure_type ||
456 CTyTag == dwarf::DW_TAG_union_type);
459 if (PTyTag == dwarf::DW_TAG_array_type && PTyTag == CTyTag)
460 return PTy->getBaseType() == CTy->getBaseType();
463 if (PTyTag == dwarf::DW_TAG_array_type)
464 Ty = PTy->getBaseType();
466 Ty = dyn_cast<DIType>(PTy->getElements()[ParentAI]);
471 void BPFAbstractMemberAccess::traceAICall(
CallInst *Call,
478 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
479 traceBitCast(BI, Call, ParentInfo);
480 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
483 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
484 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
485 ChildInfo.Metadata)) {
486 AIChain[CI] = std::make_pair(Call, ParentInfo);
487 traceAICall(CI, ChildInfo);
489 BaseAICalls[
Call] = ParentInfo;
491 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
492 if (GI->hasAllZeroIndices())
493 traceGEP(GI, Call, ParentInfo);
495 BaseAICalls[
Call] = ParentInfo;
497 BaseAICalls[
Call] = ParentInfo;
502 void BPFAbstractMemberAccess::traceBitCast(
BitCastInst *BitCast,
510 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
511 traceBitCast(BI, Parent, ParentInfo);
512 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
514 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
515 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
516 ChildInfo.Metadata)) {
517 AIChain[CI] = std::make_pair(Parent, ParentInfo);
518 traceAICall(CI, ChildInfo);
520 BaseAICalls[Parent] = ParentInfo;
522 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
523 if (GI->hasAllZeroIndices())
524 traceGEP(GI, Parent, ParentInfo);
526 BaseAICalls[Parent] = ParentInfo;
528 BaseAICalls[Parent] = ParentInfo;
540 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
541 traceBitCast(BI, Parent, ParentInfo);
542 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
544 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
545 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
546 ChildInfo.Metadata)) {
547 AIChain[CI] = std::make_pair(Parent, ParentInfo);
548 traceAICall(CI, ChildInfo);
550 BaseAICalls[Parent] = ParentInfo;
552 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
553 if (GI->hasAllZeroIndices())
554 traceGEP(GI, Parent, ParentInfo);
556 BaseAICalls[Parent] = ParentInfo;
558 BaseAICalls[Parent] = ParentInfo;
563 void BPFAbstractMemberAccess::collectAICallChains(
Function &
F) {
570 auto *
Call = dyn_cast<CallInst>(&
I);
571 if (!IsPreserveDIAccessIndexCall(Call, CInfo) ||
572 AIChain.find(Call) != AIChain.end())
575 traceAICall(Call, CInfo);
579 uint64_t BPFAbstractMemberAccess::getConstant(
const Value *IndexValue) {
580 const ConstantInt *CV = dyn_cast<ConstantInt>(IndexValue);
586 void BPFAbstractMemberAccess::GetStorageBitRange(
DIDerivedType *MemberTy,
587 Align RecordAlignment,
593 if (RecordAlignment > 8 || MemberBitSize > AlignBits)
595 "requiring too big alignment");
597 StartBitOffset = MemberBitOffset & ~(AlignBits - 1);
598 if ((StartBitOffset + AlignBits) < (MemberBitOffset + MemberBitSize))
600 "cross alignment boundary");
601 EndBitOffset = StartBitOffset + AlignBits;
608 Align RecordAlignment) {
614 if (Tag == dwarf::DW_TAG_array_type) {
617 (EltTy->getSizeInBits() >> 3);
618 }
else if (Tag == dwarf::DW_TAG_structure_type) {
619 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
623 unsigned SBitOffset, NextSBitOffset;
624 GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset,
626 PatchImm += SBitOffset >> 3;
633 if (Tag == dwarf::DW_TAG_array_type) {
635 return calcArraySize(CTy, 1) * (EltTy->getSizeInBits() >> 3);
637 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
640 return SizeInBits >> 3;
642 unsigned SBitOffset, NextSBitOffset;
643 GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
644 SizeInBits = NextSBitOffset - SBitOffset;
645 if (SizeInBits & (SizeInBits - 1))
647 return SizeInBits >> 3;
653 if (Tag == dwarf::DW_TAG_array_type) {
659 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
664 const auto *BTy = dyn_cast<DIBasicType>(BaseTy);
666 const auto *CompTy = dyn_cast<DICompositeType>(BaseTy);
668 if (!CompTy || CompTy->getTag() != dwarf::DW_TAG_enumeration_type)
671 BTy = dyn_cast<DIBasicType>(BaseTy);
673 uint32_t Encoding = BTy->getEncoding();
674 return (Encoding == dwarf::DW_ATE_signed || Encoding == dwarf::DW_ATE_signed_char);
684 bool IsBitField =
false;
687 if (Tag == dwarf::DW_TAG_array_type) {
691 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
699 return 64 - SizeInBits;
702 unsigned SBitOffset, NextSBitOffset;
703 GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
704 if (NextSBitOffset - SBitOffset > 64)
709 return SBitOffset + 64 - OffsetInBits - SizeInBits;
711 return OffsetInBits + 64 - NextSBitOffset;
716 bool IsBitField =
false;
718 if (Tag == dwarf::DW_TAG_array_type) {
722 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
730 return 64 - SizeInBits;
733 unsigned SBitOffset, NextSBitOffset;
734 GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
735 if (NextSBitOffset - SBitOffset > 64)
738 return 64 - SizeInBits;
744 bool BPFAbstractMemberAccess::HasPreserveFieldInfoCall(CallInfoStack &CallStack) {
746 while (CallStack.size()) {
747 auto StackElem = CallStack.top();
748 if (StackElem.second.Kind == BPFPreserveFieldInfoAI)
758 Value *BPFAbstractMemberAccess::computeBaseAndAccessKey(
CallInst *Call,
760 std::string &AccessKey,
762 Value *Base =
nullptr;
764 CallInfoStack CallStack;
768 CallStack.push(std::make_pair(Call, CInfo));
769 CInfo = AIChain[
Call].second;
786 while (CallStack.size()) {
787 auto StackElem = CallStack.top();
788 Call = StackElem.first;
789 CInfo = StackElem.second;
797 if (CInfo.Kind == BPFPreserveUnionAI ||
798 CInfo.Kind == BPFPreserveStructAI) {
802 TypeMeta = PossibleTypeDef;
807 assert(CInfo.Kind == BPFPreserveArrayAI);
813 uint64_t AccessIndex = CInfo.AccessIndex;
816 bool CheckElemType =
false;
817 if (
const auto *CTy = dyn_cast<DICompositeType>(Ty)) {
827 auto *DTy = cast<DIDerivedType>(Ty);
828 assert(DTy->getTag() == dwarf::DW_TAG_pointer_type);
831 CTy = dyn_cast<DICompositeType>(BaseTy);
833 CheckElemType =
true;
834 }
else if (CTy->
getTag() != dwarf::DW_TAG_array_type) {
835 FirstIndex += AccessIndex;
836 CheckElemType =
true;
843 auto *CTy = dyn_cast<DICompositeType>(BaseTy);
845 if (HasPreserveFieldInfoCall(CallStack))
850 unsigned CTag = CTy->
getTag();
851 if (CTag == dwarf::DW_TAG_structure_type || CTag == dwarf::DW_TAG_union_type) {
854 if (HasPreserveFieldInfoCall(CallStack))
868 while (CallStack.size()) {
869 auto StackElem = CallStack.top();
870 CInfo = StackElem.second;
873 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
874 InfoKind = CInfo.AccessIndex;
880 if (CallStack.size()) {
881 auto StackElem2 = CallStack.top();
882 CallInfo CInfo2 = StackElem2.second;
883 if (CInfo2.Kind == BPFPreserveFieldInfoAI) {
884 InfoKind = CInfo2.AccessIndex;
885 assert(CallStack.size() == 1);
890 uint64_t AccessIndex = CInfo.AccessIndex;
893 MDNode *MDN = CInfo.Metadata;
896 PatchImm = GetFieldInfo(InfoKind, CTy, AccessIndex, PatchImm,
897 CInfo.RecordAlignment);
914 std::string &AccessKey,
920 std::string AccessStr(
"0");
931 const auto *
CE = cast<ConstantExpr>(
Call->getArgOperand(1));
944 const auto *CTy = cast<DICompositeType>(BaseTy);
945 assert(CTy->
getTag() == dwarf::DW_TAG_enumeration_type);
948 const auto *
Enum = cast<DIEnumerator>(Element);
949 if (
Enum->getName() == EnumeratorStr) {
958 PatchImm = std::stoll(std::string(EValueStr));
964 AccessKey =
"llvm." + Ty->
getName().
str() +
":" +
973 bool BPFAbstractMemberAccess::transformGEPChain(
CallInst *Call,
975 std::string AccessKey;
977 Value *Base =
nullptr;
980 IsInt32Ret = CInfo.Kind == BPFPreserveFieldInfoAI;
981 if (CInfo.Kind == BPFPreserveFieldInfoAI && CInfo.Metadata) {
982 TypeMeta = computeAccessKey(Call, CInfo, AccessKey, IsInt32Ret);
984 Base = computeBaseAndAccessKey(Call, CInfo, AccessKey, TypeMeta);
992 if (GEPGlobals.find(AccessKey) == GEPGlobals.end()) {
1002 GV->
setMetadata(LLVMContext::MD_preserve_access_index, TypeMeta);
1003 GEPGlobals[AccessKey] = GV;
1005 GV = GEPGlobals[AccessKey];
1008 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
1018 Call->replaceAllUsesWith(PassThroughInst);
1019 Call->eraseFromParent();
1038 BB->getInstList().insert(
Call->getIterator(), BCInst);
1043 BB->getInstList().insert(
Call->getIterator(),
GEP);
1047 BB->getInstList().insert(
Call->getIterator(), BCInst2);
1096 Call->replaceAllUsesWith(PassThroughInst);
1097 Call->eraseFromParent();
1102 bool BPFAbstractMemberAccess::doTransformation(
Function &
F) {
1103 bool Transformed =
false;
1108 collectAICallChains(
F);
1110 for (
auto &
C : BaseAICalls)
1111 Transformed = transformGEPChain(
C.first,
C.second) || Transformed;
1113 return removePreserveAccessIndexIntrinsic(
F) || Transformed;