about summary refs log tree commit diff
path: root/instrumentation/cmplog-instructions-pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation/cmplog-instructions-pass.cc')
-rw-r--r--instrumentation/cmplog-instructions-pass.cc184
1 files changed, 22 insertions, 162 deletions
diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc
index ad334d3b..0562c5b2 100644
--- a/instrumentation/cmplog-instructions-pass.cc
+++ b/instrumentation/cmplog-instructions-pass.cc
@@ -104,7 +104,6 @@ Iterator Unique(Iterator first, Iterator last) {
 bool CmpLogInstructions::hookInstrs(Module &M) {
 
   std::vector<Instruction *> icomps;
-  std::vector<SwitchInst *>  switches;
   LLVMContext &              C = M.getContext();
 
   Type *       VoidTy = Type::getVoidTy(C);
@@ -222,6 +221,18 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
   FunctionCallee cmplogHookInsN = cN;
 #endif
 
+  GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
+
+  if (!AFLCmplogPtr) {
+
+    AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
+                                      GlobalValue::ExternalWeakLinkage, 0,
+                                      "__afl_cmp_map");
+
+  }
+
+  Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
+
   /* iterate over all functions, bbs and instruction and add suitable calls */
   for (auto &F : M) {
 
@@ -238,164 +249,6 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
         }
 
-        SwitchInst *switchInst = nullptr;
-        if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
-
-          if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
-
-        }
-
-      }
-
-    }
-
-  }
-
-  // unique the collected switches
-  switches.erase(Unique(switches.begin(), switches.end()), switches.end());
-
-  // Instrument switch values for cmplog
-  if (switches.size()) {
-
-    if (!be_quiet)
-      errs() << "Hooking " << switches.size() << " switch instructions\n";
-
-    for (auto &SI : switches) {
-
-      Value *       Val = SI->getCondition();
-      unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
-      unsigned char do_cast = 0;
-
-      if (!SI->getNumCases() || max_size < 16) {
-
-        // if (!be_quiet) errs() << "skip trivial switch..\n";
-        continue;
-
-      }
-
-      if (max_size % 8) {
-
-        max_size = (((max_size / 8) + 1) * 8);
-        do_cast = 1;
-
-      }
-
-      IRBuilder<> IRB(SI->getParent());
-      IRB.SetInsertPoint(SI);
-
-      if (max_size > 128) {
-
-        if (!be_quiet) {
-
-          fprintf(stderr,
-                  "Cannot handle this switch bit size: %u (truncating)\n",
-                  max_size);
-
-        }
-
-        max_size = 128;
-        do_cast = 1;
-
-      }
-
-      // do we need to cast?
-      switch (max_size) {
-
-        case 8:
-        case 16:
-        case 32:
-        case 64:
-        case 128:
-          cast_size = max_size;
-          break;
-        default:
-          cast_size = 128;
-          do_cast = 1;
-
-      }
-
-      Value *CompareTo = Val;
-
-      if (do_cast) {
-
-        CompareTo =
-            IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false);
-
-      }
-
-      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
-           ++i) {
-
-#if LLVM_VERSION_MAJOR < 5
-        ConstantInt *cint = i.getCaseValue();
-#else
-        ConstantInt *cint = i->getCaseValue();
-#endif
-
-        if (cint) {
-
-          std::vector<Value *> args;
-          args.push_back(CompareTo);
-
-          Value *new_param = cint;
-
-          if (do_cast) {
-
-            new_param =
-                IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false);
-
-          }
-
-          if (new_param) {
-
-            args.push_back(new_param);
-            ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
-            args.push_back(attribute);
-            if (cast_size != max_size) {
-
-              ConstantInt *bitsize =
-                  ConstantInt::get(Int8Ty, (max_size / 8) - 1);
-              args.push_back(bitsize);
-
-            }
-
-            switch (cast_size) {
-
-              case 8:
-                IRB.CreateCall(cmplogHookIns1, args);
-                break;
-              case 16:
-                IRB.CreateCall(cmplogHookIns2, args);
-                break;
-              case 32:
-                IRB.CreateCall(cmplogHookIns4, args);
-                break;
-              case 64:
-                IRB.CreateCall(cmplogHookIns8, args);
-                break;
-              case 128:
-#ifdef WORD_SIZE_64
-                if (max_size == 128) {
-
-                  IRB.CreateCall(cmplogHookIns16, args);
-
-                } else {
-
-                  IRB.CreateCall(cmplogHookInsN, args);
-
-                }
-
-#endif
-                break;
-              default:
-                break;
-
-            }
-
-          }
-
-        }
-
       }
 
     }
@@ -409,8 +262,15 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
     for (auto &selectcmpInst : icomps) {
 
-      IRBuilder<> IRB(selectcmpInst->getParent());
-      IRB.SetInsertPoint(selectcmpInst);
+      IRBuilder<> IRB2(selectcmpInst->getParent());
+      IRB2.SetInsertPoint(selectcmpInst);
+      LoadInst *CmpPtr = IRB2.CreateLoad(AFLCmplogPtr);
+      CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
+      auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
+      auto ThenTerm =
+          SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false);
+
+      IRBuilder<> IRB(ThenTerm);
 
       Value *op0 = selectcmpInst->getOperand(0);
       Value *op1 = selectcmpInst->getOperand(1);
@@ -601,7 +461,7 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
   }
 
-  if (switches.size() || icomps.size())
+  if (icomps.size())
     return true;
   else
     return false;