diff options
author | Andrea Fioraldi <andreafioraldi@gmail.com> | 2019-09-02 18:49:43 +0200 |
---|---|---|
committer | Andrea Fioraldi <andreafioraldi@gmail.com> | 2019-09-02 18:49:43 +0200 |
commit | b24639d0113e15933e749ea0f96abe3f25a134a0 (patch) | |
tree | 4272020625c80c0d6982d3787bebc573c0da01b8 /llvm_mode/compare-transform-pass.so.cc | |
parent | 2ae4ca91b48407add0e940ee13bd8b385e319a7a (diff) | |
download | afl++-b24639d0113e15933e749ea0f96abe3f25a134a0.tar.gz |
run code formatter
Diffstat (limited to 'llvm_mode/compare-transform-pass.so.cc')
-rw-r--r-- | llvm_mode/compare-transform-pass.so.cc | 286 |
1 files changed, 165 insertions, 121 deletions
diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc index e7886db1..e1b6e671 100644 --- a/llvm_mode/compare-transform-pass.so.cc +++ b/llvm_mode/compare-transform-pass.so.cc @@ -36,202 +36,236 @@ using namespace llvm; namespace { - class CompareTransform : public ModulePass { +class CompareTransform : public ModulePass { - public: - static char ID; - CompareTransform() : ModulePass(ID) { - } + public: + static char ID; + CompareTransform() : ModulePass(ID) { - bool runOnModule(Module &M) override; + } + + bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR < 4 - const char * getPassName() const override { + const char *getPassName() const override { + #else - StringRef getPassName() const override { + StringRef getPassName() const override { + #endif - return "transforms compare functions"; - } - private: - bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp - ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp); - }; -} + return "transforms compare functions"; + } + + private: + bool transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp, const bool processStrncmp, + const bool processStrcasecmp, + const bool processStrncasecmp); + +}; + +} // namespace char CompareTransform::ID = 0; -bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp - , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) { +bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp, + const bool processStrncmp, + const bool processStrcasecmp, + const bool processStrncasecmp) { - std::vector<CallInst*> calls; - LLVMContext &C = M.getContext(); - IntegerType *Int8Ty = IntegerType::getInt8Ty(C); - IntegerType *Int32Ty = IntegerType::getInt32Ty(C); - IntegerType *Int64Ty = IntegerType::getInt64Ty(C); + std::vector<CallInst *> calls; + LLVMContext & C = M.getContext(); + IntegerType * Int8Ty = IntegerType::getInt8Ty(C); + IntegerType * Int32Ty = IntegerType::getInt32Ty(C); + IntegerType * Int64Ty = IntegerType::getInt64Ty(C); #if LLVM_VERSION_MAJOR < 9 - Constant* + Constant * #else FunctionCallee #endif - c = M.getOrInsertFunction("tolower", - Int32Ty, - Int32Ty + c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty #if LLVM_VERSION_MAJOR < 5 - , nullptr + , + nullptr #endif - ); + ); #if LLVM_VERSION_MAJOR < 9 - Function* tolowerFn = cast<Function>(c); + Function *tolowerFn = cast<Function>(c); #else FunctionCallee tolowerFn = c; #endif - /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ + /* iterate over all functions, bbs and instruction and add suitable calls to + * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ for (auto &F : M) { + for (auto &BB : F) { - for(auto &IN: BB) { - CallInst* callInst = nullptr; + + for (auto &IN : BB) { + + CallInst *callInst = nullptr; if ((callInst = dyn_cast<CallInst>(&IN))) { - bool isStrcmp = processStrcmp; - bool isMemcmp = processMemcmp; - bool isStrncmp = processStrncmp; - bool isStrcasecmp = processStrcasecmp; + bool isStrcmp = processStrcmp; + bool isMemcmp = processMemcmp; + bool isStrncmp = processStrncmp; + bool isStrcasecmp = processStrcasecmp; bool isStrncasecmp = processStrncasecmp; Function *Callee = callInst->getCalledFunction(); - if (!Callee) - continue; - if (callInst->getCallingConv() != llvm::CallingConv::C) - continue; + if (!Callee) continue; + if (callInst->getCallingConv() != llvm::CallingConv::C) continue; StringRef FuncName = Callee->getName(); - isStrcmp &= !FuncName.compare(StringRef("strcmp")); - isMemcmp &= !FuncName.compare(StringRef("memcmp")); - isStrncmp &= !FuncName.compare(StringRef("strncmp")); - isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); + isStrcmp &= !FuncName.compare(StringRef("strcmp")); + isMemcmp &= !FuncName.compare(StringRef("memcmp")); + isStrncmp &= !FuncName.compare(StringRef("strncmp")); + isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp")); - if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && + !isStrncasecmp) continue; - /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */ + /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function + * prototype */ FunctionType *FT = Callee->getFunctionType(); - - isStrcmp &= FT->getNumParams() == 2 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); - isStrcasecmp &= FT->getNumParams() == 2 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); - isMemcmp &= FT->getNumParams() == 3 && + isStrcmp &= + FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isStrcasecmp &= + FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isMemcmp &= FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0)->isPointerTy() && FT->getParamType(1)->isPointerTy() && FT->getParamType(2)->isIntegerTy(); - isStrncmp &= FT->getNumParams() == 3 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && - FT->getParamType(2)->isIntegerTy(); + isStrncmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == + IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); isStrncasecmp &= FT->getNumParams() == 3 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && - FT->getParamType(2)->isIntegerTy(); - - if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == + IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); + + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && + !isStrncasecmp) continue; /* is a str{n,}{case,}cmp/memcmp, check if we have * str{case,}cmp(x, "const") or str{case,}cmp("const", x) * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) * memcmp(x, "const", ..) or memcmp("const", x, ..) */ - Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); + Value *Str1P = callInst->getArgOperand(0), + *Str2P = callInst->getArgOperand(1); StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); /* handle cases of one string is const, one string is variable */ - if (!(HasStr1 ^ HasStr2)) - continue; + if (!(HasStr1 ^ HasStr2)) continue; if (isMemcmp || isStrncmp || isStrncasecmp) { + /* check if third operand is a constant integer * strlen("constStr") and sizeof() are treated as constant */ - Value *op2 = callInst->getArgOperand(2); - ConstantInt* ilen = dyn_cast<ConstantInt>(op2); - if (!ilen) - continue; - /* final precaution: if size of compare is larger than constant string skip it*/ - uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P); - if (literalLength < ilen->getZExtValue()) - continue; + Value * op2 = callInst->getArgOperand(2); + ConstantInt *ilen = dyn_cast<ConstantInt>(op2); + if (!ilen) continue; + /* final precaution: if size of compare is larger than constant + * string skip it*/ + uint64_t literalLength = + HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P); + if (literalLength < ilen->getZExtValue()) continue; + } calls.push_back(callInst); + } + } + } + } - if (!calls.size()) - return false; - errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; + if (!calls.size()) return false; + errs() << "Replacing " << calls.size() + << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; - for (auto &callInst: calls) { + for (auto &callInst : calls) { - Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); - StringRef Str1, Str2, ConstStr; + Value *Str1P = callInst->getArgOperand(0), + *Str2P = callInst->getArgOperand(1); + StringRef Str1, Str2, ConstStr; std::string TmpConstStr; - Value *VarStr; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); + Value * VarStr; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); getConstantStringInfo(Str2P, Str2); uint64_t constLen, sizedLen; - bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); - bool isSizedcmp = isMemcmp - || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp")) - || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); - bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp")) - || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); + bool isMemcmp = + !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); + bool isSizedcmp = isMemcmp || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncmp")) || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncasecmp")); + bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare( + StringRef("strcasecmp")) || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncasecmp")); if (isSizedcmp) { - Value *op2 = callInst->getArgOperand(2); - ConstantInt* ilen = dyn_cast<ConstantInt>(op2); + + Value * op2 = callInst->getArgOperand(2); + ConstantInt *ilen = dyn_cast<ConstantInt>(op2); sizedLen = ilen->getZExtValue(); + } if (HasStr1) { + TmpConstStr = Str1.str(); VarStr = Str2P; constLen = isMemcmp ? sizedLen : GetStringLength(Str1P); - } - else { + + } else { + TmpConstStr = Str2.str(); VarStr = Str1P; constLen = isMemcmp ? sizedLen : GetStringLength(Str2P); + } /* properly handle zero terminated C strings by adding the terminating 0 to * the StringRef (in comparison to std::string a StringRef has built-in * runtime bounds checking, which makes debugging easier) */ - TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr); + TmpConstStr.append("\0", 1); + ConstStr = StringRef(TmpConstStr); - if (isSizedcmp && constLen > sizedLen) { - constLen = sizedLen; - } + if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; } - errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n"; + errs() << callInst->getCalledFunction()->getName() << ": len " << constLen + << ": " << ConstStr << "\n"; /* split before the call instruction */ BasicBlock *bb = callInst->getParent(); BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); - BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + BasicBlock *next_bb = + BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_bb); PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi"); @@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i]; - BasicBlock::iterator IP = next_bb->getFirstInsertionPt(); - IRBuilder<> IRB(&*IP); + IRBuilder<> IRB(&*IP); - Value* v = ConstantInt::get(Int64Ty, i); - Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty"); + Value *v = ConstantInt::get(Int64Ty, i); + Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty"); Value *load = IRB.CreateLoad(ele); if (isCaseInsensitive) { + // load >= 'A' && load <= 'Z' ? load | 0x020 : load std::vector<Value *> args; args.push_back(load); load = IRB.CreateCall(tolowerFn, args, "tmp"); load = IRB.CreateTrunc(load, Int8Ty); + } + Value *isub; if (HasStr1) isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); else isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); - Value *sext = IRB.CreateSExt(isub, Int32Ty); + Value *sext = IRB.CreateSExt(isub, Int32Ty); PN->addIncoming(sext, cur_bb); - if (i < constLen - 1) { - next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + + next_bb = + BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_bb); Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); IRB.CreateCondBr(icmp, next_bb, end_bb); cur_bb->getTerminator()->eraseFromParent(); + } else { - //IRB.CreateBr(end_bb); + + // IRB.CreateBr(end_bb); + } - //add offset to varstr - //create load - //create signed isub - //create icmp - //create jcc - //create next_bb + // add offset to varstr + // create load + // create signed isub + // create icmp + // create jcc + // create next_bb + } /* since the call is the first instruction of the bb it is safe to * replace it with a phi instruction */ BasicBlock::iterator ii(callInst); ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); - } + } return true; + } bool CompareTransform::runOnModule(Module &M) { if (getenv("AFL_QUIET") == NULL) - llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n"; + llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, " + "extended by heiko@hexco.de\n"; transformCmps(M, true, true, true, true, true); verifyModule(M); return true; + } static void registerCompTransPass(const PassManagerBuilder &, - legacy::PassManagerBase &PM) { + legacy::PassManagerBase &PM) { auto p = new CompareTransform(); PM.add(p); |