about summary refs log tree commit diff
path: root/instrumentation/cmplog-instructions-pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation/cmplog-instructions-pass.cc')
-rw-r--r--instrumentation/cmplog-instructions-pass.cc89
1 files changed, 71 insertions, 18 deletions
diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc
index 9921de0c..3499ccf0 100644
--- a/instrumentation/cmplog-instructions-pass.cc
+++ b/instrumentation/cmplog-instructions-pass.cc
@@ -186,16 +186,19 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
               selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
               selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
               selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
-              selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) {
-
-            auto op0 = selectcmpInst->getOperand(0);
-            auto op1 = selectcmpInst->getOperand(1);
-
-            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
-
-            /* this is probably not needed but we do it anyway */
-            if (!intTyOp0 || !intTyOp1) { continue; }
+              selectcmpInst->getPredicate() == CmpInst::ICMP_SLE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_ULE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UGT ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OLT ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UNE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_ONE) {
 
             icomps.push_back(selectcmpInst);
 
@@ -221,16 +224,66 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
     auto op0 = selectcmpInst->getOperand(0);
     auto op1 = selectcmpInst->getOperand(1);
 
-    IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-    IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+    IntegerType *        intTyOp0 = NULL;
+    IntegerType *        intTyOp1 = NULL;
+    unsigned             max_size = 0;
+    std::vector<Value *> args;
 
-    unsigned max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
-                            ? intTyOp0->getBitWidth()
-                            : intTyOp1->getBitWidth();
+    if (selectcmpInst->getOpcode() == Instruction::FCmp) {
 
-    std::vector<Value *> args;
-    args.push_back(op0);
-    args.push_back(op1);
+      auto ty0 = op0->getType();
+      if (ty0->isHalfTy()
+#if LLVM_VERSION_MAJOR >= 11
+          || ty0->isBFloatTy()
+#endif
+      )
+        max_size = 16;
+      else if (ty0->isFloatTy())
+        max_size = 32;
+      else if (ty0->isDoubleTy())
+        max_size = 64;
+
+      if (max_size) {
+
+        Value *V0 = IRB.CreateBitCast(op0, IntegerType::get(C, max_size));
+        intTyOp0 = dyn_cast<IntegerType>(V0->getType());
+        Value *V1 = IRB.CreateBitCast(op1, IntegerType::get(C, max_size));
+        intTyOp1 = dyn_cast<IntegerType>(V1->getType());
+
+        if (intTyOp0 && intTyOp1) {
+
+          max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
+                         ? intTyOp0->getBitWidth()
+                         : intTyOp1->getBitWidth();
+          args.push_back(V0);
+          args.push_back(V1);
+
+        } else {
+
+          max_size = 0;
+
+        }
+
+      }
+
+    } else {
+
+      intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+      intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+
+      if (intTyOp0 && intTyOp1) {
+
+        max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
+                       ? intTyOp0->getBitWidth()
+                       : intTyOp1->getBitWidth();
+        args.push_back(op0);
+        args.push_back(op1);
+
+      }
+
+    }
+
+    if (max_size < 8 || max_size > 64 || !intTyOp0 || !intTyOp1) continue;
 
     switch (max_size) {