diff options
Diffstat (limited to 'instrumentation')
-rw-r--r-- | instrumentation/SanitizerCoverageLTO.so.cc | 14 | ||||
-rw-r--r-- | instrumentation/afl-compiler-rt.o.c | 160 | ||||
-rw-r--r-- | instrumentation/afl-llvm-common.cc | 5 | ||||
-rw-r--r-- | instrumentation/afl-llvm-dict2file.so.cc | 7 | ||||
-rw-r--r-- | instrumentation/afl-llvm-lto-instrumentation.so.cc | 7 | ||||
-rw-r--r-- | instrumentation/afl-llvm-pass.so.cc | 4 | ||||
-rw-r--r-- | instrumentation/cmplog-instructions-pass.cc | 184 | ||||
-rw-r--r-- | instrumentation/cmplog-routines-pass.cc | 67 | ||||
-rw-r--r-- | instrumentation/cmplog-switches-pass.cc | 414 | ||||
-rw-r--r-- | instrumentation/compare-transform-pass.so.cc | 19 | ||||
-rw-r--r-- | instrumentation/split-compares-pass.so.cc | 987 |
11 files changed, 1186 insertions, 682 deletions
diff --git a/instrumentation/SanitizerCoverageLTO.so.cc b/instrumentation/SanitizerCoverageLTO.so.cc index 20f1856e..eddbfcc8 100644 --- a/instrumentation/SanitizerCoverageLTO.so.cc +++ b/instrumentation/SanitizerCoverageLTO.so.cc @@ -516,6 +516,8 @@ bool ModuleSanitizerCoverage::instrumentModule( for (auto &F : M) { + if (!isInInstrumentList(&F) || !F.size()) { continue; } + for (auto &BB : F) { for (auto &IN : BB) { @@ -759,6 +761,12 @@ bool ModuleSanitizerCoverage::instrumentModule( uint64_t literalLength = Str2.size(); uint64_t optLength = ilen->getZExtValue(); + if (optLength > literalLength + 1) { + + optLength = Str2.length() + 1; + + } + if (literalLength + 1 == optLength) { Str2.append("\0", 1); // add null byte @@ -862,6 +870,12 @@ bool ModuleSanitizerCoverage::instrumentModule( uint64_t literalLength = optLen; optLen = ilen->getZExtValue(); + if (optLen > thestring.length() + 1) { + + optLen = thestring.length() + 1; + + } + if (optLen < 2) { continue; } if (literalLength + 1 == optLen) { // add null byte thestring.append("\0", 1); diff --git a/instrumentation/afl-compiler-rt.o.c b/instrumentation/afl-compiler-rt.o.c index 2089ce78..3fec291c 100644 --- a/instrumentation/afl-compiler-rt.o.c +++ b/instrumentation/afl-compiler-rt.o.c @@ -98,9 +98,9 @@ u32 __afl_dictionary_len; u64 __afl_map_addr; // for the __AFL_COVERAGE_ON/__AFL_COVERAGE_OFF features to work: -int __afl_selective_coverage __attribute__((weak)); -int __afl_selective_coverage_start_off __attribute__((weak)); -int __afl_selective_coverage_temp = 1; +int __afl_selective_coverage __attribute__((weak)); +int __afl_selective_coverage_start_off __attribute__((weak)); +static int __afl_selective_coverage_temp = 1; #if defined(__ANDROID__) || defined(__HAIKU__) PREV_LOC_T __afl_prev_loc[NGRAM_SIZE_MAX]; @@ -147,7 +147,7 @@ static int __afl_dummy_fd[2] = {2, 2}; /* ensure we kill the child on termination */ -void at_exit(int signal) { +static void at_exit(int signal) { if (child_pid > 0) { kill(child_pid, SIGKILL); } @@ -179,7 +179,7 @@ void __afl_trace(const u32 x) { /* Error reporting to forkserver controller */ -void send_forkserver_error(int error) { +static void send_forkserver_error(int error) { u32 status; if (!error || error > 0xffff) return; @@ -270,12 +270,6 @@ static void __afl_map_shm(void) { if (__afl_final_loc) { - if (__afl_final_loc % 64) { - - __afl_final_loc = (((__afl_final_loc + 63) >> 6) << 6); - - } - __afl_map_size = __afl_final_loc; if (__afl_final_loc > MAP_SIZE) { @@ -304,8 +298,9 @@ static void __afl_map_shm(void) { if (!getenv("AFL_QUIET")) fprintf(stderr, - "Warning: AFL++ tools will need to set AFL_MAP_SIZE to %u " - "to be able to run this instrumented program!\n", + "Warning: AFL++ tools might need to set AFL_MAP_SIZE to %u " + "to be able to run this instrumented program if this " + "crashes!\n", __afl_final_loc); } @@ -622,6 +617,7 @@ static void __afl_unmap_shm(void) { #endif __afl_cmp_map = NULL; + __afl_cmp_map_backup = NULL; } @@ -629,6 +625,34 @@ static void __afl_unmap_shm(void) { } +#define write_error(text) write_error_with_location(text, __FILE__, __LINE__) + +void write_error_with_location(char *text, char *filename, int linenumber) { + + u8 * o = getenv("__AFL_OUT_DIR"); + char *e = strerror(errno); + + if (o) { + + char buf[4096]; + snprintf(buf, sizeof(buf), "%s/error.txt", o); + FILE *f = fopen(buf, "a"); + + if (f) { + + fprintf(f, "File %s, line %d: Error(%s): %s\n", filename, linenumber, + text, e); + fclose(f); + + } + + } + + fprintf(stderr, "File %s, line %d: Error(%s): %s\n", filename, linenumber, + text, e); + +} + #ifdef __linux__ static void __afl_start_snapshots(void) { @@ -655,7 +679,12 @@ static void __afl_start_snapshots(void) { if (__afl_sharedmem_fuzzing || (__afl_dictionary_len && __afl_dictionary)) { - if (read(FORKSRV_FD, &was_killed, 4) != 4) { _exit(1); } + if (read(FORKSRV_FD, &was_killed, 4) != 4) { + + write_error("read to afl-fuzz"); + _exit(1); + + } if (__afl_debug) { @@ -724,7 +753,12 @@ static void __afl_start_snapshots(void) { } else { /* Wait for parent by reading from the pipe. Abort if read fails. */ - if (read(FORKSRV_FD, &was_killed, 4) != 4) _exit(1); + if (read(FORKSRV_FD, &was_killed, 4) != 4) { + + write_error("reading from afl-fuzz"); + _exit(1); + + } } @@ -761,7 +795,12 @@ static void __afl_start_snapshots(void) { if (child_stopped && was_killed) { child_stopped = 0; - if (waitpid(child_pid, &status, 0) < 0) _exit(1); + if (waitpid(child_pid, &status, 0) < 0) { + + write_error("child_stopped && was_killed"); + _exit(1); // TODO why exit? + + } } @@ -770,7 +809,12 @@ static void __afl_start_snapshots(void) { /* Once woken up, create a clone of our process. */ child_pid = fork(); - if (child_pid < 0) _exit(1); + if (child_pid < 0) { + + write_error("fork"); + _exit(1); + + } /* In child process: close fds, resume execution. */ @@ -810,9 +854,19 @@ static void __afl_start_snapshots(void) { /* In parent process: write PID to pipe, then wait for child. */ - if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) _exit(1); + if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) { + + write_error("write to afl-fuzz"); + _exit(1); + + } - if (waitpid(child_pid, &status, WUNTRACED) < 0) _exit(1); + if (waitpid(child_pid, &status, WUNTRACED) < 0) { + + write_error("waitpid"); + _exit(1); + + } /* In persistent mode, the child stops itself with SIGSTOP to indicate a successful run. In this case, we want to wake it up without forking @@ -822,7 +876,12 @@ static void __afl_start_snapshots(void) { /* Relay wait status to pipe, then loop back. */ - if (write(FORKSRV_FD + 1, &status, 4) != 4) _exit(1); + if (write(FORKSRV_FD + 1, &status, 4) != 4) { + + write_error("writing to afl-fuzz"); + _exit(1); + + } } @@ -955,7 +1014,12 @@ static void __afl_start_forkserver(void) { } else { - if (read(FORKSRV_FD, &was_killed, 4) != 4) _exit(1); + if (read(FORKSRV_FD, &was_killed, 4) != 4) { + + // write_error("read from afl-fuzz"); + _exit(1); + + } } @@ -992,7 +1056,12 @@ static void __afl_start_forkserver(void) { if (child_stopped && was_killed) { child_stopped = 0; - if (waitpid(child_pid, &status, 0) < 0) _exit(1); + if (waitpid(child_pid, &status, 0) < 0) { + + write_error("child_stopped && was_killed"); + _exit(1); + + } } @@ -1001,7 +1070,12 @@ static void __afl_start_forkserver(void) { /* Once woken up, create a clone of our process. */ child_pid = fork(); - if (child_pid < 0) _exit(1); + if (child_pid < 0) { + + write_error("fork"); + _exit(1); + + } /* In child process: close fds, resume execution. */ @@ -1030,11 +1104,20 @@ static void __afl_start_forkserver(void) { /* In parent process: write PID to pipe, then wait for child. */ - if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) _exit(1); + if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) { - if (waitpid(child_pid, &status, is_persistent ? WUNTRACED : 0) < 0) + write_error("write to afl-fuzz"); _exit(1); + } + + if (waitpid(child_pid, &status, is_persistent ? WUNTRACED : 0) < 0) { + + write_error("waitpid"); + _exit(1); + + } + /* In persistent mode, the child stops itself with SIGSTOP to indicate a successful run. In this case, we want to wake it up without forking again. */ @@ -1043,7 +1126,12 @@ static void __afl_start_forkserver(void) { /* Relay wait status to pipe, then loop back. */ - if (write(FORKSRV_FD + 1, &status, 4) != 4) _exit(1); + if (write(FORKSRV_FD + 1, &status, 4) != 4) { + + write_error("writing to afl-fuzz"); + _exit(1); + + } } @@ -1599,7 +1687,7 @@ void __cmplog_ins_hookN(uint128_t arg1, uint128_t arg2, uint8_t attr, void __cmplog_ins_hook16(uint128_t arg1, uint128_t arg2, uint8_t attr) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; uintptr_t k = (uintptr_t)__builtin_return_address(0); k = (k >> 4) ^ (k << 8); @@ -1668,7 +1756,7 @@ void __sanitizer_cov_trace_cmp4(uint32_t arg1, uint32_t arg2) { } -void __sanitizer_cov_trace_cost_cmp4(uint32_t arg1, uint32_t arg2) { +void __sanitizer_cov_trace_const_cmp4(uint32_t arg1, uint32_t arg2) { __cmplog_ins_hook4(arg1, arg2, 0); @@ -1703,7 +1791,7 @@ void __sanitizer_cov_trace_const_cmp16(uint128_t arg1, uint128_t arg2) { void __sanitizer_cov_trace_switch(uint64_t val, uint64_t *cases) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; for (uint64_t i = 0; i < cases[0]; i++) { @@ -1800,7 +1888,7 @@ void __cmplog_rtn_hook(u8 *ptr1, u8 *ptr2) { fprintf(stderr, "\n"); */ - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; // fprintf(stderr, "RTN1 %p %p\n", ptr1, ptr2); int l1, l2; if ((l1 = area_is_valid(ptr1, 32)) <= 0 || @@ -1884,7 +1972,7 @@ static u8 *get_llvm_stdstring(u8 *string) { void __cmplog_rtn_gcc_stdstring_cstring(u8 *stdstring, u8 *cstring) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; if (area_is_valid(stdstring, 32) <= 0 || area_is_valid(cstring, 32) <= 0) return; @@ -1894,7 +1982,7 @@ void __cmplog_rtn_gcc_stdstring_cstring(u8 *stdstring, u8 *cstring) { void __cmplog_rtn_gcc_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; if (area_is_valid(stdstring1, 32) <= 0 || area_is_valid(stdstring2, 32) <= 0) return; @@ -1905,7 +1993,7 @@ void __cmplog_rtn_gcc_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) { void __cmplog_rtn_llvm_stdstring_cstring(u8 *stdstring, u8 *cstring) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; if (area_is_valid(stdstring, 32) <= 0 || area_is_valid(cstring, 32) <= 0) return; @@ -1915,7 +2003,7 @@ void __cmplog_rtn_llvm_stdstring_cstring(u8 *stdstring, u8 *cstring) { void __cmplog_rtn_llvm_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) { - if (unlikely(!__afl_cmp_map)) return; + if (likely(!__afl_cmp_map)) return; if (area_is_valid(stdstring1, 32) <= 0 || area_is_valid(stdstring2, 32) <= 0) return; @@ -1949,7 +2037,7 @@ void __afl_coverage_on() { if (likely(__afl_selective_coverage && __afl_selective_coverage_temp)) { __afl_area_ptr = __afl_area_ptr_backup; - __afl_cmp_map = __afl_cmp_map_backup; + if (__afl_cmp_map_backup) { __afl_cmp_map = __afl_cmp_map_backup; } } @@ -1990,3 +2078,5 @@ void __afl_coverage_interesting(u8 val, u32 id) { } +#undef write_error + diff --git a/instrumentation/afl-llvm-common.cc b/instrumentation/afl-llvm-common.cc index af32e2f9..3239ea91 100644 --- a/instrumentation/afl-llvm-common.cc +++ b/instrumentation/afl-llvm-common.cc @@ -96,9 +96,8 @@ bool isIgnoreFunction(const llvm::Function *F) { static constexpr const char *ignoreSubstringList[] = { - "__asan", "__msan", "__ubsan", "__lsan", - "__san", "__sanitize", "__cxx", "_GLOBAL__", - "DebugCounter", "DwarfDebug", "DebugLoc" + "__asan", "__msan", "__ubsan", "__lsan", "__san", "__sanitize", + "__cxx", "DebugCounter", "DwarfDebug", "DebugLoc" }; diff --git a/instrumentation/afl-llvm-dict2file.so.cc b/instrumentation/afl-llvm-dict2file.so.cc index e2b44b21..58f01920 100644 --- a/instrumentation/afl-llvm-dict2file.so.cc +++ b/instrumentation/afl-llvm-dict2file.so.cc @@ -154,6 +154,7 @@ bool AFLdict2filePass::runOnModule(Module &M) { for (auto &F : M) { if (isIgnoreFunction(&F)) continue; + if (!isInInstrumentList(&F) || !F.size()) { continue; } /* Some implementation notes. * @@ -428,6 +429,12 @@ bool AFLdict2filePass::runOnModule(Module &M) { uint64_t literalLength = Str2.length(); uint64_t optLength = ilen->getZExtValue(); + if (optLength > literalLength + 1) { + + optLength = Str2.length() + 1; + + } + if (literalLength + 1 == optLength) { Str2.append("\0", 1); // add null byte diff --git a/instrumentation/afl-llvm-lto-instrumentation.so.cc b/instrumentation/afl-llvm-lto-instrumentation.so.cc index fe43fbe5..46aa388e 100644 --- a/instrumentation/afl-llvm-lto-instrumentation.so.cc +++ b/instrumentation/afl-llvm-lto-instrumentation.so.cc @@ -546,6 +546,12 @@ bool AFLLTOPass::runOnModule(Module &M) { uint64_t literalLength = Str2.size(); uint64_t optLength = ilen->getZExtValue(); + if (optLength > literalLength + 1) { + + optLength = Str2.length() + 1; + + } + if (literalLength + 1 == optLength) { Str2.append("\0", 1); // add null byte @@ -649,6 +655,7 @@ bool AFLLTOPass::runOnModule(Module &M) { uint64_t literalLength = optLen; optLen = ilen->getZExtValue(); + if (optLen > literalLength + 1) { optLen = literalLength + 1; } if (optLen < 2) { continue; } if (literalLength + 1 == optLen) { // add null byte thestring.append("\0", 1); diff --git a/instrumentation/afl-llvm-pass.so.cc b/instrumentation/afl-llvm-pass.so.cc index a8f1baff..b673d815 100644 --- a/instrumentation/afl-llvm-pass.so.cc +++ b/instrumentation/afl-llvm-pass.so.cc @@ -438,9 +438,9 @@ bool AFLCoverage::runOnModule(Module &M) { fprintf(stderr, "FUNCTION: %s (%zu)\n", F.getName().str().c_str(), F.size()); - if (!isInInstrumentList(&F)) continue; + if (!isInInstrumentList(&F)) { continue; } - if (F.size() < function_minimum_size) continue; + if (F.size() < function_minimum_size) { continue; } std::list<Value *> todo; for (auto &BB : F) { diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc index ad334d3b..0562c5b2 100644 --- a/instrumentation/cmplog-instructions-pass.cc +++ b/instrumentation/cmplog-instructions-pass.cc @@ -104,7 +104,6 @@ Iterator Unique(Iterator first, Iterator last) { bool CmpLogInstructions::hookInstrs(Module &M) { std::vector<Instruction *> icomps; - std::vector<SwitchInst *> switches; LLVMContext & C = M.getContext(); Type * VoidTy = Type::getVoidTy(C); @@ -222,6 +221,18 @@ bool CmpLogInstructions::hookInstrs(Module &M) { FunctionCallee cmplogHookInsN = cN; #endif + GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); + + if (!AFLCmplogPtr) { + + AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, + GlobalValue::ExternalWeakLinkage, 0, + "__afl_cmp_map"); + + } + + Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); + /* iterate over all functions, bbs and instruction and add suitable calls */ for (auto &F : M) { @@ -238,164 +249,6 @@ bool CmpLogInstructions::hookInstrs(Module &M) { } - SwitchInst *switchInst = nullptr; - if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { - - if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); } - - } - - } - - } - - } - - // unique the collected switches - switches.erase(Unique(switches.begin(), switches.end()), switches.end()); - - // Instrument switch values for cmplog - if (switches.size()) { - - if (!be_quiet) - errs() << "Hooking " << switches.size() << " switch instructions\n"; - - for (auto &SI : switches) { - - Value * Val = SI->getCondition(); - unsigned int max_size = Val->getType()->getIntegerBitWidth(), cast_size; - unsigned char do_cast = 0; - - if (!SI->getNumCases() || max_size < 16) { - - // if (!be_quiet) errs() << "skip trivial switch..\n"; - continue; - - } - - if (max_size % 8) { - - max_size = (((max_size / 8) + 1) * 8); - do_cast = 1; - - } - - IRBuilder<> IRB(SI->getParent()); - IRB.SetInsertPoint(SI); - - if (max_size > 128) { - - if (!be_quiet) { - - fprintf(stderr, - "Cannot handle this switch bit size: %u (truncating)\n", - max_size); - - } - - max_size = 128; - do_cast = 1; - - } - - // do we need to cast? - switch (max_size) { - - case 8: - case 16: - case 32: - case 64: - case 128: - cast_size = max_size; - break; - default: - cast_size = 128; - do_cast = 1; - - } - - Value *CompareTo = Val; - - if (do_cast) { - - CompareTo = - IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false); - - } - - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; - ++i) { - -#if LLVM_VERSION_MAJOR < 5 - ConstantInt *cint = i.getCaseValue(); -#else - ConstantInt *cint = i->getCaseValue(); -#endif - - if (cint) { - - std::vector<Value *> args; - args.push_back(CompareTo); - - Value *new_param = cint; - - if (do_cast) { - - new_param = - IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false); - - } - - if (new_param) { - - args.push_back(new_param); - ConstantInt *attribute = ConstantInt::get(Int8Ty, 1); - args.push_back(attribute); - if (cast_size != max_size) { - - ConstantInt *bitsize = - ConstantInt::get(Int8Ty, (max_size / 8) - 1); - args.push_back(bitsize); - - } - - switch (cast_size) { - - case 8: - IRB.CreateCall(cmplogHookIns1, args); - break; - case 16: - IRB.CreateCall(cmplogHookIns2, args); - break; - case 32: - IRB.CreateCall(cmplogHookIns4, args); - break; - case 64: - IRB.CreateCall(cmplogHookIns8, args); - break; - case 128: -#ifdef WORD_SIZE_64 - if (max_size == 128) { - - IRB.CreateCall(cmplogHookIns16, args); - - } else { - - IRB.CreateCall(cmplogHookInsN, args); - - } - -#endif - break; - default: - break; - - } - - } - - } - } } @@ -409,8 +262,15 @@ bool CmpLogInstructions::hookInstrs(Module &M) { for (auto &selectcmpInst : icomps) { - IRBuilder<> IRB(selectcmpInst->getParent()); - IRB.SetInsertPoint(selectcmpInst); + IRBuilder<> IRB2(selectcmpInst->getParent()); + IRB2.SetInsertPoint(selectcmpInst); + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = + SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false); + + IRBuilder<> IRB(ThenTerm); Value *op0 = selectcmpInst->getOperand(0); Value *op1 = selectcmpInst->getOperand(1); @@ -601,7 +461,7 @@ bool CmpLogInstructions::hookInstrs(Module &M) { } - if (switches.size() || icomps.size()) + if (icomps.size()) return true; else return false; diff --git a/instrumentation/cmplog-routines-pass.cc b/instrumentation/cmplog-routines-pass.cc index a5992c9a..1e2610f2 100644 --- a/instrumentation/cmplog-routines-pass.cc +++ b/instrumentation/cmplog-routines-pass.cc @@ -184,6 +184,18 @@ bool CmpLogRoutines::hookRtns(Module &M) { FunctionCallee cmplogGccStdC = c4; #endif + GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); + + if (!AFLCmplogPtr) { + + AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, + GlobalValue::ExternalWeakLinkage, 0, + "__afl_cmp_map"); + + } + + Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); + /* iterate over all functions, bbs and instruction and add suitable calls */ for (auto &F : M) { @@ -289,8 +301,15 @@ bool CmpLogRoutines::hookRtns(Module &M) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); - IRBuilder<> IRB(callInst->getParent()); - IRB.SetInsertPoint(callInst); + IRBuilder<> IRB2(callInst->getParent()); + IRB2.SetInsertPoint(callInst); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); + + IRBuilder<> IRB(ThenTerm); std::vector<Value *> args; Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); @@ -308,8 +327,15 @@ bool CmpLogRoutines::hookRtns(Module &M) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); - IRBuilder<> IRB(callInst->getParent()); - IRB.SetInsertPoint(callInst); + IRBuilder<> IRB2(callInst->getParent()); + IRB2.SetInsertPoint(callInst); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); + + IRBuilder<> IRB(ThenTerm); std::vector<Value *> args; Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); @@ -327,8 +353,15 @@ bool CmpLogRoutines::hookRtns(Module &M) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); - IRBuilder<> IRB(callInst->getParent()); - IRB.SetInsertPoint(callInst); + IRBuilder<> IRB2(callInst->getParent()); + IRB2.SetInsertPoint(callInst); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); + + IRBuilder<> IRB(ThenTerm); std::vector<Value *> args; Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); @@ -346,8 +379,15 @@ bool CmpLogRoutines::hookRtns(Module &M) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); - IRBuilder<> IRB(callInst->getParent()); - IRB.SetInsertPoint(callInst); + IRBuilder<> IRB2(callInst->getParent()); + IRB2.SetInsertPoint(callInst); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); + + IRBuilder<> IRB(ThenTerm); std::vector<Value *> args; Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); @@ -365,8 +405,15 @@ bool CmpLogRoutines::hookRtns(Module &M) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); - IRBuilder<> IRB(callInst->getParent()); - IRB.SetInsertPoint(callInst); + IRBuilder<> IRB2(callInst->getParent()); + IRB2.SetInsertPoint(callInst); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); + + IRBuilder<> IRB(ThenTerm); std::vector<Value *> args; Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); diff --git a/instrumentation/cmplog-switches-pass.cc b/instrumentation/cmplog-switches-pass.cc new file mode 100644 index 00000000..c42d44fe --- /dev/null +++ b/instrumentation/cmplog-switches-pass.cc @@ -0,0 +1,414 @@ +/* + american fuzzy lop++ - LLVM CmpLog instrumentation + -------------------------------------------------- + + Written by Andrea Fioraldi <andreafioraldi@gmail.com> + + Copyright 2015, 2016 Google Inc. All rights reserved. + Copyright 2019-2020 AFLplusplus Project. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +*/ + +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> + +#include <iostream> +#include <list> +#include <string> +#include <fstream> +#include <sys/time.h> + +#include "llvm/Config/llvm-config.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ValueTracking.h" + +#if LLVM_VERSION_MAJOR > 3 || \ + (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) + #include "llvm/IR/Verifier.h" + #include "llvm/IR/DebugInfo.h" +#else + #include "llvm/Analysis/Verifier.h" + #include "llvm/DebugInfo.h" + #define nullptr 0 +#endif + +#include <set> +#include "afl-llvm-common.h" + +using namespace llvm; + +namespace { + +class CmpLogInstructions : public ModulePass { + + public: + static char ID; + CmpLogInstructions() : ModulePass(ID) { + + initInstrumentList(); + + } + + bool runOnModule(Module &M) override; + +#if LLVM_VERSION_MAJOR < 4 + const char *getPassName() const override { + +#else + StringRef getPassName() const override { + +#endif + return "cmplog instructions"; + + } + + private: + bool hookInstrs(Module &M); + +}; + +} // namespace + +char CmpLogInstructions::ID = 0; + +template <class Iterator> +Iterator Unique(Iterator first, Iterator last) { + + while (first != last) { + + Iterator next(first); + last = std::remove(++next, last, *first); + first = next; + + } + + return last; + +} + +bool CmpLogInstructions::hookInstrs(Module &M) { + + std::vector<SwitchInst *> switches; + LLVMContext & C = M.getContext(); + + Type * VoidTy = Type::getVoidTy(C); + IntegerType *Int8Ty = IntegerType::getInt8Ty(C); + IntegerType *Int16Ty = IntegerType::getInt16Ty(C); + IntegerType *Int32Ty = IntegerType::getInt32Ty(C); + IntegerType *Int64Ty = IntegerType::getInt64Ty(C); + +#if LLVM_VERSION_MAJOR < 9 + Constant * +#else + FunctionCallee +#endif + c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty, + Int8Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR < 9 + Function *cmplogHookIns1 = cast<Function>(c1); +#else + FunctionCallee cmplogHookIns1 = c1; +#endif + +#if LLVM_VERSION_MAJOR < 9 + Constant * +#else + FunctionCallee +#endif + c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty, + Int8Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR < 9 + Function *cmplogHookIns2 = cast<Function>(c2); +#else + FunctionCallee cmplogHookIns2 = c2; +#endif + +#if LLVM_VERSION_MAJOR < 9 + Constant * +#else + FunctionCallee +#endif + c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty, + Int8Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR < 9 + Function *cmplogHookIns4 = cast<Function>(c4); +#else + FunctionCallee cmplogHookIns4 = c4; +#endif + +#if LLVM_VERSION_MAJOR < 9 + Constant * +#else + FunctionCallee +#endif + c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty, + Int8Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR < 9 + Function *cmplogHookIns8 = cast<Function>(c8); +#else + FunctionCallee cmplogHookIns8 = c8; +#endif + + GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); + + if (!AFLCmplogPtr) { + + AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, + GlobalValue::ExternalWeakLinkage, 0, + "__afl_cmp_map"); + + } + + Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); + + /* iterate over all functions, bbs and instruction and add suitable calls */ + for (auto &F : M) { + + if (!isInInstrumentList(&F)) continue; + + for (auto &BB : F) { + + SwitchInst *switchInst = nullptr; + if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { + + if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); } + + } + + } + + } + + // unique the collected switches + switches.erase(Unique(switches.begin(), switches.end()), switches.end()); + + // Instrument switch values for cmplog + if (switches.size()) { + + if (!be_quiet) + errs() << "Hooking " << switches.size() << " switch instructions\n"; + + for (auto &SI : switches) { + + Value * Val = SI->getCondition(); + unsigned int max_size = Val->getType()->getIntegerBitWidth(), cast_size; + unsigned char do_cast = 0; + + if (!SI->getNumCases() || max_size < 16) { + + // if (!be_quiet) errs() << "skip trivial switch..\n"; + continue; + + } + + if (max_size % 8) { + + max_size = (((max_size / 8) + 1) * 8); + do_cast = 1; + + } + + IRBuilder<> IRB2(SI->getParent()); + IRB2.SetInsertPoint(SI); + + LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr); + CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); + auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false); + + IRBuilder<> IRB(ThenTerm); + + if (max_size > 128) { + + if (!be_quiet) { + + fprintf(stderr, + "Cannot handle this switch bit size: %u (truncating)\n", + max_size); + + } + + max_size = 128; + do_cast = 1; + + } + + // do we need to cast? + switch (max_size) { + + case 8: + case 16: + case 32: + case 64: + case 128: + cast_size = max_size; + break; + default: + cast_size = 128; + do_cast = 1; + + } + + Value *CompareTo = Val; + + if (do_cast) { + + CompareTo = + IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false); + + } + + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; + ++i) { + +#if LLVM_VERSION_MAJOR < 5 + ConstantInt *cint = i.getCaseValue(); +#else + ConstantInt *cint = i->getCaseValue(); +#endif + + if (cint) { + + std::vector<Value *> args; + args.push_back(CompareTo); + + Value *new_param = cint; + + if (do_cast) { + + new_param = + IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false); + + } + + if (new_param) { + + args.push_back(new_param); + ConstantInt *attribute = ConstantInt::get(Int8Ty, 1); + args.push_back(attribute); + if (cast_size != max_size) { + + ConstantInt *bitsize = + ConstantInt::get(Int8Ty, (max_size / 8) - 1); + args.push_back(bitsize); + + } + + switch (cast_size) { + + case 8: + IRB.CreateCall(cmplogHookIns1, args); + break; + case 16: + IRB.CreateCall(cmplogHookIns2, args); + break; + case 32: + IRB.CreateCall(cmplogHookIns4, args); + break; + case 64: + IRB.CreateCall(cmplogHookIns8, args); + break; + case 128: +#ifdef WORD_SIZE_64 + if (max_size == 128) { + + IRB.CreateCall(cmplogHookIns16, args); + + } else { + + IRB.CreateCall(cmplogHookInsN, args); + + } + +#endif + break; + default: + break; + + } + + } + + } + + } + + } + + } + + if (switches.size()) + return true; + else + return false; + +} + +bool CmpLogInstructions::runOnModule(Module &M) { + + if (getenv("AFL_QUIET") == NULL) + printf("Running cmplog-switches-pass by andreafioraldi@gmail.com\n"); + else + be_quiet = 1; + hookInstrs(M); + verifyModule(M); + + return true; + +} + +static void registerCmpLogInstructionsPass(const PassManagerBuilder &, + legacy::PassManagerBase &PM) { + + auto p = new CmpLogInstructions(); + PM.add(p); + +} + +static RegisterStandardPasses RegisterCmpLogInstructionsPass( + PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass); + +static RegisterStandardPasses RegisterCmpLogInstructionsPass0( + PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass); + +#if LLVM_VERSION_MAJOR >= 11 +static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO( + PassManagerBuilder::EP_FullLinkTimeOptimizationLast, + registerCmpLogInstructionsPass); +#endif + diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc index 3ecba4e6..f5dd4a53 100644 --- a/instrumentation/compare-transform-pass.so.cc +++ b/instrumentation/compare-transform-pass.so.cc @@ -313,27 +313,18 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, ConstantInt *ilen = dyn_cast<ConstantInt>(op2); if (ilen) { - uint64_t len = ilen->getZExtValue(); // if len is zero this is a pointless call but allow real // implementation to worry about that - if (len < 2) continue; + if (ilen->getZExtValue() < 2) { continue; } - if (isMemcmp) { - - // if size of compare is larger than constant string this is - // likely a bug but allow real implementation to worry about - // that - uint64_t literalLength = HasStr1 ? Str1.size() : Str2.size(); - if (literalLength + 1 < ilen->getZExtValue()) continue; - - } - - } else if (isMemcmp) + } else if (isMemcmp) { // this *may* supply a len greater than the constant string at // runtime so similarly we don't want to have to handle that continue; + } + } calls.push_back(callInst); @@ -421,7 +412,7 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, } if (TmpConstStr.length() < 2 || - (TmpConstStr.length() == 2 && !TmpConstStr[1])) { + (TmpConstStr.length() == 2 && TmpConstStr[1] == 0)) { continue; diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc index b02a89fb..13f45b69 100644 --- a/instrumentation/split-compares-pass.so.cc +++ b/instrumentation/split-compares-pass.so.cc @@ -47,6 +47,10 @@ using namespace llvm; #include "afl-llvm-common.h" +// uncomment this toggle function verification at each step. horribly slow, but +// helps to pinpoint a potential problem in the splitting code. +//#define VERIFY_TOO_MUCH 1 + namespace { class SplitComparesTransform : public ModulePass { @@ -67,28 +71,101 @@ class SplitComparesTransform : public ModulePass { const char *getPassName() const override { #endif - return "simplifies and splits ICMP instructions"; + return "AFL_SplitComparesTransform"; } private: int enableFPSplit; - size_t splitIntCompares(Module &M, unsigned bitw); + unsigned target_bitwidth = 8; + + size_t count = 0; + size_t splitFPCompares(Module &M); - bool simplifyCompares(Module &M); bool simplifyFPCompares(Module &M); - bool simplifyIntSignedness(Module &M); size_t nextPowerOfTwo(size_t in); + using CmpWorklist = SmallVector<CmpInst *, 8>; + + /// simplify the comparison and then split the comparison until the + /// target_bitwidth is reached. + bool simplifyAndSplit(CmpInst *I, Module &M); + /// simplify a non-strict comparison (e.g., less than or equals) + bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// simplify a signed comparison (signed less or greater than) + bool simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// splits an icmp into nested icmps recursivly until target_bitwidth is + /// reached + bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist); + + /// print an error to llvm's errs stream, but only if not ordered to be quiet + void reportError(const StringRef msg, Instruction *I, Module &M) { + + if (!be_quiet) { + + errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n"; + if (debug) { + + if (I) { + + errs() << "Instruction = " << *I << "\n"; + if (auto BB = I->getParent()) { + + if (auto F = BB->getParent()) { + + if (F->hasName()) { + + errs() << "|-> in function " << F->getName() << " "; + + } + + } + + } + + } + + auto n = M.getName(); + if (n.size() > 0) { errs() << "in module " << n << "\n"; } + + } + + } + + } + + bool isSupportedBitWidth(unsigned bitw) { + + // IDK whether the icmp code works on other bitwidths. I guess not? So we + // try to avoid dealing with other weird icmp's that llvm might use (looking + // at you `icmp i0`). + switch (bitw) { + + case 8: + case 16: + case 32: + case 64: + case 128: + case 256: + return true; + default: + return false; + + } + + } + }; } // namespace char SplitComparesTransform::ID = 0; -/* This function splits FCMP instructions with xGE or xLE predicates into two - * FCMP instructions with predicate xGT or xLT and EQ */ +/// This function splits FCMP instructions with xGE or xLE predicates into two +/// FCMP instructions with predicate xGT or xLT and EQ bool SplitComparesTransform::simplifyFPCompares(Module &M) { LLVMContext & C = M.getContext(); @@ -221,292 +298,481 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { } -/* This function splits ICMP instructions with xGE or xLE predicates into two - * ICMP instructions with predicate xGT or xLT and EQ */ -bool SplitComparesTransform::simplifyCompares(Module &M) { +/// This function splits ICMP instructions with xGE or xLE predicates into two +/// ICMP instructions with predicate xGT or xLT and EQ +bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst * IcmpInst, + Module & M, + CmpWorklist &worklist) { - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - /* iterate over all functions, bbs and instruction and add - * all integer comparisons with >= and <= predicates to the icomps vector */ - for (auto &F : M) { + /* find out what the new predicate is going to be */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } - if (!isInInstrumentList(&F)) continue; + BasicBlock *bb = IcmpInst->getParent(); - for (auto &BB : F) { + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - for (auto &IN : BB) { + CmpInst::Predicate pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - CmpInst *selectcmpInst = nullptr; + switch (pred) { - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + case CmpInst::ICMP_UGE: + new_pred = CmpInst::ICMP_UGT; + break; + case CmpInst::ICMP_SGE: + new_pred = CmpInst::ICMP_SGT; + break; + case CmpInst::ICMP_ULE: + new_pred = CmpInst::ICMP_ULT; + break; + case CmpInst::ICMP_SLE: + new_pred = CmpInst::ICMP_SLT; + break; + default: // keep the compiler happy + return false; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) { + } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + /* split before the icmp instruction */ + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* the old bb now contains a unconditional jump to the new one (end_bb) + * we need to delete it later */ + + /* create the ICMP instruction with new_pred and add it to the old basic + * block bb it is now at the position where the old IcmpInst was */ + CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np); + + /* create a new basic block which holds the new EQ icmp */ + CmpInst *icmp_eq; + /* insert middle_bb before end_bb */ + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); + middle_bb->getInstList().push_back(icmp_eq); + /* add an unconditional branch to the end of middle_bb with destination + * end_bb */ + BranchInst::Create(end_bb, middle_bb); + + /* replace the uncond branch with a conditional one, which depends on the + * new_pred icmp. True goes to end, false to the middle (injected) bb */ + auto term = bb->getTerminator(); + BranchInst::Create(end_bb, middle_bb, icmp_np, bb); + term->eraseFromParent(); + + /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI + * inst to wire up the loose ends */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + /* the first result depends on the outcome of icmp_eq */ + PN->addIncoming(icmp_eq, middle_bb); + /* if the source was the original bb we know that the icmp_np yielded true + * hence we can hardcode this value */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* replace the old IcmpInst with our new and shiny PHI inst */ + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + worklist.push_back(icmp_np); + worklist.push_back(icmp_eq); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + return true; - /* this is probably not needed but we do it anyway */ - if (!intTyOp0 || !intTyOp1) { continue; } +} - icomps.push_back(selectcmpInst); +/// Simplify a signed comparison operator by splitting it into a unsigned and +/// bit comparison. add all resulting comparisons to +/// the worklist passed as a reference. +bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist) { - } + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - } + BasicBlock *bb = IcmpInst->getParent(); - } + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - } + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { return false; } + unsigned bitw = intTyOp0->getBitWidth(); + IntegerType *IntType = IntegerType::get(C, bitw); - } + /* get the new predicate */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } + auto pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - if (!icomps.size()) { return false; } + if (pred == CmpInst::ICMP_SGT) { - for (auto &IcmpInst : icomps) { + new_pred = CmpInst::ICMP_UGT; - BasicBlock *bb = IcmpInst->getParent(); + } else { - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); + new_pred = CmpInst::ICMP_ULT; - /* find out what the new predicate is going to be */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + } - switch (pred) { + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* create a 1 bit compare for the sign bit. to do this shift and trunc + * the original operands so only the first bit remains.*/ + Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + + IRBuilder<> IRB(bb->getTerminator()); + s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1)); + t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty); + s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1)); + t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty); + /* compare of the sign bits */ + icmp_sign_bit = IRB.CreateICmp(CmpInst::ICMP_EQ, t_op0, t_op1); + + /* create a new basic block which is executed if the signedness bit is + * different */ + CmpInst * icmp_inv_sig_cmp; + BasicBlock *sign_bb = + BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_SGT) { + + /* if we check for > and the op0 positive and op1 negative then the final + * result is true. if op0 negative and op1 pos, the cmp must result + * in false + */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); - case CmpInst::ICMP_UGE: - new_pred = CmpInst::ICMP_UGT; - break; - case CmpInst::ICMP_SGE: - new_pred = CmpInst::ICMP_SGT; - break; - case CmpInst::ICMP_ULE: - new_pred = CmpInst::ICMP_ULT; - break; - case CmpInst::ICMP_SLE: - new_pred = CmpInst::ICMP_SLT; - break; - default: // keep the compiler happy - continue; + } else { - } + /* just the inverse of the above statement */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); - /* split before the icmp instruction */ - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + } - /* the old bb now contains a unconditional jump to the new one (end_bb) - * we need to delete it later */ + sign_bb->getInstList().push_back(icmp_inv_sig_cmp); + BranchInst::Create(end_bb, sign_bb); - /* create the ICMP instruction with new_pred and add it to the old basic - * block bb it is now at the position where the old IcmpInst was */ - Instruction *icmp_np; - icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_np); + /* create a new bb which is executed if signedness is equal */ + CmpInst * icmp_usign_cmp; + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + /* we can do a normal unsigned compare now */ + icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - /* create a new basic block which holds the new EQ icmp */ - Instruction *icmp_eq; - /* insert middle_bb before end_bb */ - BasicBlock *middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); - middle_bb->getInstList().push_back(icmp_eq); - /* add an unconditional branch to the end of middle_bb with destination - * end_bb */ - BranchInst::Create(end_bb, middle_bb); + middle_bb->getInstList().push_back(icmp_usign_cmp); + BranchInst::Create(end_bb, middle_bb); - /* replace the uncond branch with a conditional one, which depends on the - * new_pred icmp. True goes to end, false to the middle (injected) bb */ - auto term = bb->getTerminator(); - BranchInst::Create(end_bb, middle_bb, icmp_np, bb); - term->eraseFromParent(); + auto term = bb->getTerminator(); + /* if the sign is eq do a normal unsigned cmp, else we have to check the + * signedness bit */ + BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); + term->eraseFromParent(); - /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI - * inst to wire up the loose ends */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - /* the first result depends on the outcome of icmp_eq */ - PN->addIncoming(icmp_eq, middle_bb); - /* if the source was the original bb we know that the icmp_np yielded true - * hence we can hardcode this value */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - /* replace the old IcmpInst with our new and shiny PHI inst */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - } + PN->addIncoming(icmp_usign_cmp, middle_bb); + PN->addIncoming(icmp_inv_sig_cmp, sign_bb); + + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + // save for later + worklist.push_back(icmp_usign_cmp); + + // signed comparisons are not supported by the splitting code, so we must not + // add it to the worklist. + // worklist.push_back(icmp_inv_sig_cmp); return true; } -/* this function transforms signed compares to equivalent unsigned compares */ -bool SplitComparesTransform::simplifyIntSignedness(Module &M) { +bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M, + CmpWorklist &worklist) { - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); + auto pred = cmp_inst->getPredicate(); + switch (pred) { - /* iterate over all functions, bbs and instructions and add - * all signed compares to icomps vector */ - for (auto &F : M) { + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: + break; + default: + // unsupported predicate! + return false; - if (!isInInstrumentList(&F)) continue; + } - for (auto &BB : F) { + auto op0 = cmp_inst->getOperand(0); + auto op1 = cmp_inst->getOperand(1); - for (auto &IN : BB) { + // get bitwidth by checking the bitwidth of the first operator + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { - CmpInst *selectcmpInst = nullptr; + // not an integer type + return false; - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + } - if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) { + unsigned bitw = intTyOp0->getBitWidth(); + if (bitw == target_bitwidth) { - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + // already the target bitwidth so we have to do nothing here. + return true; + + } + + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + BasicBlock * bb = cmp_inst->getParent(); + IntegerType *OldIntType = IntegerType::get(C, bitw); + IntegerType *NewIntType = IntegerType::get(C, bitw / 2); + BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst)); + CmpInst * icmp_high, *icmp_low; - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + /* create the comparison of the top halves of the original operands */ + Value *s_op0, *op0_high, *s_op1, *op1_high; - /* see above */ - if (!intTyOp0 || !intTyOp1) { continue; } + IRBuilder<> IRB(bb->getTerminator()); - /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { + s_op0 = IRB.CreateBinOp(Instruction::LShr, op0, + ConstantInt::get(OldIntType, bitw / 2)); + op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType); - continue; + s_op1 = IRB.CreateBinOp(Instruction::LShr, op1, + ConstantInt::get(OldIntType, bitw / 2)); + op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType); + icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high)); - } + PHINode *PN = nullptr; - icomps.push_back(selectcmpInst); + /* now we have to destinguish between == != and > < */ + switch (pred) { - } + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: { - } + /* transformation for == and != icmps */ + + /* create a compare for the lower half of the original operands */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); + + Value * op0_low, *op1_low; + IRBuilder<> Builder(cmp_low_bb); + + op0_low = Builder.CreateTrunc(op0, NewIntType); + op1_low = Builder.CreateTrunc(op1, NewIntType); + icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low)); + + BranchInst::Create(end_bb, cmp_low_bb); + + /* dependent on the cmp of the high parts go to the end or go on with + * the comparison */ + auto term = bb->getTerminator(); + BranchInst *br = nullptr; + if (pred == CmpInst::ICMP_EQ) { + + br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); + + } else { + + /* CmpInst::ICMP_NE */ + br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); + + } + + term->eraseFromParent(); + + /* create the PHI and connect the edges accordingly */ + PN = PHINode::Create(Int1Ty, 2, ""); + PN->addIncoming(icmp_low, cmp_low_bb); + Value *val = nullptr; + if (pred == CmpInst::ICMP_EQ) { + + val = ConstantInt::get(Int1Ty, 0); + + } else { + + /* CmpInst::ICMP_NE */ + val = ConstantInt::get(Int1Ty, 1); } + PN->addIncoming(val, icmp_high->getParent()); + break; + } - } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: { - if (!icomps.size()) { return false; } + /* transformations for < and > */ - for (auto &IcmpInst : icomps) { + /* create a basic block which checks for the inverse predicate. + * if this is true we can go to the end if not we have to go to the + * bb which checks the lower half of the operands */ + Instruction *op0_low, *op1_low; + CmpInst * icmp_inv_cmp = nullptr; + BasicBlock * inv_cmp_bb = + BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_UGT) { - BasicBlock *bb = IcmpInst->getParent(); + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, + op0_high, op1_high); - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); + } else { - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - if (!intTyOp0) { continue; } - unsigned bitw = intTyOp0->getBitWidth(); - IntegerType *IntType = IntegerType::get(C, bitw); + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, + op0_high, op1_high); - /* get the new predicate */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + } - if (pred == CmpInst::ICMP_SGT) { + inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); + worklist.push_back(icmp_inv_cmp); - new_pred = CmpInst::ICMP_UGT; + auto term = bb->getTerminator(); + term->eraseFromParent(); + BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - } else { + /* create a bb which handles the cmp of the lower halves */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); + op0_low = new TruncInst(op0, NewIntType); + cmp_low_bb->getInstList().push_back(op0_low); + op1_low = new TruncInst(op1, NewIntType); + cmp_low_bb->getInstList().push_back(op1_low); - new_pred = CmpInst::ICMP_ULT; + icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); + cmp_low_bb->getInstList().push_back(icmp_low); + BranchInst::Create(end_bb, cmp_low_bb); + + BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); + + PN = PHINode::Create(Int1Ty, 3); + PN->addIncoming(icmp_low, cmp_low_bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); + break; } - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + default: + return false; - /* create a 1 bit compare for the sign bit. to do this shift and trunc - * the original operands so only the first bit remains.*/ - Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + } - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - t_op0 = new TruncInst(s_op0, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0); + BasicBlock::iterator ii(cmp_inst); + ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN); - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - t_op1 = new TruncInst(s_op1, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1); + // We split the comparison into low and high. If this isn't our target + // bitwidth we recursivly split the low and high parts again until we have + // target bitwidth. + if ((bitw / 2) > target_bitwidth) { - /* compare of the sign bits */ - icmp_sign_bit = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_sign_bit); + worklist.push_back(icmp_high); + worklist.push_back(icmp_low); - /* create a new basic block which is executed if the signedness bit is - * different */ - Instruction *icmp_inv_sig_cmp; - BasicBlock * sign_bb = - BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_SGT) { + } - /* if we check for > and the op0 positive and op1 negative then the final - * result is true. if op0 negative and op1 pos, the cmp must result - * in false - */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + return true; - } else { +} + +bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) { + + CmpWorklist worklist; + + auto op0 = I->getOperand(0); + auto op1 = I->getOperand(1); + if (!op0 || !op1) { return false; } + auto op0Ty = dyn_cast<IntegerType>(op0->getType()); + if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; } + + unsigned bitw = op0Ty->getBitWidth(); + +#ifdef VERIFY_TOO_MUCH + auto F = I->getParent()->getParent(); +#endif - /* just the inverse of the above statement */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + // we run the comparison simplification on all compares regardless of their + // bitwidth. + if (I->getPredicate() == CmpInst::ICMP_UGE || + I->getPredicate() == CmpInst::ICMP_SGE || + I->getPredicate() == CmpInst::ICMP_ULE || + I->getPredicate() == CmpInst::ICMP_SLE) { + + if (!simplifyOrEqualsCompare(I, M, worklist)) { + + reportError( + "Failed to simplify inequality or equals comparison " + "(UGE,SGE,ULE,SLE)", + I, M); } - sign_bb->getInstList().push_back(icmp_inv_sig_cmp); - BranchInst::Create(end_bb, sign_bb); + } else if (I->getPredicate() == CmpInst::ICMP_SGT || - /* create a new bb which is executed if signedness is equal */ - Instruction *icmp_usign_cmp; - BasicBlock * middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - /* we can do a normal unsigned compare now */ - icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - middle_bb->getInstList().push_back(icmp_usign_cmp); - BranchInst::Create(end_bb, middle_bb); + I->getPredicate() == CmpInst::ICMP_SLT) { - auto term = bb->getTerminator(); - /* if the sign is eq do a normal unsigned cmp, else we have to check the - * signedness bit */ - BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); - term->eraseFromParent(); + if (!simplifySignedCompare(I, M, worklist)) { - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + reportError("Failed to simplify signed comparison (SGT,SLT)", I, M); + + } + + } + +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + + reportError("simpliyfing compare lead to broken function", nullptr, M); + + } + +#endif + + // the simplification methods replace the original CmpInst and push the + // resulting new CmpInst into the worklist. If the worklist is empty then + // we only have to split the original CmpInst. + if (worklist.size() == 0) { worklist.push_back(I); } + + while (!worklist.empty()) { + + CmpInst *cmp = worklist.pop_back_val(); + // we split the simplified compares into comparisons with smaller bitwidths + // if they are larger than our target_bitwidth. + if (bitw > target_bitwidth) { + + if (!splitCompare(cmp, M, worklist)) { + + reportError("Failed to split comparison", cmp, M); + + } + +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + + reportError("splitting compare lead to broken function", nullptr, M); + + } - PN->addIncoming(icmp_usign_cmp, middle_bb); - PN->addIncoming(icmp_inv_sig_cmp, sign_bb); +#endif - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } } + count++; return true; } @@ -1050,306 +1316,110 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { } -/* splits icmps of size bitw into two nested icmps with bitw/2 size each */ -size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { - - size_t count = 0; - - LLVMContext &C = M.getContext(); - - IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - IntegerType *OldIntType = IntegerType::get(C, bitw); - IntegerType *NewIntType = IntegerType::get(C, bitw / 2); - - std::vector<Instruction *> icomps; - - if (bitw % 2) { return 0; } - - /* not supported yet */ - if (bitw > 64) { return 0; } - - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the - * functions simplifyCompares() and simplifyIntSignedness() - * were executed only these four predicates should exist */ - for (auto &F : M) { - - if (!isInInstrumentList(&F)) continue; +bool SplitComparesTransform::runOnModule(Module &M) { - for (auto &BB : F) { + char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); + if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); + if (bitw_env) { target_bitwidth = atoi(bitw_env); } - for (auto &IN : BB) { + enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL; - CmpInst *selectcmpInst = nullptr; + if ((isatty(2) && getenv("AFL_QUIET") == NULL) || + getenv("AFL_DEBUG") != NULL) { - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + errs() << "Split-compare-pass by laf.intel@gmail.com, extended by " + "heiko@hexco.de (splitting icmp to " + << target_bitwidth << " bit)\n"; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ || - selectcmpInst->getPredicate() == CmpInst::ICMP_NE || - selectcmpInst->getPredicate() == CmpInst::ICMP_UGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) { + if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); - - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + } else { - if (!intTyOp0 || !intTyOp1) { continue; } + be_quiet = 1; - /* check if the bitwidths are the one we are looking for */ - if (intTyOp0->getBitWidth() != bitw || - intTyOp1->getBitWidth() != bitw) { + } - continue; + if (enableFPSplit) { - } + count = splitFPCompares(M); - icomps.push_back(selectcmpInst); + /* + if (!be_quiet) { - } + errs() << "Split-floatingpoint-compare-pass: " << count + << " FP comparisons split\n"; } - } - - } + */ + simplifyFPCompares(M); } - if (!icomps.size()) { return 0; } - - for (auto &IcmpInst : icomps) { - - BasicBlock *bb = IcmpInst->getParent(); - - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); - - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); - - /* create the comparison of the top halves of the original operands */ - Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high; - - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - op0_high = new TruncInst(s_op0, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op0_high); - - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - op1_high = new TruncInst(s_op1, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op1_high); - - icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_high); - - /* now we have to destinguish between == != and > < */ - if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { - - /* transformation for == and != icmps */ - - /* create a compare for the lower half of the original operands */ - Instruction *op0_low, *op1_low, *icmp_low; - BasicBlock * cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - /* dependent on the cmp of the high parts go to the end or go on with - * the comparison */ - auto term = bb->getTerminator(); - if (pred == CmpInst::ICMP_EQ) { - - BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); - - } else { - - /* CmpInst::ICMP_NE */ - BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); - - } - - term->eraseFromParent(); - - /* create the PHI and connect the edges accordingly */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - PN->addIncoming(icmp_low, cmp_low_bb); - if (pred == CmpInst::ICMP_EQ) { - - PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); - - } else { + std::vector<CmpInst *> worklist; + /* iterate over all functions, bbs and instruction search for all integer + * compare instructions. Save them into the worklist for later. */ + for (auto &F : M) { - /* CmpInst::ICMP_NE */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + if (!isInInstrumentList(&F)) continue; - } + for (auto &BB : F) { - /* replace the old icmp with the new PHI */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + for (auto &IN : BB) { - } else { + if (auto CI = dyn_cast<CmpInst>(&IN)) { - /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */ - /* transformations for < and > */ + auto op0 = CI->getOperand(0); + auto op1 = CI->getOperand(1); + if (!op0 || !op1) { return false; } + auto iTy1 = dyn_cast<IntegerType>(op0->getType()); + if (iTy1 && isa<IntegerType>(op1->getType())) { - /* create a basic block which checks for the inverse predicate. - * if this is true we can go to the end if not we have to go to the - * bb which checks the lower half of the operands */ - Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; - BasicBlock * inv_cmp_bb = - BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_UGT) { + unsigned bitw = iTy1->getBitWidth(); + if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); } - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, - op0_high, op1_high); - - } else { + } - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, - op0_high, op1_high); + } } - inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); - - auto term = bb->getTerminator(); - term->eraseFromParent(); - BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - - /* create a bb which handles the cmp of the lower halves */ - BasicBlock *cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); - - PHINode *PN = PHINode::Create(Int1Ty, 3); - PN->addIncoming(icmp_low, cmp_low_bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); - - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); - } - ++count; - } - return count; - -} - -bool SplitComparesTransform::runOnModule(Module &M) { - - int bitw = 64; - size_t count = 0; - - char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); - if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); - if (bitw_env) { bitw = atoi(bitw_env); } - - enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL; - - if ((isatty(2) && getenv("AFL_QUIET") == NULL) || - getenv("AFL_DEBUG") != NULL) { - - printf( - "Split-compare-pass by laf.intel@gmail.com, extended by " - "heiko@hexco.de\n"); + // now that we have a list of all integer comparisons we can start replacing + // them with the splitted alternatives. + for (auto CI : worklist) { - } else { - - be_quiet = 1; + simplifyAndSplit(CI, M); } - if (enableFPSplit) { - - count = splitFPCompares(M); - - /* - if (!be_quiet) { - - errs() << "Split-floatingpoint-compare-pass: " << count - << " FP comparisons split\n"; + bool brokenDebug = false; + if (verifyModule(M, &errs() +#if LLVM_VERSION_MAJOR > 3 || \ + (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9) + , + &brokenDebug // 9th May 2016 +#endif + )) { - } - - */ - simplifyFPCompares(M); + reportError( + "Module Verifier failed! Consider reporting a bug with the AFL++ " + "project.", + nullptr, M); } - simplifyCompares(M); - - simplifyIntSignedness(M); + if (brokenDebug) { - switch (bitw) { - - case 64: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 32: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 16: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - // bitw >>= 1; - break; - - default: - // if (!be_quiet) errs() << "NOT Running split-compare-pass \n"; - return false; - break; + reportError("Module Verifier reported broken Debug Infos - Stripping!", + nullptr, M); + StripDebugInfo(M); } - verifyModule(M); return true; } @@ -1373,3 +1443,8 @@ static RegisterStandardPasses RegisterSplitComparesTransPassLTO( registerSplitComparesPass); #endif +static RegisterPass<SplitComparesTransform> X("splitcompares", + "AFL++ split compares", + true /* Only looks at CFG */, + true /* Analysis Pass */); + |