aboutsummaryrefslogtreecommitdiff
path: root/llvm_mode/compare-transform-pass.so.cc
diff options
context:
space:
mode:
authorAndrea Fioraldi <andreafioraldi@gmail.com>2019-09-02 18:49:43 +0200
committerAndrea Fioraldi <andreafioraldi@gmail.com>2019-09-02 18:49:43 +0200
commitb24639d0113e15933e749ea0f96abe3f25a134a0 (patch)
tree4272020625c80c0d6982d3787bebc573c0da01b8 /llvm_mode/compare-transform-pass.so.cc
parent2ae4ca91b48407add0e940ee13bd8b385e319a7a (diff)
downloadafl++-b24639d0113e15933e749ea0f96abe3f25a134a0.tar.gz
run code formatter
Diffstat (limited to 'llvm_mode/compare-transform-pass.so.cc')
-rw-r--r--llvm_mode/compare-transform-pass.so.cc286
1 files changed, 165 insertions, 121 deletions
diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc
index e7886db1..e1b6e671 100644
--- a/llvm_mode/compare-transform-pass.so.cc
+++ b/llvm_mode/compare-transform-pass.so.cc
@@ -36,202 +36,236 @@ using namespace llvm;
namespace {
- class CompareTransform : public ModulePass {
+class CompareTransform : public ModulePass {
- public:
- static char ID;
- CompareTransform() : ModulePass(ID) {
- }
+ public:
+ static char ID;
+ CompareTransform() : ModulePass(ID) {
- bool runOnModule(Module &M) override;
+ }
+
+ bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR < 4
- const char * getPassName() const override {
+ const char *getPassName() const override {
+
#else
- StringRef getPassName() const override {
+ StringRef getPassName() const override {
+
#endif
- return "transforms compare functions";
- }
- private:
- bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
- ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
- };
-}
+ return "transforms compare functions";
+ }
+
+ private:
+ bool transformCmps(Module &M, const bool processStrcmp,
+ const bool processMemcmp, const bool processStrncmp,
+ const bool processStrcasecmp,
+ const bool processStrncasecmp);
+
+};
+
+} // namespace
char CompareTransform::ID = 0;
-bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
- , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
+bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
+ const bool processMemcmp,
+ const bool processStrncmp,
+ const bool processStrcasecmp,
+ const bool processStrncasecmp) {
- std::vector<CallInst*> calls;
- LLVMContext &C = M.getContext();
- IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
- IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
- IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
+ std::vector<CallInst *> calls;
+ LLVMContext & C = M.getContext();
+ IntegerType * Int8Ty = IntegerType::getInt8Ty(C);
+ IntegerType * Int32Ty = IntegerType::getInt32Ty(C);
+ IntegerType * Int64Ty = IntegerType::getInt64Ty(C);
#if LLVM_VERSION_MAJOR < 9
- Constant*
+ Constant *
#else
FunctionCallee
#endif
- c = M.getOrInsertFunction("tolower",
- Int32Ty,
- Int32Ty
+ c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
#if LLVM_VERSION_MAJOR < 5
- , nullptr
+ ,
+ nullptr
#endif
- );
+ );
#if LLVM_VERSION_MAJOR < 9
- Function* tolowerFn = cast<Function>(c);
+ Function *tolowerFn = cast<Function>(c);
#else
FunctionCallee tolowerFn = c;
#endif
- /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
+ /* iterate over all functions, bbs and instruction and add suitable calls to
+ * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
for (auto &F : M) {
+
for (auto &BB : F) {
- for(auto &IN: BB) {
- CallInst* callInst = nullptr;
+
+ for (auto &IN : BB) {
+
+ CallInst *callInst = nullptr;
if ((callInst = dyn_cast<CallInst>(&IN))) {
- bool isStrcmp = processStrcmp;
- bool isMemcmp = processMemcmp;
- bool isStrncmp = processStrncmp;
- bool isStrcasecmp = processStrcasecmp;
+ bool isStrcmp = processStrcmp;
+ bool isMemcmp = processMemcmp;
+ bool isStrncmp = processStrncmp;
+ bool isStrcasecmp = processStrcasecmp;
bool isStrncasecmp = processStrncasecmp;
Function *Callee = callInst->getCalledFunction();
- if (!Callee)
- continue;
- if (callInst->getCallingConv() != llvm::CallingConv::C)
- continue;
+ if (!Callee) continue;
+ if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
StringRef FuncName = Callee->getName();
- isStrcmp &= !FuncName.compare(StringRef("strcmp"));
- isMemcmp &= !FuncName.compare(StringRef("memcmp"));
- isStrncmp &= !FuncName.compare(StringRef("strncmp"));
- isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
+ isStrcmp &= !FuncName.compare(StringRef("strcmp"));
+ isMemcmp &= !FuncName.compare(StringRef("memcmp"));
+ isStrncmp &= !FuncName.compare(StringRef("strncmp"));
+ isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
- if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+ if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+ !isStrncasecmp)
continue;
- /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
+ /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
+ * prototype */
FunctionType *FT = Callee->getFunctionType();
-
- isStrcmp &= FT->getNumParams() == 2 &&
- FT->getReturnType()->isIntegerTy(32) &&
- FT->getParamType(0) == FT->getParamType(1) &&
- FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
- isStrcasecmp &= FT->getNumParams() == 2 &&
- FT->getReturnType()->isIntegerTy(32) &&
- FT->getParamType(0) == FT->getParamType(1) &&
- FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
- isMemcmp &= FT->getNumParams() == 3 &&
+ isStrcmp &=
+ FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+ isStrcasecmp &=
+ FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+ isMemcmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0)->isPointerTy() &&
FT->getParamType(1)->isPointerTy() &&
FT->getParamType(2)->isIntegerTy();
- isStrncmp &= FT->getNumParams() == 3 &&
- FT->getReturnType()->isIntegerTy(32) &&
- FT->getParamType(0) == FT->getParamType(1) &&
- FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
- FT->getParamType(2)->isIntegerTy();
+ isStrncmp &= FT->getNumParams() == 3 &&
+ FT->getReturnType()->isIntegerTy(32) &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ FT->getParamType(0) ==
+ IntegerType::getInt8PtrTy(M.getContext()) &&
+ FT->getParamType(2)->isIntegerTy();
isStrncasecmp &= FT->getNumParams() == 3 &&
- FT->getReturnType()->isIntegerTy(32) &&
- FT->getParamType(0) == FT->getParamType(1) &&
- FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
- FT->getParamType(2)->isIntegerTy();
-
- if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+ FT->getReturnType()->isIntegerTy(32) &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ FT->getParamType(0) ==
+ IntegerType::getInt8PtrTy(M.getContext()) &&
+ FT->getParamType(2)->isIntegerTy();
+
+ if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+ !isStrncasecmp)
continue;
/* is a str{n,}{case,}cmp/memcmp, check if we have
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
- Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
+ Value *Str1P = callInst->getArgOperand(0),
+ *Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2;
- bool HasStr1 = getConstantStringInfo(Str1P, Str1);
- bool HasStr2 = getConstantStringInfo(Str2P, Str2);
+ bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+ bool HasStr2 = getConstantStringInfo(Str2P, Str2);
/* handle cases of one string is const, one string is variable */
- if (!(HasStr1 ^ HasStr2))
- continue;
+ if (!(HasStr1 ^ HasStr2)) continue;
if (isMemcmp || isStrncmp || isStrncasecmp) {
+
/* check if third operand is a constant integer
* strlen("constStr") and sizeof() are treated as constant */
- Value *op2 = callInst->getArgOperand(2);
- ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
- if (!ilen)
- continue;
- /* final precaution: if size of compare is larger than constant string skip it*/
- uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
- if (literalLength < ilen->getZExtValue())
- continue;
+ Value * op2 = callInst->getArgOperand(2);
+ ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
+ if (!ilen) continue;
+ /* final precaution: if size of compare is larger than constant
+ * string skip it*/
+ uint64_t literalLength =
+ HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
+ if (literalLength < ilen->getZExtValue()) continue;
+
}
calls.push_back(callInst);
+
}
+
}
+
}
+
}
- if (!calls.size())
- return false;
- errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
+ if (!calls.size()) return false;
+ errs() << "Replacing " << calls.size()
+ << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
- for (auto &callInst: calls) {
+ for (auto &callInst : calls) {
- Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
- StringRef Str1, Str2, ConstStr;
+ Value *Str1P = callInst->getArgOperand(0),
+ *Str2P = callInst->getArgOperand(1);
+ StringRef Str1, Str2, ConstStr;
std::string TmpConstStr;
- Value *VarStr;
- bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+ Value * VarStr;
+ bool HasStr1 = getConstantStringInfo(Str1P, Str1);
getConstantStringInfo(Str2P, Str2);
uint64_t constLen, sizedLen;
- bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
- bool isSizedcmp = isMemcmp
- || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
- || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
- bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
- || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
+ bool isMemcmp =
+ !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
+ bool isSizedcmp = isMemcmp ||
+ !callInst->getCalledFunction()->getName().compare(
+ StringRef("strncmp")) ||
+ !callInst->getCalledFunction()->getName().compare(
+ StringRef("strncasecmp"));
+ bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
+ StringRef("strcasecmp")) ||
+ !callInst->getCalledFunction()->getName().compare(
+ StringRef("strncasecmp"));
if (isSizedcmp) {
- Value *op2 = callInst->getArgOperand(2);
- ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
+
+ Value * op2 = callInst->getArgOperand(2);
+ ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
sizedLen = ilen->getZExtValue();
+
}
if (HasStr1) {
+
TmpConstStr = Str1.str();
VarStr = Str2P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
- }
- else {
+
+ } else {
+
TmpConstStr = Str2.str();
VarStr = Str1P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
+
}
/* properly handle zero terminated C strings by adding the terminating 0 to
* the StringRef (in comparison to std::string a StringRef has built-in
* runtime bounds checking, which makes debugging easier) */
- TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr);
+ TmpConstStr.append("\0", 1);
+ ConstStr = StringRef(TmpConstStr);
- if (isSizedcmp && constLen > sizedLen) {
- constLen = sizedLen;
- }
+ if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; }
- errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
+ errs() << callInst->getCalledFunction()->getName() << ": len " << constLen
+ << ": " << ConstStr << "\n";
/* split before the call instruction */
BasicBlock *bb = callInst->getParent();
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
- BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+ BasicBlock *next_bb =
+ BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
@@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
-
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
- IRBuilder<> IRB(&*IP);
+ IRBuilder<> IRB(&*IP);
- Value* v = ConstantInt::get(Int64Ty, i);
- Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
+ Value *v = ConstantInt::get(Int64Ty, i);
+ Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
Value *load = IRB.CreateLoad(ele);
if (isCaseInsensitive) {
+
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
std::vector<Value *> args;
args.push_back(load);
load = IRB.CreateCall(tolowerFn, args, "tmp");
load = IRB.CreateTrunc(load, Int8Ty);
+
}
+
Value *isub;
if (HasStr1)
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
else
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
- Value *sext = IRB.CreateSExt(isub, Int32Ty);
+ Value *sext = IRB.CreateSExt(isub, Int32Ty);
PN->addIncoming(sext, cur_bb);
-
if (i < constLen - 1) {
- next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+
+ next_bb =
+ BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
IRB.CreateCondBr(icmp, next_bb, end_bb);
cur_bb->getTerminator()->eraseFromParent();
+
} else {
- //IRB.CreateBr(end_bb);
+
+ // IRB.CreateBr(end_bb);
+
}
- //add offset to varstr
- //create load
- //create signed isub
- //create icmp
- //create jcc
- //create next_bb
+ // add offset to varstr
+ // create load
+ // create signed isub
+ // create icmp
+ // create jcc
+ // create next_bb
+
}
/* since the call is the first instruction of the bb it is safe to
* replace it with a phi instruction */
BasicBlock::iterator ii(callInst);
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
- }
+ }
return true;
+
}
bool CompareTransform::runOnModule(Module &M) {
if (getenv("AFL_QUIET") == NULL)
- llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
+ llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
+ "extended by heiko@hexco.de\n";
transformCmps(M, true, true, true, true, true);
verifyModule(M);
return true;
+
}
static void registerCompTransPass(const PassManagerBuilder &,
- legacy::PassManagerBase &PM) {
+ legacy::PassManagerBase &PM) {
auto p = new CompareTransform();
PM.add(p);