about summary refs log tree commit diff
path: root/instrumentation/cmplog-instructions-pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation/cmplog-instructions-pass.cc')
-rw-r--r--instrumentation/cmplog-instructions-pass.cc586
1 files changed, 498 insertions, 88 deletions
diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc
index 3499ccf0..a74fb6c8 100644
--- a/instrumentation/cmplog-instructions-pass.cc
+++ b/instrumentation/cmplog-instructions-pass.cc
@@ -85,9 +85,25 @@ class CmpLogInstructions : public ModulePass {
 
 char CmpLogInstructions::ID = 0;
 
+template <class Iterator>
+Iterator Unique(Iterator first, Iterator last) {
+
+  while (first != last) {
+
+    Iterator next(first);
+    last = std::remove(++next, last, *first);
+    first = next;
+
+  }
+
+  return last;
+
+}
+
 bool CmpLogInstructions::hookInstrs(Module &M) {
 
   std::vector<Instruction *> icomps;
+  std::vector<SwitchInst *>  switches;
   LLVMContext &              C = M.getContext();
 
   Type *       VoidTy = Type::getVoidTy(C);
@@ -95,13 +111,15 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
   IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
   IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
   IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
+  IntegerType *Int128Ty = IntegerType::getInt128Ty(C);
 
 #if LLVM_VERSION_MAJOR < 9
   Constant *
 #else
   FunctionCallee
 #endif
-      c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty
+      c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
+                                 Int8Ty
 #if LLVM_VERSION_MAJOR < 5
                                  ,
                                  NULL
@@ -118,7 +136,8 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 #else
   FunctionCallee
 #endif
-      c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty
+      c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
+                                 Int8Ty
 #if LLVM_VERSION_MAJOR < 5
                                  ,
                                  NULL
@@ -135,7 +154,8 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 #else
   FunctionCallee
 #endif
-      c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty
+      c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
+                                 Int8Ty
 #if LLVM_VERSION_MAJOR < 5
                                  ,
                                  NULL
@@ -152,7 +172,8 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 #else
   FunctionCallee
 #endif
-      c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty
+      c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
+                                 Int8Ty
 #if LLVM_VERSION_MAJOR < 5
                                  ,
                                  NULL
@@ -164,6 +185,42 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
   FunctionCallee cmplogHookIns8 = c8;
 #endif
 
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c16 = M.getOrInsertFunction("__cmplog_ins_hook16", VoidTy, Int128Ty,
+                                  Int128Ty, Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                  ,
+                                  NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookIns16 = cast<Function>(c16);
+#else
+  FunctionCallee cmplogHookIns16 = c16;
+#endif
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      cN = M.getOrInsertFunction("__cmplog_ins_hookN", VoidTy, Int128Ty,
+                                 Int128Ty, Int8Ty, Int8Ty
+#if LLVM_VERSION_MAJOR < 5
+                                 ,
+                                 NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookInsN = cast<Function>(cN);
+#else
+  FunctionCallee cmplogHookInsN = cN;
+#endif
+
   /* iterate over all functions, bbs and instruction and add suitable calls */
   for (auto &F : M) {
 
@@ -174,35 +231,16 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
       for (auto &IN : BB) {
 
         CmpInst *selectcmpInst = nullptr;
-
         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
 
-          if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_NE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_UGT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SGT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_ULT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SLT ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SLE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_ULE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_UGT ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_OLT ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_UNE ||
-              selectcmpInst->getPredicate() == CmpInst::FCMP_ONE) {
-
-            icomps.push_back(selectcmpInst);
+          icomps.push_back(selectcmpInst);
 
-          }
+        }
+
+        SwitchInst *switchInst = nullptr;
+        if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
+
+          if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
 
         }
 
@@ -212,101 +250,473 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
 
   }
 
-  if (!icomps.size()) return false;
-  // if (!be_quiet) errs() << "Hooking " << icomps.size() << " cmp
-  // instructions\n";
+  // unique the collected switches
+  switches.erase(Unique(switches.begin(), switches.end()), switches.end());
+
+  // Instrument switch values for cmplog
+  if (switches.size()) {
+
+    if (!be_quiet)
+      errs() << "Hooking " << switches.size() << " switch instructions\n";
 
-  for (auto &selectcmpInst : icomps) {
+    for (auto &SI : switches) {
 
-    IRBuilder<> IRB(selectcmpInst->getParent());
-    IRB.SetInsertPoint(selectcmpInst);
+      Value *       Val = SI->getCondition();
+      unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
+      unsigned char do_cast = 0;
 
-    auto op0 = selectcmpInst->getOperand(0);
-    auto op1 = selectcmpInst->getOperand(1);
+      if (!SI->getNumCases() || max_size <= 8) {
 
-    IntegerType *        intTyOp0 = NULL;
-    IntegerType *        intTyOp1 = NULL;
-    unsigned             max_size = 0;
-    std::vector<Value *> args;
+        // if (!be_quiet) errs() << "skip trivial switch..\n";
+        continue;
 
-    if (selectcmpInst->getOpcode() == Instruction::FCmp) {
+      }
+
+      IRBuilder<> IRB(SI->getParent());
+      IRB.SetInsertPoint(SI);
+
+      if (max_size % 8) {
+
+        max_size = (((max_size / 8) + 1) * 8);
+        do_cast = 1;
+
+      }
+
+      if (max_size > 128) {
+
+        if (!be_quiet) {
+
+          fprintf(stderr,
+                  "Cannot handle this switch bit size: %u (truncating)\n",
+                  max_size);
+
+        }
+
+        max_size = 128;
+        do_cast = 1;
+
+      }
+
+      // do we need to cast?
+      switch (max_size) {
+
+        case 8:
+        case 16:
+        case 32:
+        case 64:
+        case 128:
+          cast_size = max_size;
+          break;
+        default:
+          cast_size = 128;
+          do_cast = 1;
+
+      }
+
+      Value *CompareTo = Val;
+
+      if (do_cast) {
+
+        ConstantInt *cint = dyn_cast<ConstantInt>(Val);
+        if (cint) {
+
+          uint64_t val = cint->getZExtValue();
+          // fprintf(stderr, "ConstantInt: %lu\n", val);
+          switch (cast_size) {
+
+            case 8:
+              CompareTo = ConstantInt::get(Int8Ty, val);
+              break;
+            case 16:
+              CompareTo = ConstantInt::get(Int16Ty, val);
+              break;
+            case 32:
+              CompareTo = ConstantInt::get(Int32Ty, val);
+              break;
+            case 64:
+              CompareTo = ConstantInt::get(Int64Ty, val);
+              break;
+            case 128:
+              CompareTo = ConstantInt::get(Int128Ty, val);
+              break;
+
+          }
 
-      auto ty0 = op0->getType();
-      if (ty0->isHalfTy()
+        } else {
+
+          CompareTo = IRB.CreateBitCast(Val, IntegerType::get(C, cast_size));
+
+        }
+
+      }
+
+      for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
+           ++i) {
+
+#if LLVM_VERSION_MAJOR < 5
+        ConstantInt *cint = i.getCaseValue();
+#else
+        ConstantInt *cint = i->getCaseValue();
+#endif
+
+        if (cint) {
+
+          std::vector<Value *> args;
+          args.push_back(CompareTo);
+
+          Value *new_param = cint;
+
+          if (do_cast) {
+
+            uint64_t val = cint->getZExtValue();
+            // fprintf(stderr, "ConstantInt: %lu\n", val);
+            switch (cast_size) {
+
+              case 8:
+                new_param = ConstantInt::get(Int8Ty, val);
+                break;
+              case 16:
+                new_param = ConstantInt::get(Int16Ty, val);
+                break;
+              case 32:
+                new_param = ConstantInt::get(Int32Ty, val);
+                break;
+              case 64:
+                new_param = ConstantInt::get(Int64Ty, val);
+                break;
+              case 128:
+                new_param = ConstantInt::get(Int128Ty, val);
+                break;
+
+            }
+
+          }
+
+          if (new_param) {
+
+            args.push_back(new_param);
+            ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
+            args.push_back(attribute);
+            if (cast_size != max_size) {
+
+              ConstantInt *bitsize =
+                  ConstantInt::get(Int8Ty, (max_size / 8) - 1);
+              args.push_back(bitsize);
+
+            }
+
+            switch (cast_size) {
+
+              case 8:
+                IRB.CreateCall(cmplogHookIns1, args);
+                break;
+              case 16:
+                IRB.CreateCall(cmplogHookIns2, args);
+                break;
+              case 32:
+                IRB.CreateCall(cmplogHookIns4, args);
+                break;
+              case 64:
+                IRB.CreateCall(cmplogHookIns8, args);
+                break;
+              case 128:
+                if (max_size == 128) {
+
+                  IRB.CreateCall(cmplogHookIns16, args);
+
+                } else {
+
+                  IRB.CreateCall(cmplogHookInsN, args);
+
+                }
+
+                break;
+
+            }
+
+          }
+
+        }
+
+      }
+
+    }
+
+  }
+
+  if (icomps.size()) {
+
+    // if (!be_quiet) errs() << "Hooking " << icomps.size() <<
+    //                          " cmp instructions\n";
+
+    for (auto &selectcmpInst : icomps) {
+
+      IRBuilder<> IRB(selectcmpInst->getParent());
+      IRB.SetInsertPoint(selectcmpInst);
+
+      Value *op0 = selectcmpInst->getOperand(0);
+      Value *op1 = selectcmpInst->getOperand(1);
+
+      IntegerType *        intTyOp0 = NULL;
+      IntegerType *        intTyOp1 = NULL;
+      unsigned             max_size = 0, cast_size = 0;
+      unsigned char        attr = 0, do_cast = 0;
+      std::vector<Value *> args;
+
+      CmpInst *cmpInst = dyn_cast<CmpInst>(selectcmpInst);
+
+      if (!cmpInst) { continue; }
+
+      switch (cmpInst->getPredicate()) {
+
+        case CmpInst::ICMP_NE:
+        case CmpInst::FCMP_UNE:
+        case CmpInst::FCMP_ONE:
+          break;
+        case CmpInst::ICMP_EQ:
+        case CmpInst::FCMP_UEQ:
+        case CmpInst::FCMP_OEQ:
+          attr += 1;
+          break;
+        case CmpInst::ICMP_UGT:
+        case CmpInst::ICMP_SGT:
+        case CmpInst::FCMP_OGT:
+        case CmpInst::FCMP_UGT:
+          attr += 2;
+          break;
+        case CmpInst::ICMP_UGE:
+        case CmpInst::ICMP_SGE:
+        case CmpInst::FCMP_OGE:
+        case CmpInst::FCMP_UGE:
+          attr += 3;
+          break;
+        case CmpInst::ICMP_ULT:
+        case CmpInst::ICMP_SLT:
+        case CmpInst::FCMP_OLT:
+        case CmpInst::FCMP_ULT:
+          attr += 4;
+          break;
+        case CmpInst::ICMP_ULE:
+        case CmpInst::ICMP_SLE:
+        case CmpInst::FCMP_OLE:
+        case CmpInst::FCMP_ULE:
+          attr += 5;
+          break;
+        default:
+          break;
+
+      }
+
+      if (selectcmpInst->getOpcode() == Instruction::FCmp) {
+
+        auto ty0 = op0->getType();
+        if (ty0->isHalfTy()
 #if LLVM_VERSION_MAJOR >= 11
-          || ty0->isBFloatTy()
+            || ty0->isBFloatTy()
 #endif
-      )
-        max_size = 16;
-      else if (ty0->isFloatTy())
-        max_size = 32;
-      else if (ty0->isDoubleTy())
-        max_size = 64;
+        )
+          max_size = 16;
+        else if (ty0->isFloatTy())
+          max_size = 32;
+        else if (ty0->isDoubleTy())
+          max_size = 64;
+        else if (ty0->isX86_FP80Ty())
+          max_size = 80;
+        else if (ty0->isFP128Ty() || ty0->isPPC_FP128Ty())
+          max_size = 128;
+
+        attr += 8;
+        do_cast = 1;
 
-      if (max_size) {
+      } else {
 
-        Value *V0 = IRB.CreateBitCast(op0, IntegerType::get(C, max_size));
-        intTyOp0 = dyn_cast<IntegerType>(V0->getType());
-        Value *V1 = IRB.CreateBitCast(op1, IntegerType::get(C, max_size));
-        intTyOp1 = dyn_cast<IntegerType>(V1->getType());
+        intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+        intTyOp1 = dyn_cast<IntegerType>(op1->getType());
 
         if (intTyOp0 && intTyOp1) {
 
           max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
                          ? intTyOp0->getBitWidth()
                          : intTyOp1->getBitWidth();
-          args.push_back(V0);
-          args.push_back(V1);
 
-        } else {
+        }
+
+      }
+
+      if (!max_size) { continue; }
+
+      // _ExtInt() with non-8th values
+      if (max_size % 8) {
+
+        max_size = (((max_size / 8) + 1) * 8);
+        do_cast = 1;
+
+      }
+
+      if (max_size > 128) {
+
+        if (!be_quiet) {
 
-          max_size = 0;
+          fprintf(stderr,
+                  "Cannot handle this compare bit size: %u (truncating)\n",
+                  max_size);
 
         }
 
+        max_size = 128;
+        do_cast = 1;
+
+      }
+
+      // do we need to cast?
+      switch (max_size) {
+
+        case 8:
+        case 16:
+        case 32:
+        case 64:
+        case 128:
+          cast_size = max_size;
+          break;
+        default:
+          cast_size = 128;
+          do_cast = 1;
+
       }
 
-    } else {
+      if (do_cast) {
+
+        // F*cking LLVM optimized out any kind of bitcasts of ConstantInt values
+        // creating illegal calls. WTF. So we have to work around this.
+
+        ConstantInt *cint = dyn_cast<ConstantInt>(op0);
+        if (cint) {
+
+          uint64_t val = cint->getZExtValue();
+          // fprintf(stderr, "ConstantInt: %lu\n", val);
+          ConstantInt *new_param = NULL;
+          switch (cast_size) {
+
+            case 8:
+              new_param = ConstantInt::get(Int8Ty, val);
+              break;
+            case 16:
+              new_param = ConstantInt::get(Int16Ty, val);
+              break;
+            case 32:
+              new_param = ConstantInt::get(Int32Ty, val);
+              break;
+            case 64:
+              new_param = ConstantInt::get(Int64Ty, val);
+              break;
+            case 128:
+              new_param = ConstantInt::get(Int128Ty, val);
+              break;
+
+          }
+
+          if (!new_param) { continue; }
+          args.push_back(new_param);
+
+        } else {
+
+          Value *V0 = IRB.CreateBitCast(op0, IntegerType::get(C, cast_size));
+          args.push_back(V0);
+
+        }
+
+        cint = dyn_cast<ConstantInt>(op1);
+        if (cint) {
+
+          uint64_t     val = cint->getZExtValue();
+          ConstantInt *new_param = NULL;
+          switch (cast_size) {
+
+            case 8:
+              new_param = ConstantInt::get(Int8Ty, val);
+              break;
+            case 16:
+              new_param = ConstantInt::get(Int16Ty, val);
+              break;
+            case 32:
+              new_param = ConstantInt::get(Int32Ty, val);
+              break;
+            case 64:
+              new_param = ConstantInt::get(Int64Ty, val);
+              break;
+            case 128:
+              new_param = ConstantInt::get(Int128Ty, val);
+              break;
+
+          }
+
+          if (!new_param) { continue; }
+          args.push_back(new_param);
+
+        } else {
+
+          Value *V1 = IRB.CreateBitCast(op1, IntegerType::get(C, cast_size));
+          args.push_back(V1);
 
-      intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-      intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+        }
 
-      if (intTyOp0 && intTyOp1) {
+      } else {
 
-        max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
-                       ? intTyOp0->getBitWidth()
-                       : intTyOp1->getBitWidth();
         args.push_back(op0);
         args.push_back(op1);
 
       }
 
-    }
+      ConstantInt *attribute = ConstantInt::get(Int8Ty, attr);
+      args.push_back(attribute);
+
+      if (cast_size != max_size) {
+
+        ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1);
+        args.push_back(bitsize);
+
+      }
+
+      // fprintf(stderr, "_ExtInt(%u) castTo %u with attr %u didcast %u\n",
+      //         max_size, cast_size, attr, do_cast);
+
+      switch (cast_size) {
 
-    if (max_size < 8 || max_size > 64 || !intTyOp0 || !intTyOp1) continue;
-
-    switch (max_size) {
-
-      case 8:
-        IRB.CreateCall(cmplogHookIns1, args);
-        break;
-      case 16:
-        IRB.CreateCall(cmplogHookIns2, args);
-        break;
-      case 32:
-        IRB.CreateCall(cmplogHookIns4, args);
-        break;
-      case 64:
-        IRB.CreateCall(cmplogHookIns8, args);
-        break;
-      default:
-        break;
+        case 8:
+          IRB.CreateCall(cmplogHookIns1, args);
+          break;
+        case 16:
+          IRB.CreateCall(cmplogHookIns2, args);
+          break;
+        case 32:
+          IRB.CreateCall(cmplogHookIns4, args);
+          break;
+        case 64:
+          IRB.CreateCall(cmplogHookIns8, args);
+          break;
+        case 128:
+          if (max_size == 128) {
+
+            IRB.CreateCall(cmplogHookIns16, args);
+
+          } else {
+
+            IRB.CreateCall(cmplogHookInsN, args);
+
+          }
+
+          break;
+
+      }
 
     }
 
   }
 
-  return true;
+  if (switches.size() || icomps.size())
+    return true;
+  else
+    return false;
 
 }