85 #include "llvm/IR/IntrinsicsBPF.h"
95 #define DEBUG_TYPE "bpf-abstract-member-access"
105 M, Intrinsic::bpf_passthrough, {Input->getType(), Input->getType()});
115 using namespace llvm;
118 class BPFAbstractMemberAccess final {
131 typedef std::stack<std::pair<CallInst *, CallInfo>> CallInfoStack;
135 BPFPreserveArrayAI = 1,
136 BPFPreserveUnionAI = 2,
137 BPFPreserveStructAI = 3,
138 BPFPreserveFieldInfoAI = 4,
145 static std::map<std::string, GlobalVariable *> GEPGlobals;
147 std::map<CallInst *, std::pair<CallInst *, CallInfo>> AIChain;
151 std::map<CallInst *, CallInfo> BaseAICalls;
165 bool removePreserveAccessIndexIntrinsic(
Function &
F);
166 void replaceWithGEP(std::vector<CallInst *> &CallList,
168 bool HasPreserveFieldInfoCall(CallInfoStack &CallStack);
176 std::string &AccessKey,
MDNode *&BaseMeta);
178 std::string &AccessKey,
bool &IsInt32Ret);
183 std::map<std::string, GlobalVariable *> BPFAbstractMemberAccess::GEPGlobals;
185 class BPFAbstractMemberAccessLegacyPass final :
public FunctionPass {
189 return BPFAbstractMemberAccess(
TM).run(
F);
207 "BPF Abstract Member Access",
false,
false)
210 return new BPFAbstractMemberAccessLegacyPass(
TM);
214 LLVM_DEBUG(
dbgs() <<
"********** Abstract Member Accesses **********\n");
221 if (
M->debug_compile_units().empty())
224 DL = &
M->getDataLayout();
225 return doTransformation(
F);
229 if (
Tag != dwarf::DW_TAG_typedef &&
Tag != dwarf::DW_TAG_const_type &&
230 Tag != dwarf::DW_TAG_volatile_type &&
231 Tag != dwarf::DW_TAG_restrict_type &&
232 Tag != dwarf::DW_TAG_member)
234 if (
Tag == dwarf::DW_TAG_typedef && !skipTypedef)
240 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
243 Ty = DTy->getBaseType();
249 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
252 Ty = DTy->getBaseType();
260 for (
uint32_t I = StartDim;
I < Elements.size(); ++
I) {
261 if (
auto *Element = dyn_cast_or_null<DINode>(Elements[
I]))
262 if (Element->getTag() == dwarf::DW_TAG_subrange_type) {
263 const DISubrange *SR = cast<DISubrange>(Element);
265 DimSize *= CI->getSExtValue();
274 return Call->getParamElementType(0);
278 bool BPFAbstractMemberAccess::IsPreserveDIAccessIndexCall(
const CallInst *Call,
283 const auto *GV = dyn_cast<GlobalValue>(
Call->getCalledOperand());
286 if (GV->getName().startswith(
"llvm.preserve.array.access.index")) {
287 CInfo.Kind = BPFPreserveArrayAI;
288 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
290 report_fatal_error(
"Missing metadata for llvm.preserve.array.access.index intrinsic");
291 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
292 CInfo.Base =
Call->getArgOperand(0);
296 if (GV->getName().startswith(
"llvm.preserve.union.access.index")) {
297 CInfo.Kind = BPFPreserveUnionAI;
298 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
300 report_fatal_error(
"Missing metadata for llvm.preserve.union.access.index intrinsic");
301 CInfo.AccessIndex = getConstant(
Call->getArgOperand(1));
302 CInfo.Base =
Call->getArgOperand(0);
305 if (GV->getName().startswith(
"llvm.preserve.struct.access.index")) {
306 CInfo.Kind = BPFPreserveStructAI;
307 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
309 report_fatal_error(
"Missing metadata for llvm.preserve.struct.access.index intrinsic");
310 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
311 CInfo.Base =
Call->getArgOperand(0);
315 if (GV->getName().startswith(
"llvm.bpf.preserve.field.info")) {
316 CInfo.Kind = BPFPreserveFieldInfoAI;
317 CInfo.Metadata =
nullptr;
319 uint64_t InfoKind = getConstant(
Call->getArgOperand(1));
322 CInfo.AccessIndex = InfoKind;
325 if (GV->getName().startswith(
"llvm.bpf.preserve.type.info")) {
326 CInfo.Kind = BPFPreserveFieldInfoAI;
327 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
339 if (GV->getName().startswith(
"llvm.bpf.preserve.enum.value")) {
340 CInfo.Kind = BPFPreserveFieldInfoAI;
341 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
357 void BPFAbstractMemberAccess::replaceWithGEP(std::vector<CallInst *> &CallList,
360 for (
auto Call : CallList) {
362 if (DimensionIndex > 0)
363 Dimension = getConstant(
Call->getArgOperand(DimensionIndex));
369 IdxList.push_back(Zero);
370 IdxList.push_back(
Call->getArgOperand(GEPIndex));
375 Call->eraseFromParent();
379 bool BPFAbstractMemberAccess::removePreserveAccessIndexIntrinsic(
Function &
F) {
380 std::vector<CallInst *> PreserveArrayIndexCalls;
381 std::vector<CallInst *> PreserveUnionIndexCalls;
382 std::vector<CallInst *> PreserveStructIndexCalls;
387 auto *
Call = dyn_cast<CallInst>(&
I);
389 if (!IsPreserveDIAccessIndexCall(Call, CInfo))
393 if (CInfo.Kind == BPFPreserveArrayAI)
394 PreserveArrayIndexCalls.push_back(Call);
395 else if (CInfo.Kind == BPFPreserveUnionAI)
396 PreserveUnionIndexCalls.push_back(Call);
398 PreserveStructIndexCalls.push_back(Call);
411 replaceWithGEP(PreserveArrayIndexCalls, 1, 2);
412 replaceWithGEP(PreserveStructIndexCalls, 0, 1);
413 for (
auto Call : PreserveUnionIndexCalls) {
414 Call->replaceAllUsesWith(
Call->getArgOperand(0));
415 Call->eraseFromParent();
424 bool BPFAbstractMemberAccess::IsValidAIChain(
const MDNode *ParentType,
426 const MDNode *ChildType) {
435 if (isa<DIDerivedType>(CType))
439 if (
const auto *PtrTy = dyn_cast<DIDerivedType>(PType)) {
440 if (PtrTy->getTag() != dwarf::DW_TAG_pointer_type)
446 const auto *PTy = dyn_cast<DICompositeType>(PType);
447 const auto *CTy = dyn_cast<DICompositeType>(CType);
448 assert(PTy && CTy &&
"ParentType or ChildType is null or not composite");
451 assert(PTyTag == dwarf::DW_TAG_array_type ||
452 PTyTag == dwarf::DW_TAG_structure_type ||
453 PTyTag == dwarf::DW_TAG_union_type);
456 assert(CTyTag == dwarf::DW_TAG_array_type ||
457 CTyTag == dwarf::DW_TAG_structure_type ||
458 CTyTag == dwarf::DW_TAG_union_type);
461 if (PTyTag == dwarf::DW_TAG_array_type && PTyTag == CTyTag)
462 return PTy->getBaseType() == CTy->getBaseType();
465 if (PTyTag == dwarf::DW_TAG_array_type)
466 Ty = PTy->getBaseType();
468 Ty = dyn_cast<DIType>(PTy->getElements()[ParentAI]);
473 void BPFAbstractMemberAccess::traceAICall(
CallInst *Call,
480 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
481 traceBitCast(BI, Call, ParentInfo);
482 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
485 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
486 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
487 ChildInfo.Metadata)) {
488 AIChain[CI] = std::make_pair(Call, ParentInfo);
489 traceAICall(CI, ChildInfo);
491 BaseAICalls[
Call] = ParentInfo;
493 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
494 if (GI->hasAllZeroIndices())
495 traceGEP(GI, Call, ParentInfo);
497 BaseAICalls[
Call] = ParentInfo;
499 BaseAICalls[
Call] = ParentInfo;
504 void BPFAbstractMemberAccess::traceBitCast(
BitCastInst *BitCast,
512 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
513 traceBitCast(BI, Parent, ParentInfo);
514 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
516 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
517 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
518 ChildInfo.Metadata)) {
519 AIChain[CI] = std::make_pair(Parent, ParentInfo);
520 traceAICall(CI, ChildInfo);
522 BaseAICalls[Parent] = ParentInfo;
524 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
525 if (GI->hasAllZeroIndices())
526 traceGEP(GI, Parent, ParentInfo);
528 BaseAICalls[Parent] = ParentInfo;
530 BaseAICalls[Parent] = ParentInfo;
542 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
543 traceBitCast(BI, Parent, ParentInfo);
544 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
546 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
547 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
548 ChildInfo.Metadata)) {
549 AIChain[CI] = std::make_pair(Parent, ParentInfo);
550 traceAICall(CI, ChildInfo);
552 BaseAICalls[Parent] = ParentInfo;
554 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
555 if (GI->hasAllZeroIndices())
556 traceGEP(GI, Parent, ParentInfo);
558 BaseAICalls[Parent] = ParentInfo;
560 BaseAICalls[Parent] = ParentInfo;
565 void BPFAbstractMemberAccess::collectAICallChains(
Function &
F) {
572 auto *
Call = dyn_cast<CallInst>(&
I);
573 if (!IsPreserveDIAccessIndexCall(Call, CInfo) ||
574 AIChain.find(Call) != AIChain.end())
577 traceAICall(Call, CInfo);
581 uint64_t BPFAbstractMemberAccess::getConstant(
const Value *IndexValue) {
582 const ConstantInt *CV = dyn_cast<ConstantInt>(IndexValue);
588 void BPFAbstractMemberAccess::GetStorageBitRange(
DIDerivedType *MemberTy,
589 Align RecordAlignment,
595 if (RecordAlignment > 8) {
598 if (MemberBitOffset / 64 != (MemberBitOffset + MemberBitSize) / 64)
600 "requiring too big alignment");
601 RecordAlignment =
Align(8);
605 if (MemberBitSize > AlignBits)
607 "bitfield size greater than record alignment");
609 StartBitOffset = MemberBitOffset & ~(AlignBits - 1);
610 if ((StartBitOffset + AlignBits) < (MemberBitOffset + MemberBitSize))
612 "cross alignment boundary");
613 EndBitOffset = StartBitOffset + AlignBits;
626 if (
Tag == dwarf::DW_TAG_array_type) {
629 (EltTy->getSizeInBits() >> 3);
630 }
else if (
Tag == dwarf::DW_TAG_structure_type) {
631 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
635 unsigned SBitOffset, NextSBitOffset;
636 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset,
638 PatchImm += SBitOffset >> 3;
645 if (
Tag == dwarf::DW_TAG_array_type) {
647 return calcArraySize(CTy, 1) * (EltTy->getSizeInBits() >> 3);
649 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
652 return SizeInBits >> 3;
654 unsigned SBitOffset, NextSBitOffset;
655 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset,
657 SizeInBits = NextSBitOffset - SBitOffset;
658 if (SizeInBits & (SizeInBits - 1))
660 return SizeInBits >> 3;
666 if (
Tag == dwarf::DW_TAG_array_type) {
672 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
677 const auto *BTy = dyn_cast<DIBasicType>(BaseTy);
679 const auto *CompTy = dyn_cast<DICompositeType>(BaseTy);
681 if (!CompTy || CompTy->getTag() != dwarf::DW_TAG_enumeration_type)
684 BTy = dyn_cast<DIBasicType>(BaseTy);
686 uint32_t Encoding = BTy->getEncoding();
687 return (Encoding == dwarf::DW_ATE_signed || Encoding == dwarf::DW_ATE_signed_char);
697 bool IsBitField =
false;
700 if (
Tag == dwarf::DW_TAG_array_type) {
704 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
712 return 64 - SizeInBits;
715 unsigned SBitOffset, NextSBitOffset;
716 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset, NextSBitOffset);
717 if (NextSBitOffset - SBitOffset > 64)
722 return SBitOffset + 64 - OffsetInBits - SizeInBits;
724 return OffsetInBits + 64 - NextSBitOffset;
729 bool IsBitField =
false;
731 if (
Tag == dwarf::DW_TAG_array_type) {
735 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
743 return 64 - SizeInBits;
746 unsigned SBitOffset, NextSBitOffset;
747 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset, NextSBitOffset);
748 if (NextSBitOffset - SBitOffset > 64)
751 return 64 - SizeInBits;
757 bool BPFAbstractMemberAccess::HasPreserveFieldInfoCall(CallInfoStack &CallStack) {
759 while (CallStack.size()) {
760 auto StackElem = CallStack.top();
761 if (StackElem.second.Kind == BPFPreserveFieldInfoAI)
771 Value *BPFAbstractMemberAccess::computeBaseAndAccessKey(
CallInst *Call,
773 std::string &AccessKey,
777 CallInfoStack CallStack;
781 CallStack.push(std::make_pair(Call, CInfo));
782 CInfo = AIChain[
Call].second;
799 while (CallStack.size()) {
800 auto StackElem = CallStack.top();
801 Call = StackElem.first;
802 CInfo = StackElem.second;
810 if (CInfo.Kind == BPFPreserveUnionAI ||
811 CInfo.Kind == BPFPreserveStructAI) {
815 TypeMeta = PossibleTypeDef;
820 assert(CInfo.Kind == BPFPreserveArrayAI);
826 uint64_t AccessIndex = CInfo.AccessIndex;
829 bool CheckElemType =
false;
830 if (
const auto *CTy = dyn_cast<DICompositeType>(Ty)) {
840 auto *DTy = cast<DIDerivedType>(Ty);
841 assert(DTy->getTag() == dwarf::DW_TAG_pointer_type);
844 CTy = dyn_cast<DICompositeType>(BaseTy);
846 CheckElemType =
true;
847 }
else if (CTy->
getTag() != dwarf::DW_TAG_array_type) {
848 FirstIndex += AccessIndex;
849 CheckElemType =
true;
856 auto *CTy = dyn_cast<DICompositeType>(BaseTy);
858 if (HasPreserveFieldInfoCall(CallStack))
863 unsigned CTag = CTy->
getTag();
864 if (CTag == dwarf::DW_TAG_structure_type || CTag == dwarf::DW_TAG_union_type) {
867 if (HasPreserveFieldInfoCall(CallStack))
881 while (CallStack.size()) {
882 auto StackElem = CallStack.top();
883 CInfo = StackElem.second;
886 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
887 InfoKind = CInfo.AccessIndex;
895 if (CallStack.size()) {
896 auto StackElem2 = CallStack.top();
897 CallInfo CInfo2 = StackElem2.second;
898 if (CInfo2.Kind == BPFPreserveFieldInfoAI) {
899 InfoKind = CInfo2.AccessIndex;
900 assert(CallStack.size() == 1);
905 uint64_t AccessIndex = CInfo.AccessIndex;
908 MDNode *MDN = CInfo.Metadata;
911 PatchImm = GetFieldInfo(InfoKind, CTy, AccessIndex, PatchImm,
912 CInfo.RecordAlignment);
929 std::string &AccessKey,
935 std::string AccessStr(
"0");
950 cast<GlobalVariable>(
Call->getArgOperand(1)->stripPointerCasts());
962 const auto *CTy = cast<DICompositeType>(BaseTy);
963 assert(CTy->
getTag() == dwarf::DW_TAG_enumeration_type);
966 const auto *
Enum = cast<DIEnumerator>(Element);
967 if (
Enum->getName() == EnumeratorStr) {
976 PatchImm = std::stoll(std::string(EValueStr));
982 AccessKey =
"llvm." + Ty->
getName().
str() +
":" +
991 bool BPFAbstractMemberAccess::transformGEPChain(
CallInst *Call,
993 std::string AccessKey;
998 IsInt32Ret = CInfo.Kind == BPFPreserveFieldInfoAI;
999 if (CInfo.Kind == BPFPreserveFieldInfoAI && CInfo.Metadata) {
1000 TypeMeta = computeAccessKey(Call, CInfo, AccessKey, IsInt32Ret);
1002 Base = computeBaseAndAccessKey(Call, CInfo, AccessKey, TypeMeta);
1010 if (GEPGlobals.find(AccessKey) == GEPGlobals.end()) {
1018 nullptr, AccessKey);
1020 GV->
setMetadata(LLVMContext::MD_preserve_access_index, TypeMeta);
1021 GEPGlobals[AccessKey] = GV;
1023 GV = GEPGlobals[AccessKey];
1026 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
1036 Call->replaceAllUsesWith(PassThroughInst);
1037 Call->eraseFromParent();
1056 BB->getInstList().insert(
Call->getIterator(), BCInst);
1061 BB->getInstList().insert(
Call->getIterator(),
GEP);
1065 BB->getInstList().insert(
Call->getIterator(), BCInst2);
1114 Call->replaceAllUsesWith(PassThroughInst);
1115 Call->eraseFromParent();
1120 bool BPFAbstractMemberAccess::doTransformation(
Function &
F) {
1121 bool Transformed =
false;
1126 collectAICallChains(
F);
1128 for (
auto &
C : BaseAICalls)
1129 Transformed = transformGEPChain(
C.first,
C.second) || Transformed;
1131 return removePreserveAccessIndexIntrinsic(
F) || Transformed;