about summary refs log tree commit diff
path: root/instrumentation
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation')
-rw-r--r--instrumentation/SanitizerCoverageLTO.so.cc14
-rw-r--r--instrumentation/afl-compiler-rt.o.c160
-rw-r--r--instrumentation/afl-llvm-common.cc5
-rw-r--r--instrumentation/afl-llvm-dict2file.so.cc7
-rw-r--r--instrumentation/afl-llvm-lto-instrumentation.so.cc7
-rw-r--r--instrumentation/afl-llvm-pass.so.cc4
-rw-r--r--instrumentation/cmplog-instructions-pass.cc184
-rw-r--r--instrumentation/cmplog-routines-pass.cc67
-rw-r--r--instrumentation/cmplog-switches-pass.cc414
-rw-r--r--instrumentation/compare-transform-pass.so.cc19
-rw-r--r--instrumentation/split-compares-pass.so.cc987
11 files changed, 1186 insertions, 682 deletions
diff --git a/instrumentation/SanitizerCoverageLTO.so.cc b/instrumentation/SanitizerCoverageLTO.so.cc
index 20f1856e..eddbfcc8 100644
--- a/instrumentation/SanitizerCoverageLTO.so.cc
+++ b/instrumentation/SanitizerCoverageLTO.so.cc
@@ -516,6 +516,8 @@ bool ModuleSanitizerCoverage::instrumentModule(
 
     for (auto &F : M) {
 
+      if (!isInInstrumentList(&F) || !F.size()) { continue; }
+
       for (auto &BB : F) {
 
         for (auto &IN : BB) {
@@ -759,6 +761,12 @@ bool ModuleSanitizerCoverage::instrumentModule(
 
                   uint64_t literalLength = Str2.size();
                   uint64_t optLength = ilen->getZExtValue();
+                  if (optLength > literalLength + 1) {
+
+                    optLength = Str2.length() + 1;
+
+                  }
+
                   if (literalLength + 1 == optLength) {
 
                     Str2.append("\0", 1);  // add null byte
@@ -862,6 +870,12 @@ bool ModuleSanitizerCoverage::instrumentModule(
 
                 uint64_t literalLength = optLen;
                 optLen = ilen->getZExtValue();
+                if (optLen > thestring.length() + 1) {
+
+                  optLen = thestring.length() + 1;
+
+                }
+
                 if (optLen < 2) { continue; }
                 if (literalLength + 1 == optLen) {  // add null byte
                   thestring.append("\0", 1);
diff --git a/instrumentation/afl-compiler-rt.o.c b/instrumentation/afl-compiler-rt.o.c
index 2089ce78..3fec291c 100644
--- a/instrumentation/afl-compiler-rt.o.c
+++ b/instrumentation/afl-compiler-rt.o.c
@@ -98,9 +98,9 @@ u32 __afl_dictionary_len;
 u64 __afl_map_addr;
 
 // for the __AFL_COVERAGE_ON/__AFL_COVERAGE_OFF features to work:
-int __afl_selective_coverage __attribute__((weak));
-int __afl_selective_coverage_start_off __attribute__((weak));
-int __afl_selective_coverage_temp = 1;
+int        __afl_selective_coverage __attribute__((weak));
+int        __afl_selective_coverage_start_off __attribute__((weak));
+static int __afl_selective_coverage_temp = 1;
 
 #if defined(__ANDROID__) || defined(__HAIKU__)
 PREV_LOC_T __afl_prev_loc[NGRAM_SIZE_MAX];
@@ -147,7 +147,7 @@ static int __afl_dummy_fd[2] = {2, 2};
 
 /* ensure we kill the child on termination */
 
-void at_exit(int signal) {
+static void at_exit(int signal) {
 
   if (child_pid > 0) { kill(child_pid, SIGKILL); }
 
@@ -179,7 +179,7 @@ void __afl_trace(const u32 x) {
 
 /* Error reporting to forkserver controller */
 
-void send_forkserver_error(int error) {
+static void send_forkserver_error(int error) {
 
   u32 status;
   if (!error || error > 0xffff) return;
@@ -270,12 +270,6 @@ static void __afl_map_shm(void) {
 
   if (__afl_final_loc) {
 
-    if (__afl_final_loc % 64) {
-
-      __afl_final_loc = (((__afl_final_loc + 63) >> 6) << 6);
-
-    }
-
     __afl_map_size = __afl_final_loc;
 
     if (__afl_final_loc > MAP_SIZE) {
@@ -304,8 +298,9 @@ static void __afl_map_shm(void) {
 
           if (!getenv("AFL_QUIET"))
             fprintf(stderr,
-                    "Warning: AFL++ tools will need to set AFL_MAP_SIZE to %u "
-                    "to be able to run this instrumented program!\n",
+                    "Warning: AFL++ tools might need to set AFL_MAP_SIZE to %u "
+                    "to be able to run this instrumented program if this "
+                    "crashes!\n",
                     __afl_final_loc);
 
         }
@@ -622,6 +617,7 @@ static void __afl_unmap_shm(void) {
 #endif
 
     __afl_cmp_map = NULL;
+    __afl_cmp_map_backup = NULL;
 
   }
 
@@ -629,6 +625,34 @@ static void __afl_unmap_shm(void) {
 
 }
 
+#define write_error(text) write_error_with_location(text, __FILE__, __LINE__)
+
+void write_error_with_location(char *text, char *filename, int linenumber) {
+
+  u8 *  o = getenv("__AFL_OUT_DIR");
+  char *e = strerror(errno);
+
+  if (o) {
+
+    char buf[4096];
+    snprintf(buf, sizeof(buf), "%s/error.txt", o);
+    FILE *f = fopen(buf, "a");
+
+    if (f) {
+
+      fprintf(f, "File %s, line %d: Error(%s): %s\n", filename, linenumber,
+              text, e);
+      fclose(f);
+
+    }
+
+  }
+
+  fprintf(stderr, "File %s, line %d: Error(%s): %s\n", filename, linenumber,
+          text, e);
+
+}
+
 #ifdef __linux__
 static void __afl_start_snapshots(void) {
 
@@ -655,7 +679,12 @@ static void __afl_start_snapshots(void) {
 
   if (__afl_sharedmem_fuzzing || (__afl_dictionary_len && __afl_dictionary)) {
 
-    if (read(FORKSRV_FD, &was_killed, 4) != 4) { _exit(1); }
+    if (read(FORKSRV_FD, &was_killed, 4) != 4) {
+
+      write_error("read to afl-fuzz");
+      _exit(1);
+
+    }
 
     if (__afl_debug) {
 
@@ -724,7 +753,12 @@ static void __afl_start_snapshots(void) {
     } else {
 
       /* Wait for parent by reading from the pipe. Abort if read fails. */
-      if (read(FORKSRV_FD, &was_killed, 4) != 4) _exit(1);
+      if (read(FORKSRV_FD, &was_killed, 4) != 4) {
+
+        write_error("reading from afl-fuzz");
+        _exit(1);
+
+      }
 
     }
 
@@ -761,7 +795,12 @@ static void __afl_start_snapshots(void) {
     if (child_stopped && was_killed) {
 
       child_stopped = 0;
-      if (waitpid(child_pid, &status, 0) < 0) _exit(1);
+      if (waitpid(child_pid, &status, 0) < 0) {
+
+        write_error("child_stopped && was_killed");
+        _exit(1);  // TODO why exit?
+
+      }
 
     }
 
@@ -770,7 +809,12 @@ static void __afl_start_snapshots(void) {
       /* Once woken up, create a clone of our process. */
 
       child_pid = fork();
-      if (child_pid < 0) _exit(1);
+      if (child_pid < 0) {
+
+        write_error("fork");
+        _exit(1);
+
+      }
 
       /* In child process: close fds, resume execution. */
 
@@ -810,9 +854,19 @@ static void __afl_start_snapshots(void) {
 
     /* In parent process: write PID to pipe, then wait for child. */
 
-    if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) _exit(1);
+    if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) {
+
+      write_error("write to afl-fuzz");
+      _exit(1);
+
+    }
 
-    if (waitpid(child_pid, &status, WUNTRACED) < 0) _exit(1);
+    if (waitpid(child_pid, &status, WUNTRACED) < 0) {
+
+      write_error("waitpid");
+      _exit(1);
+
+    }
 
     /* In persistent mode, the child stops itself with SIGSTOP to indicate
        a successful run. In this case, we want to wake it up without forking
@@ -822,7 +876,12 @@ static void __afl_start_snapshots(void) {
 
     /* Relay wait status to pipe, then loop back. */
 
-    if (write(FORKSRV_FD + 1, &status, 4) != 4) _exit(1);
+    if (write(FORKSRV_FD + 1, &status, 4) != 4) {
+
+      write_error("writing to afl-fuzz");
+      _exit(1);
+
+    }
 
   }
 
@@ -955,7 +1014,12 @@ static void __afl_start_forkserver(void) {
 
     } else {
 
-      if (read(FORKSRV_FD, &was_killed, 4) != 4) _exit(1);
+      if (read(FORKSRV_FD, &was_killed, 4) != 4) {
+
+        // write_error("read from afl-fuzz");
+        _exit(1);
+
+      }
 
     }
 
@@ -992,7 +1056,12 @@ static void __afl_start_forkserver(void) {
     if (child_stopped && was_killed) {
 
       child_stopped = 0;
-      if (waitpid(child_pid, &status, 0) < 0) _exit(1);
+      if (waitpid(child_pid, &status, 0) < 0) {
+
+        write_error("child_stopped && was_killed");
+        _exit(1);
+
+      }
 
     }
 
@@ -1001,7 +1070,12 @@ static void __afl_start_forkserver(void) {
       /* Once woken up, create a clone of our process. */
 
       child_pid = fork();
-      if (child_pid < 0) _exit(1);
+      if (child_pid < 0) {
+
+        write_error("fork");
+        _exit(1);
+
+      }
 
       /* In child process: close fds, resume execution. */
 
@@ -1030,11 +1104,20 @@ static void __afl_start_forkserver(void) {
 
     /* In parent process: write PID to pipe, then wait for child. */
 
-    if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) _exit(1);
+    if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) {
 
-    if (waitpid(child_pid, &status, is_persistent ? WUNTRACED : 0) < 0)
+      write_error("write to afl-fuzz");
       _exit(1);
 
+    }
+
+    if (waitpid(child_pid, &status, is_persistent ? WUNTRACED : 0) < 0) {
+
+      write_error("waitpid");
+      _exit(1);
+
+    }
+
     /* In persistent mode, the child stops itself with SIGSTOP to indicate
        a successful run. In this case, we want to wake it up without forking
        again. */
@@ -1043,7 +1126,12 @@ static void __afl_start_forkserver(void) {
 
     /* Relay wait status to pipe, then loop back. */
 
-    if (write(FORKSRV_FD + 1, &status, 4) != 4) _exit(1);
+    if (write(FORKSRV_FD + 1, &status, 4) != 4) {
+
+      write_error("writing to afl-fuzz");
+      _exit(1);
+
+    }
 
   }
 
@@ -1599,7 +1687,7 @@ void __cmplog_ins_hookN(uint128_t arg1, uint128_t arg2, uint8_t attr,
 
 void __cmplog_ins_hook16(uint128_t arg1, uint128_t arg2, uint8_t attr) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
 
   uintptr_t k = (uintptr_t)__builtin_return_address(0);
   k = (k >> 4) ^ (k << 8);
@@ -1668,7 +1756,7 @@ void __sanitizer_cov_trace_cmp4(uint32_t arg1, uint32_t arg2) {
 
 }
 
-void __sanitizer_cov_trace_cost_cmp4(uint32_t arg1, uint32_t arg2) {
+void __sanitizer_cov_trace_const_cmp4(uint32_t arg1, uint32_t arg2) {
 
   __cmplog_ins_hook4(arg1, arg2, 0);
 
@@ -1703,7 +1791,7 @@ void __sanitizer_cov_trace_const_cmp16(uint128_t arg1, uint128_t arg2) {
 
 void __sanitizer_cov_trace_switch(uint64_t val, uint64_t *cases) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
 
   for (uint64_t i = 0; i < cases[0]; i++) {
 
@@ -1800,7 +1888,7 @@ void __cmplog_rtn_hook(u8 *ptr1, u8 *ptr2) {
     fprintf(stderr, "\n");
   */
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
   // fprintf(stderr, "RTN1 %p %p\n", ptr1, ptr2);
   int l1, l2;
   if ((l1 = area_is_valid(ptr1, 32)) <= 0 ||
@@ -1884,7 +1972,7 @@ static u8 *get_llvm_stdstring(u8 *string) {
 
 void __cmplog_rtn_gcc_stdstring_cstring(u8 *stdstring, u8 *cstring) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
   if (area_is_valid(stdstring, 32) <= 0 || area_is_valid(cstring, 32) <= 0)
     return;
 
@@ -1894,7 +1982,7 @@ void __cmplog_rtn_gcc_stdstring_cstring(u8 *stdstring, u8 *cstring) {
 
 void __cmplog_rtn_gcc_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
   if (area_is_valid(stdstring1, 32) <= 0 || area_is_valid(stdstring2, 32) <= 0)
     return;
 
@@ -1905,7 +1993,7 @@ void __cmplog_rtn_gcc_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) {
 
 void __cmplog_rtn_llvm_stdstring_cstring(u8 *stdstring, u8 *cstring) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
   if (area_is_valid(stdstring, 32) <= 0 || area_is_valid(cstring, 32) <= 0)
     return;
 
@@ -1915,7 +2003,7 @@ void __cmplog_rtn_llvm_stdstring_cstring(u8 *stdstring, u8 *cstring) {
 
 void __cmplog_rtn_llvm_stdstring_stdstring(u8 *stdstring1, u8 *stdstring2) {
 
-  if (unlikely(!__afl_cmp_map)) return;
+  if (likely(!__afl_cmp_map)) return;
   if (area_is_valid(stdstring1, 32) <= 0 || area_is_valid(stdstring2, 32) <= 0)
     return;
 
@@ -1949,7 +2037,7 @@ void __afl_coverage_on() {
   if (likely(__afl_selective_coverage && __afl_selective_coverage_temp)) {
 
     __afl_area_ptr = __afl_area_ptr_backup;
-    __afl_cmp_map = __afl_cmp_map_backup;
+    if (__afl_cmp_map_backup) { __afl_cmp_map = __afl_cmp_map_backup; }
 
   }
 
@@ -1990,3 +2078,5 @@ void __afl_coverage_interesting(u8 val, u32 id) {
 
 }
 
+#undef write_error
+
diff --git a/instrumentation/afl-llvm-common.cc b/instrumentation/afl-llvm-common.cc
index af32e2f9..3239ea91 100644
--- a/instrumentation/afl-llvm-common.cc
+++ b/instrumentation/afl-llvm-common.cc
@@ -96,9 +96,8 @@ bool isIgnoreFunction(const llvm::Function *F) {
 
   static constexpr const char *ignoreSubstringList[] = {
 
-      "__asan",       "__msan",     "__ubsan", "__lsan",
-      "__san",        "__sanitize", "__cxx",   "_GLOBAL__",
-      "DebugCounter", "DwarfDebug", "DebugLoc"
+      "__asan", "__msan",       "__ubsan",    "__lsan",  "__san", "__sanitize",
+      "__cxx",  "DebugCounter", "DwarfDebug", "DebugLoc"
 
   };
 
diff --git a/instrumentation/afl-llvm-dict2file.so.cc b/instrumentation/afl-llvm-dict2file.so.cc
index e2b44b21..58f01920 100644
--- a/instrumentation/afl-llvm-dict2file.so.cc
+++ b/instrumentation/afl-llvm-dict2file.so.cc
@@ -154,6 +154,7 @@ bool AFLdict2filePass::runOnModule(Module &M) {
   for (auto &F : M) {
 
     if (isIgnoreFunction(&F)) continue;
+    if (!isInInstrumentList(&F) || !F.size()) { continue; }
 
     /*  Some implementation notes.
      *
@@ -428,6 +429,12 @@ bool AFLdict2filePass::runOnModule(Module &M) {
 
                 uint64_t literalLength = Str2.length();
                 uint64_t optLength = ilen->getZExtValue();
+                if (optLength > literalLength + 1) {
+
+                  optLength = Str2.length() + 1;
+
+                }
+
                 if (literalLength + 1 == optLength) {
 
                   Str2.append("\0", 1);  // add null byte
diff --git a/instrumentation/afl-llvm-lto-instrumentation.so.cc b/instrumentation/afl-llvm-lto-instrumentation.so.cc
index fe43fbe5..46aa388e 100644
--- a/instrumentation/afl-llvm-lto-instrumentation.so.cc
+++ b/instrumentation/afl-llvm-lto-instrumentation.so.cc
@@ -546,6 +546,12 @@ bool AFLLTOPass::runOnModule(Module &M) {
 
                   uint64_t literalLength = Str2.size();
                   uint64_t optLength = ilen->getZExtValue();
+                  if (optLength > literalLength + 1) {
+
+                    optLength = Str2.length() + 1;
+
+                  }
+
                   if (literalLength + 1 == optLength) {
 
                     Str2.append("\0", 1);  // add null byte
@@ -649,6 +655,7 @@ bool AFLLTOPass::runOnModule(Module &M) {
 
                 uint64_t literalLength = optLen;
                 optLen = ilen->getZExtValue();
+                if (optLen > literalLength + 1) { optLen = literalLength + 1; }
                 if (optLen < 2) { continue; }
                 if (literalLength + 1 == optLen) {  // add null byte
                   thestring.append("\0", 1);
diff --git a/instrumentation/afl-llvm-pass.so.cc b/instrumentation/afl-llvm-pass.so.cc
index a8f1baff..b673d815 100644
--- a/instrumentation/afl-llvm-pass.so.cc
+++ b/instrumentation/afl-llvm-pass.so.cc
@@ -438,9 +438,9 @@ bool AFLCoverage::runOnModule(Module &M) {
       fprintf(stderr, "FUNCTION: %s (%zu)\n", F.getName().str().c_str(),
               F.size());
 
-    if (!isInInstrumentList(&F)) continue;
+    if (!isInInstrumentList(&F)) { continue; }
 
-    if (F.size() < function_minimum_size) continue;
+    if (F.size() < function_minimum_size) { continue; }
 
     std::list<Value *> todo;
     for (auto &BB : F) {
diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc
index ad334d3b..0562c5b2 100644
--- a/instrumentation/cmplog-instructions-pass.cc
+++ b/instrumentation/cmplog-instructions-pass.cc
@@ -104,7 +104,6 @@ Iterator Unique(Iterator first, Iterator last) {
 bool CmpLogInstructions::hookInstrs(Module &M) {
 
   std::vector<Instruction *> icomps;
-  std::vector<SwitchInst *>  switches;
   LLVMContext &              C = M.getContext();
 
   Type *       VoidTy = Type::getVoidTy(C);
@@ -222,6 +221,18 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
   FunctionCallee cmplogHookInsN = cN;
 #endif
 
+  GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
+
+  if (!AFLCmplogPtr) {
+
+    AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
+                                      GlobalValue::ExternalWeakLinkage, 0,
+                                      "__afl_cmp_map");
+
+  }
+
+  Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
+
   /* iterate over all functions, bbs and instruction and add suitable calls */
   for (auto &F : M) {
 
@@ -238,164 +249,6 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
         }
 
-        SwitchInst *switchInst = nullptr;
-        if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
-
-          if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
-
-        }
-
-      }
-
-    }
-
-  }
-
-  // unique the collected switches
-  switches.erase(Unique(switches.begin(), switches.end()), switches.end());
-
-  // Instrument switch values for cmplog
-  if (switches.size()) {
-
-    if (!be_quiet)
-      errs() << "Hooking " << switches.size() << " switch instructions\n";
-
-    for (auto &SI : switches) {
-
-      Value *       Val = SI->getCondition();
-      unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
-      unsigned char do_cast = 0;
-
-      if (!SI->getNumCases() || max_size < 16) {
-
-        // if (!be_quiet) errs() << "skip trivial switch..\n";
-        continue;
-
-      }
-
-      if (max_size % 8) {
-
-        max_size = (((max_size / 8) + 1) * 8);
-        do_cast = 1;
-
-      }
-
-      IRBuilder<> IRB(SI->getParent());
-      IRB.SetInsertPoint(SI);
-
-      if (max_size > 128) {
-
-        if (!be_quiet) {
-
-          fprintf(stderr,
-                  "Cannot handle this switch bit size: %u (truncating)\n",
-                  max_size);
-
-        }
-
-        max_size = 128;
-        do_cast = 1;
-
-      }
-
-      // do we need to cast?
-      switch (max_size) {
-
-        case 8:
-        case 16:
-        case 32:
-        case 64:
-        case 128:
-          cast_size = max_size;
-          break;
-        default:
-          cast_size = 128;
-          do_cast = 1;
-
-      }
-
-      Value *CompareTo = Val;
-
-      if (do_cast) {
-
-        CompareTo =
-            IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false);
-
-      }
-
-      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
-           ++i) {
-
-#if LLVM_VERSION_MAJOR < 5
-        ConstantInt *cint = i.getCaseValue();
-#else
-        ConstantInt *cint = i->getCaseValue();
-#endif
-
-        if (cint) {
-
-          std::vector<Value *> args;
-          args.push_back(CompareTo);
-
-          Value *new_param = cint;
-
-          if (do_cast) {
-
-            new_param =
-                IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false);
-
-          }
-
-          if (new_param) {
-
-            args.push_back(new_param);
-            ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
-            args.push_back(attribute);
-            if (cast_size != max_size) {
-
-              ConstantInt *bitsize =
-                  ConstantInt::get(Int8Ty, (max_size / 8) - 1);
-              args.push_back(bitsize);
-
-            }
-
-            switch (cast_size) {
-
-              case 8:
-                IRB.CreateCall(cmplogHookIns1, args);
-                break;
-              case 16:
-                IRB.CreateCall(cmplogHookIns2, args);
-                break;
-              case 32:
-                IRB.CreateCall(cmplogHookIns4, args);
-                break;
-              case 64:
-                IRB.CreateCall(cmplogHookIns8, args);
-                break;
-              case 128:
-#ifdef WORD_SIZE_64
-                if (max_size == 128) {
-
-                  IRB.CreateCall(cmplogHookIns16, args);
-
-                } else {
-
-                  IRB.CreateCall(cmplogHookInsN, args);
-
-                }
-
-#endif
-                break;
-              default:
-                break;
-
-            }
-
-          }
-
-        }
-
       }
 
     }
@@ -409,8 +262,15 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
     for (auto &selectcmpInst : icomps) {
 
-      IRBuilder<> IRB(selectcmpInst->getParent());
-      IRB.SetInsertPoint(selectcmpInst);
+      IRBuilder<> IRB2(selectcmpInst->getParent());
+      IRB2.SetInsertPoint(selectcmpInst);
+      LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+      CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+      auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+      auto ThenTerm =
+          SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false);
+
+      IRBuilder<> IRB(ThenTerm);
 
       Value *op0 = selectcmpInst->getOperand(0);
       Value *op1 = selectcmpInst->getOperand(1);
@@ -601,7 +461,7 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
   }
 
-  if (switches.size() || icomps.size())
+  if (icomps.size())
     return true;
   else
     return false;
diff --git a/instrumentation/cmplog-routines-pass.cc b/instrumentation/cmplog-routines-pass.cc
index a5992c9a..1e2610f2 100644
--- a/instrumentation/cmplog-routines-pass.cc
+++ b/instrumentation/cmplog-routines-pass.cc
@@ -184,6 +184,18 @@ bool CmpLogRoutines::hookRtns(Module &M) {
   FunctionCallee cmplogGccStdC = c4;
 #endif
 
+  GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
+
+  if (!AFLCmplogPtr) {
+
+    AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
+                                      GlobalValue::ExternalWeakLinkage, 0,
+                                      "__afl_cmp_map");
+
+  }
+
+  Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
+
   /* iterate over all functions, bbs and instruction and add suitable calls */
   for (auto &F : M) {
 
@@ -289,8 +301,15 @@ bool CmpLogRoutines::hookRtns(Module &M) {
 
     Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
 
-    IRBuilder<> IRB(callInst->getParent());
-    IRB.SetInsertPoint(callInst);
+    IRBuilder<> IRB2(callInst->getParent());
+    IRB2.SetInsertPoint(callInst);
+
+    LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+    CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+    auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+    auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false);
+
+    IRBuilder<> IRB(ThenTerm);
 
     std::vector<Value *> args;
     Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
@@ -308,8 +327,15 @@ bool CmpLogRoutines::hookRtns(Module &M) {
 
     Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
 
-    IRBuilder<> IRB(callInst->getParent());
-    IRB.SetInsertPoint(callInst);
+    IRBuilder<> IRB2(callInst->getParent());
+    IRB2.SetInsertPoint(callInst);
+
+    LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+    CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+    auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+    auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false);
+
+    IRBuilder<> IRB(ThenTerm);
 
     std::vector<Value *> args;
     Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
@@ -327,8 +353,15 @@ bool CmpLogRoutines::hookRtns(Module &M) {
 
     Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
 
-    IRBuilder<> IRB(callInst->getParent());
-    IRB.SetInsertPoint(callInst);
+    IRBuilder<> IRB2(callInst->getParent());
+    IRB2.SetInsertPoint(callInst);
+
+    LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+    CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+    auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+    auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false);
+
+    IRBuilder<> IRB(ThenTerm);
 
     std::vector<Value *> args;
     Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
@@ -346,8 +379,15 @@ bool CmpLogRoutines::hookRtns(Module &M) {
 
     Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
 
-    IRBuilder<> IRB(callInst->getParent());
-    IRB.SetInsertPoint(callInst);
+    IRBuilder<> IRB2(callInst->getParent());
+    IRB2.SetInsertPoint(callInst);
+
+    LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+    CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+    auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+    auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false);
+
+    IRBuilder<> IRB(ThenTerm);
 
     std::vector<Value *> args;
     Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
@@ -365,8 +405,15 @@ bool CmpLogRoutines::hookRtns(Module &M) {
 
     Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
 
-    IRBuilder<> IRB(callInst->getParent());
-    IRB.SetInsertPoint(callInst);
+    IRBuilder<> IRB2(callInst->getParent());
+    IRB2.SetInsertPoint(callInst);
+
+    LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+    CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+    auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+    auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false);
+
+    IRBuilder<> IRB(ThenTerm);
 
     std::vector<Value *> args;
     Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
diff --git a/instrumentation/cmplog-switches-pass.cc b/instrumentation/cmplog-switches-pass.cc
new file mode 100644
index 00000000..c42d44fe
--- /dev/null
+++ b/instrumentation/cmplog-switches-pass.cc
@@ -0,0 +1,414 @@
+/*
+   american fuzzy lop++ - LLVM CmpLog instrumentation
+   --------------------------------------------------
+
+   Written by Andrea Fioraldi <andreafioraldi@gmail.com>
+
+   Copyright 2015, 2016 Google Inc. All rights reserved.
+   Copyright 2019-2020 AFLplusplus Project. All rights reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at:
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+*/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <list>
+#include <string>
+#include <fstream>
+#include <sys/time.h>
+
+#include "llvm/Config/llvm-config.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/ValueTracking.h"
+
+#if LLVM_VERSION_MAJOR > 3 || \
+    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
+  #include "llvm/IR/Verifier.h"
+  #include "llvm/IR/DebugInfo.h"
+#else
+  #include "llvm/Analysis/Verifier.h"
+  #include "llvm/DebugInfo.h"
+  #define nullptr 0
+#endif
+
+#include <set>
+#include "afl-llvm-common.h"
+
+using namespace llvm;
+
+namespace {
+
+class CmpLogInstructions : public ModulePass {
+
+ public:
+  static char ID;
+  CmpLogInstructions() : ModulePass(ID) {
+
+    initInstrumentList();
+
+  }
+
+  bool runOnModule(Module &M) override;
+
+#if LLVM_VERSION_MAJOR < 4
+  const char *getPassName() const override {
+
+#else
+  StringRef getPassName() const override {
+
+#endif
+    return "cmplog instructions";
+
+  }
+
+ private:
+  bool hookInstrs(Module &M);
+
+};
+
+}  // namespace
+
+char CmpLogInstructions::ID = 0;
+
+template <class Iterator>
+Iterator Unique(Iterator first, Iterator last) {
+
+  while (first != last) {
+
+    Iterator next(first);
+    last = std::remove(++next, last, *first);
+    first = next;
+
+  }
+
+  return last;
+
+}
+
+bool CmpLogInstructions::hookInstrs(Module &M) {
+
+  std::vector<SwitchInst *> switches;
+  LLVMContext &             C = M.getContext();
+
+  Type *       VoidTy = Type::getVoidTy(C);
+  IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
+  IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
+  IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
+  IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
+                                 Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                 ,
+                                 NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookIns1 = cast<Function>(c1);
+#else
+  FunctionCallee cmplogHookIns1 = c1;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
+                                 Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                 ,
+                                 NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookIns2 = cast<Function>(c2);
+#else
+  FunctionCallee cmplogHookIns2 = c2;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
+                                 Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                 ,
+                                 NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookIns4 = cast<Function>(c4);
+#else
+  FunctionCallee cmplogHookIns4 = c4;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
+                                 Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                 ,
+                                 NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookIns8 = cast<Function>(c8);
+#else
+  FunctionCallee cmplogHookIns8 = c8;
+#endif
+
+  GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
+
+  if (!AFLCmplogPtr) {
+
+    AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
+                                      GlobalValue::ExternalWeakLinkage, 0,
+                                      "__afl_cmp_map");
+
+  }
+
+  Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
+
+  /* iterate over all functions, bbs and instruction and add suitable calls */
+  for (auto &F : M) {
+
+    if (!isInInstrumentList(&F)) continue;
+
+    for (auto &BB : F) {
+
+      SwitchInst *switchInst = nullptr;
+      if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
+
+        if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
+
+      }
+
+    }
+
+  }
+
+  // unique the collected switches
+  switches.erase(Unique(switches.begin(), switches.end()), switches.end());
+
+  // Instrument switch values for cmplog
+  if (switches.size()) {
+
+    if (!be_quiet)
+      errs() << "Hooking " << switches.size() << " switch instructions\n";
+
+    for (auto &SI : switches) {
+
+      Value *       Val = SI->getCondition();
+      unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
+      unsigned char do_cast = 0;
+
+      if (!SI->getNumCases() || max_size < 16) {
+
+        // if (!be_quiet) errs() << "skip trivial switch..\n";
+        continue;
+
+      }
+
+      if (max_size % 8) {
+
+        max_size = (((max_size / 8) + 1) * 8);
+        do_cast = 1;
+
+      }
+
+      IRBuilder<> IRB2(SI->getParent());
+      IRB2.SetInsertPoint(SI);
+
+      LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+      CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+      auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+      auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false);
+
+      IRBuilder<> IRB(ThenTerm);
+
+      if (max_size > 128) {
+
+        if (!be_quiet) {
+
+          fprintf(stderr,
+                  "Cannot handle this switch bit size: %u (truncating)\n",
+                  max_size);
+
+        }
+
+        max_size = 128;
+        do_cast = 1;
+
+      }
+
+      // do we need to cast?
+      switch (max_size) {
+
+        case 8:
+        case 16:
+        case 32:
+        case 64:
+        case 128:
+          cast_size = max_size;
+          break;
+        default:
+          cast_size = 128;
+          do_cast = 1;
+
+      }
+
+      Value *CompareTo = Val;
+
+      if (do_cast) {
+
+        CompareTo =
+            IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false);
+
+      }
+
+      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
+           ++i) {
+
+#if LLVM_VERSION_MAJOR < 5
+        ConstantInt *cint = i.getCaseValue();
+#else
+        ConstantInt *cint = i->getCaseValue();
+#endif
+
+        if (cint) {
+
+          std::vector<Value *> args;
+          args.push_back(CompareTo);
+
+          Value *new_param = cint;
+
+          if (do_cast) {
+
+            new_param =
+                IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false);
+
+          }
+
+          if (new_param) {
+
+            args.push_back(new_param);
+            ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
+            args.push_back(attribute);
+            if (cast_size != max_size) {
+
+              ConstantInt *bitsize =
+                  ConstantInt::get(Int8Ty, (max_size / 8) - 1);
+              args.push_back(bitsize);
+
+            }
+
+            switch (cast_size) {
+
+              case 8:
+                IRB.CreateCall(cmplogHookIns1, args);
+                break;
+              case 16:
+                IRB.CreateCall(cmplogHookIns2, args);
+                break;
+              case 32:
+                IRB.CreateCall(cmplogHookIns4, args);
+                break;
+              case 64:
+                IRB.CreateCall(cmplogHookIns8, args);
+                break;
+              case 128:
+#ifdef WORD_SIZE_64
+                if (max_size == 128) {
+
+                  IRB.CreateCall(cmplogHookIns16, args);
+
+                } else {
+
+                  IRB.CreateCall(cmplogHookInsN, args);
+
+                }
+
+#endif
+                break;
+              default:
+                break;
+
+            }
+
+          }
+
+        }
+
+      }
+
+    }
+
+  }
+
+  if (switches.size())
+    return true;
+  else
+    return false;
+
+}
+
+bool CmpLogInstructions::runOnModule(Module &M) {
+
+  if (getenv("AFL_QUIET") == NULL)
+    printf("Running cmplog-switches-pass by andreafioraldi@gmail.com\n");
+  else
+    be_quiet = 1;
+  hookInstrs(M);
+  verifyModule(M);
+
+  return true;
+
+}
+
+static void registerCmpLogInstructionsPass(const PassManagerBuilder &,
+                                           legacy::PassManagerBase &PM) {
+
+  auto p = new CmpLogInstructions();
+  PM.add(p);
+
+}
+
+static RegisterStandardPasses RegisterCmpLogInstructionsPass(
+    PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass);
+
+static RegisterStandardPasses RegisterCmpLogInstructionsPass0(
+    PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass);
+
+#if LLVM_VERSION_MAJOR >= 11
+static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO(
+    PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
+    registerCmpLogInstructionsPass);
+#endif
+
diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc
index 3ecba4e6..f5dd4a53 100644
--- a/instrumentation/compare-transform-pass.so.cc
+++ b/instrumentation/compare-transform-pass.so.cc
@@ -313,27 +313,18 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
             ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
             if (ilen) {
 
-              uint64_t len = ilen->getZExtValue();
               // if len is zero this is a pointless call but allow real
               // implementation to worry about that
-              if (len < 2) continue;
+              if (ilen->getZExtValue() < 2) { continue; }
 
-              if (isMemcmp) {
-
-                // if size of compare is larger than constant string this is
-                // likely a bug but allow real implementation to worry about
-                // that
-                uint64_t literalLength = HasStr1 ? Str1.size() : Str2.size();
-                if (literalLength + 1 < ilen->getZExtValue()) continue;
-
-              }
-
-            } else if (isMemcmp)
+            } else if (isMemcmp) {
 
               // this *may* supply a len greater than the constant string at
               // runtime so similarly we don't want to have to handle that
               continue;
 
+            }
+
           }
 
           calls.push_back(callInst);
@@ -421,7 +412,7 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     }
 
     if (TmpConstStr.length() < 2 ||
-        (TmpConstStr.length() == 2 && !TmpConstStr[1])) {
+        (TmpConstStr.length() == 2 && TmpConstStr[1] == 0)) {
 
       continue;
 
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc
index b02a89fb..13f45b69 100644
--- a/instrumentation/split-compares-pass.so.cc
+++ b/instrumentation/split-compares-pass.so.cc
@@ -47,6 +47,10 @@
 using namespace llvm;
 #include "afl-llvm-common.h"
 
+// uncomment this toggle function verification at each step. horribly slow, but
+// helps to pinpoint a potential problem in the splitting code.
+//#define VERIFY_TOO_MUCH 1
+
 namespace {
 
 class SplitComparesTransform : public ModulePass {
@@ -67,28 +71,101 @@ class SplitComparesTransform : public ModulePass {
   const char *getPassName() const override {
 
 #endif
-    return "simplifies and splits ICMP instructions";
+    return "AFL_SplitComparesTransform";
 
   }
 
  private:
   int enableFPSplit;
 
-  size_t splitIntCompares(Module &M, unsigned bitw);
+  unsigned target_bitwidth = 8;
+
+  size_t count = 0;
+
   size_t splitFPCompares(Module &M);
-  bool   simplifyCompares(Module &M);
   bool   simplifyFPCompares(Module &M);
-  bool   simplifyIntSignedness(Module &M);
   size_t nextPowerOfTwo(size_t in);
 
+  using CmpWorklist = SmallVector<CmpInst *, 8>;
+
+  /// simplify the comparison and then split the comparison until the
+  /// target_bitwidth is reached.
+  bool simplifyAndSplit(CmpInst *I, Module &M);
+  /// simplify a non-strict comparison (e.g., less than or equals)
+  bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M,
+                               CmpWorklist &worklist);
+  /// simplify a signed comparison (signed less or greater than)
+  bool simplifySignedCompare(CmpInst *IcmpInst, Module &M,
+                             CmpWorklist &worklist);
+  /// splits an icmp into nested icmps recursivly until target_bitwidth is
+  /// reached
+  bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist);
+
+  /// print an error to llvm's errs stream, but only if not ordered to be quiet
+  void reportError(const StringRef msg, Instruction *I, Module &M) {
+
+    if (!be_quiet) {
+
+      errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n";
+      if (debug) {
+
+        if (I) {
+
+          errs() << "Instruction = " << *I << "\n";
+          if (auto BB = I->getParent()) {
+
+            if (auto F = BB->getParent()) {
+
+              if (F->hasName()) {
+
+                errs() << "|-> in function " << F->getName() << " ";
+
+              }
+
+            }
+
+          }
+
+        }
+
+        auto n = M.getName();
+        if (n.size() > 0) { errs() << "in module " << n << "\n"; }
+
+      }
+
+    }
+
+  }
+
+  bool isSupportedBitWidth(unsigned bitw) {
+
+    // IDK whether the icmp code works on other bitwidths. I guess not? So we
+    // try to avoid dealing with other weird icmp's that llvm might use (looking
+    // at you `icmp i0`).
+    switch (bitw) {
+
+      case 8:
+      case 16:
+      case 32:
+      case 64:
+      case 128:
+      case 256:
+        return true;
+      default:
+        return false;
+
+    }
+
+  }
+
 };
 
 }  // namespace
 
 char SplitComparesTransform::ID = 0;
 
-/* This function splits FCMP instructions with xGE or xLE predicates into two
- * FCMP instructions with predicate xGT or xLT and EQ */
+/// This function splits FCMP instructions with xGE or xLE predicates into two
+/// FCMP instructions with predicate xGT or xLT and EQ
 bool SplitComparesTransform::simplifyFPCompares(Module &M) {
 
   LLVMContext &              C = M.getContext();
@@ -221,292 +298,481 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
 
 }
 
-/* This function splits ICMP instructions with xGE or xLE predicates into two
- * ICMP instructions with predicate xGT or xLT and EQ */
-bool SplitComparesTransform::simplifyCompares(Module &M) {
+/// This function splits ICMP instructions with xGE or xLE predicates into two
+/// ICMP instructions with predicate xGT or xLT and EQ
+bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst *    IcmpInst,
+                                                     Module &     M,
+                                                     CmpWorklist &worklist) {
 
-  LLVMContext &              C = M.getContext();
-  std::vector<Instruction *> icomps;
-  IntegerType *              Int1Ty = IntegerType::getInt1Ty(C);
+  LLVMContext &C = M.getContext();
+  IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
 
-  /* iterate over all functions, bbs and instruction and add
-   * all integer comparisons with >= and <= predicates to the icomps vector */
-  for (auto &F : M) {
+  /* find out what the new predicate is going to be */
+  auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+  if (!cmp_inst) { return false; }
 
-    if (!isInInstrumentList(&F)) continue;
+  BasicBlock *bb = IcmpInst->getParent();
 
-    for (auto &BB : F) {
+  auto op0 = IcmpInst->getOperand(0);
+  auto op1 = IcmpInst->getOperand(1);
 
-      for (auto &IN : BB) {
+  CmpInst::Predicate pred = cmp_inst->getPredicate();
+  CmpInst::Predicate new_pred;
 
-        CmpInst *selectcmpInst = nullptr;
+  switch (pred) {
 
-        if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+    case CmpInst::ICMP_UGE:
+      new_pred = CmpInst::ICMP_UGT;
+      break;
+    case CmpInst::ICMP_SGE:
+      new_pred = CmpInst::ICMP_SGT;
+      break;
+    case CmpInst::ICMP_ULE:
+      new_pred = CmpInst::ICMP_ULT;
+      break;
+    case CmpInst::ICMP_SLE:
+      new_pred = CmpInst::ICMP_SLT;
+      break;
+    default:  // keep the compiler happy
+      return false;
 
-          if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) {
+  }
 
-            auto op0 = selectcmpInst->getOperand(0);
-            auto op1 = selectcmpInst->getOperand(1);
+  /* split before the icmp instruction */
+  BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+  /* the old bb now contains a unconditional jump to the new one (end_bb)
+   * we need to delete it later */
+
+  /* create the ICMP instruction with new_pred and add it to the old basic
+   * block bb it is now at the position where the old IcmpInst was */
+  CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
+  bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np);
+
+  /* create a new basic block which holds the new EQ icmp */
+  CmpInst *icmp_eq;
+  /* insert middle_bb before end_bb */
+  BasicBlock *middle_bb =
+      BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+  icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
+  middle_bb->getInstList().push_back(icmp_eq);
+  /* add an unconditional branch to the end of middle_bb with destination
+   * end_bb */
+  BranchInst::Create(end_bb, middle_bb);
+
+  /* replace the uncond branch with a conditional one, which depends on the
+   * new_pred icmp. True goes to end, false to the middle (injected) bb */
+  auto term = bb->getTerminator();
+  BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
+  term->eraseFromParent();
+
+  /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
+   * inst to wire up the loose ends */
+  PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+  /* the first result depends on the outcome of icmp_eq */
+  PN->addIncoming(icmp_eq, middle_bb);
+  /* if the source was the original bb we know that the icmp_np yielded true
+   * hence we can hardcode this value */
+  PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+  /* replace the old IcmpInst with our new and shiny PHI inst */
+  BasicBlock::iterator ii(IcmpInst);
+  ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+
+  worklist.push_back(icmp_np);
+  worklist.push_back(icmp_eq);
 
-            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+  return true;
 
-            /* this is probably not needed but we do it anyway */
-            if (!intTyOp0 || !intTyOp1) { continue; }
+}
 
-            icomps.push_back(selectcmpInst);
+/// Simplify a signed comparison operator by splitting it into a unsigned and
+/// bit comparison. add all resulting comparisons to
+/// the worklist passed as a reference.
+bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M,
+                                                   CmpWorklist &worklist) {
 
-          }
+  LLVMContext &C = M.getContext();
+  IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
 
-        }
+  BasicBlock *bb = IcmpInst->getParent();
 
-      }
+  auto op0 = IcmpInst->getOperand(0);
+  auto op1 = IcmpInst->getOperand(1);
 
-    }
+  IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+  if (!intTyOp0) { return false; }
+  unsigned     bitw = intTyOp0->getBitWidth();
+  IntegerType *IntType = IntegerType::get(C, bitw);
 
-  }
+  /* get the new predicate */
+  auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+  if (!cmp_inst) { return false; }
+  auto               pred = cmp_inst->getPredicate();
+  CmpInst::Predicate new_pred;
 
-  if (!icomps.size()) { return false; }
+  if (pred == CmpInst::ICMP_SGT) {
 
-  for (auto &IcmpInst : icomps) {
+    new_pred = CmpInst::ICMP_UGT;
 
-    BasicBlock *bb = IcmpInst->getParent();
+  } else {
 
-    auto op0 = IcmpInst->getOperand(0);
-    auto op1 = IcmpInst->getOperand(1);
+    new_pred = CmpInst::ICMP_ULT;
 
-    /* find out what the new predicate is going to be */
-    auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
-    if (!cmp_inst) { continue; }
-    auto               pred = cmp_inst->getPredicate();
-    CmpInst::Predicate new_pred;
+  }
 
-    switch (pred) {
+  BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+  /* create a 1 bit compare for the sign bit. to do this shift and trunc
+   * the original operands so only the first bit remains.*/
+  Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
+
+  IRBuilder<> IRB(bb->getTerminator());
+  s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1));
+  t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty);
+  s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1));
+  t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty);
+  /* compare of the sign bits */
+  icmp_sign_bit = IRB.CreateICmp(CmpInst::ICMP_EQ, t_op0, t_op1);
+
+  /* create a new basic block which is executed if the signedness bit is
+   * different */
+  CmpInst *   icmp_inv_sig_cmp;
+  BasicBlock *sign_bb =
+      BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
+  if (pred == CmpInst::ICMP_SGT) {
+
+    /* if we check for > and the op0 positive and op1 negative then the final
+     * result is true. if op0 negative and op1 pos, the cmp must result
+     * in false
+     */
+    icmp_inv_sig_cmp =
+        CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
 
-      case CmpInst::ICMP_UGE:
-        new_pred = CmpInst::ICMP_UGT;
-        break;
-      case CmpInst::ICMP_SGE:
-        new_pred = CmpInst::ICMP_SGT;
-        break;
-      case CmpInst::ICMP_ULE:
-        new_pred = CmpInst::ICMP_ULT;
-        break;
-      case CmpInst::ICMP_SLE:
-        new_pred = CmpInst::ICMP_SLT;
-        break;
-      default:  // keep the compiler happy
-        continue;
+  } else {
 
-    }
+    /* just the inverse of the above statement */
+    icmp_inv_sig_cmp =
+        CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
 
-    /* split before the icmp instruction */
-    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+  }
 
-    /* the old bb now contains a unconditional jump to the new one (end_bb)
-     * we need to delete it later */
+  sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
+  BranchInst::Create(end_bb, sign_bb);
 
-    /* create the ICMP instruction with new_pred and add it to the old basic
-     * block bb it is now at the position where the old IcmpInst was */
-    Instruction *icmp_np;
-    icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
-                             icmp_np);
+  /* create a new bb which is executed if signedness is equal */
+  CmpInst *   icmp_usign_cmp;
+  BasicBlock *middle_bb =
+      BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+  /* we can do a normal unsigned compare now */
+  icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
 
-    /* create a new basic block which holds the new EQ icmp */
-    Instruction *icmp_eq;
-    /* insert middle_bb before end_bb */
-    BasicBlock *middle_bb =
-        BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
-    icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
-    middle_bb->getInstList().push_back(icmp_eq);
-    /* add an unconditional branch to the end of middle_bb with destination
-     * end_bb */
-    BranchInst::Create(end_bb, middle_bb);
+  middle_bb->getInstList().push_back(icmp_usign_cmp);
+  BranchInst::Create(end_bb, middle_bb);
 
-    /* replace the uncond branch with a conditional one, which depends on the
-     * new_pred icmp. True goes to end, false to the middle (injected) bb */
-    auto term = bb->getTerminator();
-    BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
-    term->eraseFromParent();
+  auto term = bb->getTerminator();
+  /* if the sign is eq do a normal unsigned cmp, else we have to check the
+   * signedness bit */
+  BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
+  term->eraseFromParent();
 
-    /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
-     * inst to wire up the loose ends */
-    PHINode *PN = PHINode::Create(Int1Ty, 2, "");
-    /* the first result depends on the outcome of icmp_eq */
-    PN->addIncoming(icmp_eq, middle_bb);
-    /* if the source was the original bb we know that the icmp_np yielded true
-     * hence we can hardcode this value */
-    PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
-    /* replace the old IcmpInst with our new and shiny PHI inst */
-    BasicBlock::iterator ii(IcmpInst);
-    ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+  PHINode *PN = PHINode::Create(Int1Ty, 2, "");
 
-  }
+  PN->addIncoming(icmp_usign_cmp, middle_bb);
+  PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
+
+  BasicBlock::iterator ii(IcmpInst);
+  ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+
+  // save for later
+  worklist.push_back(icmp_usign_cmp);
+
+  // signed comparisons are not supported by the splitting code, so we must not
+  // add it to the worklist.
+  // worklist.push_back(icmp_inv_sig_cmp);
 
   return true;
 
 }
 
-/* this function transforms signed compares to equivalent unsigned compares */
-bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
+bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M,
+                                          CmpWorklist &worklist) {
 
-  LLVMContext &              C = M.getContext();
-  std::vector<Instruction *> icomps;
-  IntegerType *              Int1Ty = IntegerType::getInt1Ty(C);
+  auto pred = cmp_inst->getPredicate();
+  switch (pred) {
 
-  /* iterate over all functions, bbs and instructions and add
-   * all signed compares to icomps vector */
-  for (auto &F : M) {
+    case CmpInst::ICMP_EQ:
+    case CmpInst::ICMP_NE:
+    case CmpInst::ICMP_UGT:
+    case CmpInst::ICMP_ULT:
+      break;
+    default:
+      // unsupported predicate!
+      return false;
 
-    if (!isInInstrumentList(&F)) continue;
+  }
 
-    for (auto &BB : F) {
+  auto op0 = cmp_inst->getOperand(0);
+  auto op1 = cmp_inst->getOperand(1);
 
-      for (auto &IN : BB) {
+  // get bitwidth by checking the bitwidth of the first operator
+  IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+  if (!intTyOp0) {
 
-        CmpInst *selectcmpInst = nullptr;
+    // not an integer type
+    return false;
 
-        if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+  }
 
-          if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) {
+  unsigned bitw = intTyOp0->getBitWidth();
+  if (bitw == target_bitwidth) {
 
-            auto op0 = selectcmpInst->getOperand(0);
-            auto op1 = selectcmpInst->getOperand(1);
+    // already the target bitwidth so we have to do nothing here.
+    return true;
+
+  }
+
+  LLVMContext &C = M.getContext();
+  IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+  BasicBlock * bb = cmp_inst->getParent();
+  IntegerType *OldIntType = IntegerType::get(C, bitw);
+  IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
+  BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst));
+  CmpInst *    icmp_high, *icmp_low;
 
-            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+  /* create the comparison of the top halves of the original operands */
+  Value *s_op0, *op0_high, *s_op1, *op1_high;
 
-            /* see above */
-            if (!intTyOp0 || !intTyOp1) { continue; }
+  IRBuilder<> IRB(bb->getTerminator());
 
-            /* i think this is not possible but to lazy to look it up */
-            if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) {
+  s_op0 = IRB.CreateBinOp(Instruction::LShr, op0,
+                          ConstantInt::get(OldIntType, bitw / 2));
+  op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType);
 
-              continue;
+  s_op1 = IRB.CreateBinOp(Instruction::LShr, op1,
+                          ConstantInt::get(OldIntType, bitw / 2));
+  op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType);
+  icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high));
 
-            }
+  PHINode *PN = nullptr;
 
-            icomps.push_back(selectcmpInst);
+  /* now we have to destinguish between == != and > < */
+  switch (pred) {
 
-          }
+    case CmpInst::ICMP_EQ:
+    case CmpInst::ICMP_NE: {
 
-        }
+      /* transformation for == and != icmps */
+
+      /* create a compare for the lower half of the original operands */
+      BasicBlock *cmp_low_bb =
+          BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
+
+      Value *     op0_low, *op1_low;
+      IRBuilder<> Builder(cmp_low_bb);
+
+      op0_low = Builder.CreateTrunc(op0, NewIntType);
+      op1_low = Builder.CreateTrunc(op1, NewIntType);
+      icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low));
+
+      BranchInst::Create(end_bb, cmp_low_bb);
+
+      /* dependent on the cmp of the high parts go to the end or go on with
+       * the comparison */
+      auto        term = bb->getTerminator();
+      BranchInst *br = nullptr;
+      if (pred == CmpInst::ICMP_EQ) {
+
+        br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
+
+      } else {
+
+        /* CmpInst::ICMP_NE */
+        br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
+
+      }
+
+      term->eraseFromParent();
+
+      /* create the PHI and connect the edges accordingly */
+      PN = PHINode::Create(Int1Ty, 2, "");
+      PN->addIncoming(icmp_low, cmp_low_bb);
+      Value *val = nullptr;
+      if (pred == CmpInst::ICMP_EQ) {
+
+        val = ConstantInt::get(Int1Ty, 0);
+
+      } else {
+
+        /* CmpInst::ICMP_NE */
+        val = ConstantInt::get(Int1Ty, 1);
 
       }
 
+      PN->addIncoming(val, icmp_high->getParent());
+      break;
+
     }
 
-  }
+    case CmpInst::ICMP_UGT:
+    case CmpInst::ICMP_ULT: {
 
-  if (!icomps.size()) { return false; }
+      /* transformations for < and > */
 
-  for (auto &IcmpInst : icomps) {
+      /* create a basic block which checks for the inverse predicate.
+       * if this is true we can go to the end if not we have to go to the
+       * bb which checks the lower half of the operands */
+      Instruction *op0_low, *op1_low;
+      CmpInst *    icmp_inv_cmp = nullptr;
+      BasicBlock * inv_cmp_bb =
+          BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
+      if (pred == CmpInst::ICMP_UGT) {
 
-    BasicBlock *bb = IcmpInst->getParent();
+        icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
+                                       op0_high, op1_high);
 
-    auto op0 = IcmpInst->getOperand(0);
-    auto op1 = IcmpInst->getOperand(1);
+      } else {
 
-    IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-    if (!intTyOp0) { continue; }
-    unsigned     bitw = intTyOp0->getBitWidth();
-    IntegerType *IntType = IntegerType::get(C, bitw);
+        icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
+                                       op0_high, op1_high);
 
-    /* get the new predicate */
-    auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
-    if (!cmp_inst) { continue; }
-    auto               pred = cmp_inst->getPredicate();
-    CmpInst::Predicate new_pred;
+      }
 
-    if (pred == CmpInst::ICMP_SGT) {
+      inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
+      worklist.push_back(icmp_inv_cmp);
 
-      new_pred = CmpInst::ICMP_UGT;
+      auto term = bb->getTerminator();
+      term->eraseFromParent();
+      BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
 
-    } else {
+      /* create a bb which handles the cmp of the lower halves */
+      BasicBlock *cmp_low_bb =
+          BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
+      op0_low = new TruncInst(op0, NewIntType);
+      cmp_low_bb->getInstList().push_back(op0_low);
+      op1_low = new TruncInst(op1, NewIntType);
+      cmp_low_bb->getInstList().push_back(op1_low);
 
-      new_pred = CmpInst::ICMP_ULT;
+      icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
+      cmp_low_bb->getInstList().push_back(icmp_low);
+      BranchInst::Create(end_bb, cmp_low_bb);
+
+      BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
+
+      PN = PHINode::Create(Int1Ty, 3);
+      PN->addIncoming(icmp_low, cmp_low_bb);
+      PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+      PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
+      break;
 
     }
 
-    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+    default:
+      return false;
 
-    /* create a 1 bit compare for the sign bit. to do this shift and trunc
-     * the original operands so only the first bit remains.*/
-    Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
+  }
 
-    s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
-                                   ConstantInt::get(IntType, bitw - 1));
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
-    t_op0 = new TruncInst(s_op0, Int1Ty);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0);
+  BasicBlock::iterator ii(cmp_inst);
+  ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN);
 
-    s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
-                                   ConstantInt::get(IntType, bitw - 1));
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
-    t_op1 = new TruncInst(s_op1, Int1Ty);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1);
+  // We split the comparison into low and high. If this isn't our target
+  // bitwidth we recursivly split the low and high parts again until we have
+  // target bitwidth.
+  if ((bitw / 2) > target_bitwidth) {
 
-    /* compare of the sign bits */
-    icmp_sign_bit =
-        CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
-                             icmp_sign_bit);
+    worklist.push_back(icmp_high);
+    worklist.push_back(icmp_low);
 
-    /* create a new basic block which is executed if the signedness bit is
-     * different */
-    Instruction *icmp_inv_sig_cmp;
-    BasicBlock * sign_bb =
-        BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
-    if (pred == CmpInst::ICMP_SGT) {
+  }
 
-      /* if we check for > and the op0 positive and op1 negative then the final
-       * result is true. if op0 negative and op1 pos, the cmp must result
-       * in false
-       */
-      icmp_inv_sig_cmp =
-          CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
+  return true;
 
-    } else {
+}
+
+bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) {
+
+  CmpWorklist worklist;
+
+  auto op0 = I->getOperand(0);
+  auto op1 = I->getOperand(1);
+  if (!op0 || !op1) { return false; }
+  auto op0Ty = dyn_cast<IntegerType>(op0->getType());
+  if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; }
+
+  unsigned bitw = op0Ty->getBitWidth();
+
+#ifdef VERIFY_TOO_MUCH
+  auto F = I->getParent()->getParent();
+#endif
 
-      /* just the inverse of the above statement */
-      icmp_inv_sig_cmp =
-          CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
+  // we run the comparison simplification on all compares regardless of their
+  // bitwidth.
+  if (I->getPredicate() == CmpInst::ICMP_UGE ||
+      I->getPredicate() == CmpInst::ICMP_SGE ||
+      I->getPredicate() == CmpInst::ICMP_ULE ||
+      I->getPredicate() == CmpInst::ICMP_SLE) {
+
+    if (!simplifyOrEqualsCompare(I, M, worklist)) {
+
+      reportError(
+          "Failed to simplify inequality or equals comparison "
+          "(UGE,SGE,ULE,SLE)",
+          I, M);
 
     }
 
-    sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
-    BranchInst::Create(end_bb, sign_bb);
+  } else if (I->getPredicate() == CmpInst::ICMP_SGT ||
 
-    /* create a new bb which is executed if signedness is equal */
-    Instruction *icmp_usign_cmp;
-    BasicBlock * middle_bb =
-        BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
-    /* we can do a normal unsigned compare now */
-    icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
-    middle_bb->getInstList().push_back(icmp_usign_cmp);
-    BranchInst::Create(end_bb, middle_bb);
+             I->getPredicate() == CmpInst::ICMP_SLT) {
 
-    auto term = bb->getTerminator();
-    /* if the sign is eq do a normal unsigned cmp, else we have to check the
-     * signedness bit */
-    BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
-    term->eraseFromParent();
+    if (!simplifySignedCompare(I, M, worklist)) {
 
-    PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+      reportError("Failed to simplify signed comparison (SGT,SLT)", I, M);
+
+    }
+
+  }
+
+#ifdef VERIFY_TOO_MUCH
+  if (verifyFunction(*F, &errs())) {
+
+    reportError("simpliyfing compare lead to broken function", nullptr, M);
+
+  }
+
+#endif
+
+  // the simplification methods replace the original CmpInst and push the
+  // resulting new CmpInst into the worklist. If the worklist is empty then
+  // we only have to split the original CmpInst.
+  if (worklist.size() == 0) { worklist.push_back(I); }
+
+  while (!worklist.empty()) {
+
+    CmpInst *cmp = worklist.pop_back_val();
+    // we split the simplified compares into comparisons with smaller bitwidths
+    // if they are larger than our target_bitwidth.
+    if (bitw > target_bitwidth) {
+
+      if (!splitCompare(cmp, M, worklist)) {
+
+        reportError("Failed to split comparison", cmp, M);
+
+      }
+
+#ifdef VERIFY_TOO_MUCH
+      if (verifyFunction(*F, &errs())) {
+
+        reportError("splitting compare lead to broken function", nullptr, M);
+
+      }
 
-    PN->addIncoming(icmp_usign_cmp, middle_bb);
-    PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
+#endif
 
-    BasicBlock::iterator ii(IcmpInst);
-    ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+    }
 
   }
 
+  count++;
   return true;
 
 }
@@ -1050,306 +1316,110 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
 
 }
 
-/* splits icmps of size bitw into two nested icmps with bitw/2 size each */
-size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
-
-  size_t count = 0;
-
-  LLVMContext &C = M.getContext();
-
-  IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
-  IntegerType *OldIntType = IntegerType::get(C, bitw);
-  IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
-
-  std::vector<Instruction *> icomps;
-
-  if (bitw % 2) { return 0; }
-
-  /* not supported yet */
-  if (bitw > 64) { return 0; }
-
-  /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the
-   * functions simplifyCompares() and simplifyIntSignedness()
-   * were executed only these four predicates should exist */
-  for (auto &F : M) {
-
-    if (!isInInstrumentList(&F)) continue;
+bool SplitComparesTransform::runOnModule(Module &M) {
 
-    for (auto &BB : F) {
+  char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
+  if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
+  if (bitw_env) { target_bitwidth = atoi(bitw_env); }
 
-      for (auto &IN : BB) {
+  enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
 
-        CmpInst *selectcmpInst = nullptr;
+  if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
+      getenv("AFL_DEBUG") != NULL) {
 
-        if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+    errs() << "Split-compare-pass by laf.intel@gmail.com, extended by "
+              "heiko@hexco.de (splitting icmp to "
+           << target_bitwidth << " bit)\n";
 
-          if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_NE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_UGT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) {
+    if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; }
 
-            auto op0 = selectcmpInst->getOperand(0);
-            auto op1 = selectcmpInst->getOperand(1);
-
-            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+  } else {
 
-            if (!intTyOp0 || !intTyOp1) { continue; }
+    be_quiet = 1;
 
-            /* check if the bitwidths are the one we are looking for */
-            if (intTyOp0->getBitWidth() != bitw ||
-                intTyOp1->getBitWidth() != bitw) {
+  }
 
-              continue;
+  if (enableFPSplit) {
 
-            }
+    count = splitFPCompares(M);
 
-            icomps.push_back(selectcmpInst);
+    /*
+        if (!be_quiet) {
 
-          }
+          errs() << "Split-floatingpoint-compare-pass: " << count
+                 << " FP comparisons split\n";
 
         }
 
-      }
-
-    }
+    */
+    simplifyFPCompares(M);
 
   }
 
-  if (!icomps.size()) { return 0; }
-
-  for (auto &IcmpInst : icomps) {
-
-    BasicBlock *bb = IcmpInst->getParent();
-
-    auto op0 = IcmpInst->getOperand(0);
-    auto op1 = IcmpInst->getOperand(1);
-
-    auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
-    if (!cmp_inst) { continue; }
-    auto pred = cmp_inst->getPredicate();
-
-    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
-
-    /* create the comparison of the top halves of the original operands */
-    Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high;
-
-    s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
-                                   ConstantInt::get(OldIntType, bitw / 2));
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
-    op0_high = new TruncInst(s_op0, NewIntType);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
-                             op0_high);
-
-    s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
-                                   ConstantInt::get(OldIntType, bitw / 2));
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
-    op1_high = new TruncInst(s_op1, NewIntType);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
-                             op1_high);
-
-    icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
-    bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
-                             icmp_high);
-
-    /* now we have to destinguish between == != and > < */
-    if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
-
-      /* transformation for == and != icmps */
-
-      /* create a compare for the lower half of the original operands */
-      Instruction *op0_low, *op1_low, *icmp_low;
-      BasicBlock * cmp_low_bb =
-          BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
-
-      op0_low = new TruncInst(op0, NewIntType);
-      cmp_low_bb->getInstList().push_back(op0_low);
-
-      op1_low = new TruncInst(op1, NewIntType);
-      cmp_low_bb->getInstList().push_back(op1_low);
-
-      icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
-      cmp_low_bb->getInstList().push_back(icmp_low);
-      BranchInst::Create(end_bb, cmp_low_bb);
-
-      /* dependent on the cmp of the high parts go to the end or go on with
-       * the comparison */
-      auto term = bb->getTerminator();
-      if (pred == CmpInst::ICMP_EQ) {
-
-        BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
-
-      } else {
-
-        /* CmpInst::ICMP_NE */
-        BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
-
-      }
-
-      term->eraseFromParent();
-
-      /* create the PHI and connect the edges accordingly */
-      PHINode *PN = PHINode::Create(Int1Ty, 2, "");
-      PN->addIncoming(icmp_low, cmp_low_bb);
-      if (pred == CmpInst::ICMP_EQ) {
-
-        PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb);
-
-      } else {
+  std::vector<CmpInst *> worklist;
+  /* iterate over all functions, bbs and instruction search for all integer
+   * compare instructions. Save them into the worklist for later. */
+  for (auto &F : M) {
 
-        /* CmpInst::ICMP_NE */
-        PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+    if (!isInInstrumentList(&F)) continue;
 
-      }
+    for (auto &BB : F) {
 
-      /* replace the old icmp with the new PHI */
-      BasicBlock::iterator ii(IcmpInst);
-      ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+      for (auto &IN : BB) {
 
-    } else {
+        if (auto CI = dyn_cast<CmpInst>(&IN)) {
 
-      /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */
-      /* transformations for < and > */
+          auto op0 = CI->getOperand(0);
+          auto op1 = CI->getOperand(1);
+          if (!op0 || !op1) { return false; }
+          auto iTy1 = dyn_cast<IntegerType>(op0->getType());
+          if (iTy1 && isa<IntegerType>(op1->getType())) {
 
-      /* create a basic block which checks for the inverse predicate.
-       * if this is true we can go to the end if not we have to go to the
-       * bb which checks the lower half of the operands */
-      Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low;
-      BasicBlock * inv_cmp_bb =
-          BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
-      if (pred == CmpInst::ICMP_UGT) {
+            unsigned bitw = iTy1->getBitWidth();
+            if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
 
-        icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
-                                       op0_high, op1_high);
-
-      } else {
+          }
 
-        icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
-                                       op0_high, op1_high);
+        }
 
       }
 
-      inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
-
-      auto term = bb->getTerminator();
-      term->eraseFromParent();
-      BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
-
-      /* create a bb which handles the cmp of the lower halves */
-      BasicBlock *cmp_low_bb =
-          BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
-      op0_low = new TruncInst(op0, NewIntType);
-      cmp_low_bb->getInstList().push_back(op0_low);
-      op1_low = new TruncInst(op1, NewIntType);
-      cmp_low_bb->getInstList().push_back(op1_low);
-
-      icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
-      cmp_low_bb->getInstList().push_back(icmp_low);
-      BranchInst::Create(end_bb, cmp_low_bb);
-
-      BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
-
-      PHINode *PN = PHINode::Create(Int1Ty, 3);
-      PN->addIncoming(icmp_low, cmp_low_bb);
-      PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
-      PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
-
-      BasicBlock::iterator ii(IcmpInst);
-      ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
-
     }
 
-    ++count;
-
   }
 
-  return count;
-
-}
-
-bool SplitComparesTransform::runOnModule(Module &M) {
-
-  int    bitw = 64;
-  size_t count = 0;
-
-  char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
-  if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
-  if (bitw_env) { bitw = atoi(bitw_env); }
-
-  enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
-
-  if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
-      getenv("AFL_DEBUG") != NULL) {
-
-    printf(
-        "Split-compare-pass by laf.intel@gmail.com, extended by "
-        "heiko@hexco.de\n");
+  // now that we have a list of all integer comparisons we can start replacing
+  // them with the splitted alternatives.
+  for (auto CI : worklist) {
 
-  } else {
-
-    be_quiet = 1;
+    simplifyAndSplit(CI, M);
 
   }
 
-  if (enableFPSplit) {
-
-    count = splitFPCompares(M);
-
-    /*
-        if (!be_quiet) {
-
-          errs() << "Split-floatingpoint-compare-pass: " << count
-                 << " FP comparisons split\n";
+  bool brokenDebug = false;
+  if (verifyModule(M, &errs()
+#if LLVM_VERSION_MAJOR > 3 || \
+    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9)
+                          ,
+                   &brokenDebug  // 9th May 2016
+#endif
+                   )) {
 
-        }
-
-    */
-    simplifyFPCompares(M);
+    reportError(
+        "Module Verifier failed! Consider reporting a bug with the AFL++ "
+        "project.",
+        nullptr, M);
 
   }
 
-  simplifyCompares(M);
-
-  simplifyIntSignedness(M);
+  if (brokenDebug) {
 
-  switch (bitw) {
-
-    case 64:
-      count += splitIntCompares(M, bitw);
-      if (debug)
-        errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
-               << " split\n";
-      bitw >>= 1;
-#if LLVM_VERSION_MAJOR > 3 || \
-    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
-      [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
-#endif
-    case 32:
-      count += splitIntCompares(M, bitw);
-      if (debug)
-        errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
-               << " split\n";
-      bitw >>= 1;
-#if LLVM_VERSION_MAJOR > 3 || \
-    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
-      [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
-#endif
-    case 16:
-      count += splitIntCompares(M, bitw);
-      if (debug)
-        errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
-               << " split\n";
-      // bitw >>= 1;
-      break;
-
-    default:
-      // if (!be_quiet) errs() << "NOT Running split-compare-pass \n";
-      return false;
-      break;
+    reportError("Module Verifier reported broken Debug Infos - Stripping!",
+                nullptr, M);
+    StripDebugInfo(M);
 
   }
 
-  verifyModule(M);
   return true;
 
 }
@@ -1373,3 +1443,8 @@ static RegisterStandardPasses RegisterSplitComparesTransPassLTO(
     registerSplitComparesPass);
 #endif
 
+static RegisterPass<SplitComparesTransform> X("splitcompares",
+                                              "AFL++ split compares",
+                                              true /* Only looks at CFG */,
+                                              true /* Analysis Pass */);
+