/* * Copyright 2016 laf-intel * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "llvm/Config/llvm-config.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Pass.h" #include "llvm/Analysis/ValueTracking.h" #if LLVM_VERSION_MAJOR > 3 || \ (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) #include "llvm/IR/Verifier.h" #include "llvm/IR/DebugInfo.h" #else #include "llvm/Analysis/Verifier.h" #include "llvm/DebugInfo.h" #define nullptr 0 #endif #include #include "afl-llvm-common.h" using namespace llvm; namespace { class CompareTransform : public ModulePass { public: static char ID; CompareTransform() : ModulePass(ID) { initInstrumentList(); } bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR < 4 const char *getPassName() const override { #else StringRef getPassName() const override { #endif return "transforms compare functions"; } private: bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp); }; } // namespace char CompareTransform::ID = 0; bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) { DenseMap valueMap; std::vector calls; LLVMContext & C = M.getContext(); IntegerType * Int8Ty = IntegerType::getInt8Ty(C); IntegerType * Int32Ty = IntegerType::getInt32Ty(C); IntegerType * Int64Ty = IntegerType::getInt64Ty(C); #if LLVM_VERSION_MAJOR < 9 Constant * #else FunctionCallee #endif c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty #if LLVM_VERSION_MAJOR < 5 , NULL #endif ); #if LLVM_VERSION_MAJOR < 9 Function *tolowerFn = cast(c); #else FunctionCallee tolowerFn = c; #endif /* iterate over all functions, bbs and instruction and add suitable calls to * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ for (auto &F : M) { if (!isInInstrumentList(&F)) continue; for (auto &BB : F) { for (auto &IN : BB) { CallInst *callInst = nullptr; if ((callInst = dyn_cast(&IN))) { bool isStrcmp = processStrcmp; bool isMemcmp = processMemcmp; bool isStrncmp = processStrncmp; bool isStrcasecmp = processStrcasecmp; bool isStrncasecmp = processStrncasecmp; bool isIntMemcpy = true; Function *Callee = callInst->getCalledFunction(); if (!Callee) continue; if (callInst->getCallingConv() != llvm::CallingConv::C) continue; StringRef FuncName = Callee->getName(); isStrcmp &= !FuncName.compare(StringRef("strcmp")); isMemcmp &= !FuncName.compare(StringRef("memcmp")); isStrncmp &= !FuncName.compare(StringRef("strncmp")); isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp")); isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64"); if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp && !isIntMemcpy) continue; /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function * prototype */ FunctionType *FT = Callee->getFunctionType(); isStrcmp &= FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0) == FT->getParamType(1) && FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); isStrcasecmp &= FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0) == FT->getParamType(1) && FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); isMemcmp &= FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0)->isPointerTy() && FT->getParamType(1)->isPointerTy() && FT->getParamType(2)->isIntegerTy(); isStrncmp &= FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0) == FT->getParamType(1) && FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && FT->getParamType(2)->isIntegerTy(); isStrncasecmp &= FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0) == FT->getParamType(1) && FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && FT->getParamType(2)->isIntegerTy(); if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp && !isIntMemcpy) continue; /* is a str{n,}{case,}cmp/memcmp, check if we have * str{case,}cmp(x, "const") or str{case,}cmp("const", x) * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) * memcmp(x, "const", ..) or memcmp("const", x, ..) */ Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); StringRef Str1, Str2; bool HasStr1 = getConstantStringInfo(Str1P, Str1); bool HasStr2 = getConstantStringInfo(Str2P, Str2); if (isIntMemcpy && HasStr2) { valueMap[Str1P] = new std::string(Str2.str()); // fprintf(stderr, "saved %s for %p\n", Str2.str().c_str(), Str1P); continue; } // not literal? maybe global or local variable if (!(HasStr1 || HasStr2)) { auto *Ptr = dyn_cast(Str2P); if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) { if (auto *Var = dyn_cast(Ptr->getOperand(0))) { if (Var->hasInitializer()) { if (auto *Array = dyn_cast(Var->getInitializer())) { HasStr2 = true; Str2 = Array->getAsString(); valueMap[Str2P] = new std::string(Str2.str()); fprintf(stderr, "glo2 %s\n", Str2.str().c_str()); } } } } if (!HasStr2) { auto *Ptr = dyn_cast(Str1P); if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) { if (auto *Var = dyn_cast(Ptr->getOperand(0))) { if (Var->hasInitializer()) { if (auto *Array = dyn_cast( Var->getInitializer())) { HasStr1 = true; Str1 = Array->getAsString(); valueMap[Str1P] = new std::string(Str1.str()); // fprintf(stderr, "glo1 %s\n", Str1.str().c_str()); } } } } } else if (isIntMemcpy) { valueMap[Str1P] = new std::string(Str2.str()); // fprintf(stderr, "saved\n"); } } if (isIntMemcpy) continue; if (!(HasStr1 || HasStr2)) { // do we have a saved local variable initialization? std::string *val = valueMap[Str1P]; if (val && !val->empty()) { Str1 = StringRef(*val); HasStr1 = true; // fprintf(stderr, "loaded1 %s\n", Str1.str().c_str()); } else { val = valueMap[Str2P]; if (val && !val->empty()) { Str2 = StringRef(*val); HasStr2 = true; // fprintf(stderr, "loaded2 %s\n", Str2.str().c_str()); } } } /* handle cases of one string is const, one string is variable */ if (!(HasStr1 || HasStr2)) continue; if (isMemcmp || isStrncmp || isStrncasecmp) { /* check if third operand is a constant integer * strlen("constStr") and sizeof() are treated as constant */ Value * op2 = callInst->getArgOperand(2); ConstantInt *ilen = dyn_cast(op2); if (ilen) { uint64_t len = ilen->getZExtValue(); // if len is zero this is a pointless call but allow real // implementation to worry about that if (!len) continue; if (isMemcmp) { // if size of compare is larger than constant string this is // likely a bug but allow real implementation to worry about // that uint64_t literalLength = HasStr1 ? Str1.size() : Str2.size(); if (literalLength + 1 < ilen->getZExtValue()) continue; } } else if (isMemcmp) // this *may* supply a len greater than the constant string at // runtime so similarly we don't want to have to handle that continue; } calls.push_back(callInst); } } } } if (!calls.size()) return false; if (!be_quiet) errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; for (auto &callInst : calls) { Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); StringRef Str1, Str2, ConstStr; std::string TmpConstStr; Value * VarStr; bool HasStr1 = getConstantStringInfo(Str1P, Str1); bool HasStr2 = getConstantStringInfo(Str2P, Str2); uint64_t constStrLen, unrollLen, constSizedLen = 0; bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); bool isSizedcmp = isMemcmp || !callInst->getCalledFunction()->getName().compare( StringRef("strncmp")) || !callInst->getCalledFunction()->getName().compare( StringRef("strncasecmp")); Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL; bool isConstSized = sizedValue && isa(sizedValue); bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare( StringRef("strcasecmp")) || !callInst->getCalledFunction()->getName().compare( StringRef("strncasecmp")); if (!(HasStr1 || HasStr2)) { // do we have a saved local or global variable initialization? std::string *val = valueMap[Str1P]; if (val && !val->empty()) { Str1 = StringRef(*val); HasStr1 = true; } else { val = valueMap[Str2P]; if (val && !val->empty()) { Str2 = StringRef(*val); HasStr2 = true; } } } if (isConstSized) { constSizedLen = dyn_cast(sizedValue)->getZExtValue(); } if (HasStr1) { TmpConstStr = Str1.str(); VarStr = Str2P; } else { TmpConstStr = Str2.str(); VarStr = Str1P; } // add null termination character implicit in c strings TmpConstStr.append("\0", 1); // in the unusual case the const str has embedded null // characters, the string comparison functions should terminate // at the first null if (!isMemcmp) TmpConstStr.assign(TmpConstStr, 0, TmpConstStr.find('\0') + 1); constStrLen = TmpConstStr.length(); // prefer use of StringRef (in comparison to std::string a StringRef has // built-in runtime bounds checking, which makes debugging easier) ConstStr = StringRef(TmpConstStr); if (isConstSized) unrollLen = constSizedLen < constStrLen ? constSizedLen : constStrLen; else unrollLen = constStrLen; if (!be_quiet) errs() << callInst->getCalledFunction()->getName() << ": unroll len " << unrollLen << ((isSizedcmp && !isConstSized) ? ", variable n" : "") << ": " << ConstStr << "\n"; /* split before the call instruction */ BasicBlock *bb = callInst->getParent(); BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); BasicBlock *next_lenchk_bb = NULL; if (isSizedcmp && !isConstSized) { next_lenchk_bb = BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_lenchk_bb); } BasicBlock *next_cmp_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_cmp_bb); PHINode *PN = PHINode::Create( Int32Ty, (next_lenchk_bb ? 2 : 1) * unrollLen + 1, "cmp_phi"); #if LLVM_VERSION_MAJOR < 8 TerminatorInst *term = bb->getTerminator(); #else Instruction *term = bb->getTerminator(); #endif BranchInst::Create(next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, bb); term->eraseFromParent(); for (uint64_t i = 0; i < unrollLen; i++) { BasicBlock * cur_cmp_bb = next_cmp_bb, *cur_lenchk_bb = next_lenchk_bb; unsigned char c; if (cur_lenchk_bb) { IRBuilder<> cur_lenchk_IRB(&*(cur_lenchk_bb->getFirstInsertionPt())); Value * icmp = cur_lenchk_IRB.CreateICmpEQ( sizedValue, ConstantInt::get(sizedValue->getType(), i)); cur_lenchk_IRB.CreateCondBr(icmp, end_bb, cur_cmp_bb); cur_lenchk_bb->getTerminator()->eraseFromParent(); PN->addIncoming(ConstantInt::get(Int32Ty, 0), cur_lenchk_bb); } if (isCaseInsensitive) c = (unsigned char)(tolower((int)ConstStr[i]) & 0xff); else c = (unsigned char)ConstStr[i]; IRBuilder<> cur_cmp_IRB(&*(cur_cmp_bb->getFirstInsertionPt())); Value *v = ConstantInt::get(Int64Ty, i); Value *ele = cur_cmp_IRB.CreateInBoundsGEP(VarStr, v, "empty"); Value *load = cur_cmp_IRB.CreateLoad(ele); if (isCaseInsensitive) { // load >= 'A' && load <= 'Z' ? load | 0x020 : load load = cur_cmp_IRB.CreateZExt(load, Int32Ty); std::vector args; args.push_back(load); load = cur_cmp_IRB.CreateCall(tolowerFn, args); load = cur_cmp_IRB.CreateTrunc(load, Int8Ty); } Value *isub; if (HasStr1) isub = cur_cmp_IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); else isub = cur_cmp_IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); Value *sext = cur_cmp_IRB.CreateSExt(isub, Int32Ty); PN->addIncoming(sext, cur_cmp_bb); if (i < unrollLen - 1) { if (cur_lenchk_bb) { next_lenchk_bb = BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_lenchk_bb); } next_cmp_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_cmp_bb); Value *icmp = cur_cmp_IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); cur_cmp_IRB.CreateCondBr( icmp, next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, end_bb); cur_cmp_bb->getTerminator()->eraseFromParent(); } else { // IRB.CreateBr(end_bb); } // add offset to varstr // create load // create signed isub // create icmp // create jcc // create next_bb } /* since the call is the first instruction of the bb it is safe to * replace it with a phi instruction */ BasicBlock::iterator ii(callInst); ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); } return true; } bool CompareTransform::runOnModule(Module &M) { if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, " "extended by heiko@hexco.de\n"; else be_quiet = 1; transformCmps(M, true, true, true, true, true); verifyModule(M); return true; } static void registerCompTransPass(const PassManagerBuilder &, legacy::PassManagerBase &PM) { auto p = new CompareTransform(); PM.add(p); } static RegisterStandardPasses RegisterCompTransPass( PassManagerBuilder::EP_OptimizerLast, registerCompTransPass); static RegisterStandardPasses RegisterCompTransPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass); #if LLVM_VERSION_MAJOR >= 11 static RegisterStandardPasses RegisterCompTransPassLTO( PassManagerBuilder::EP_FullLinkTimeOptimizationLast, registerCompTransPass); #endif