about summary refs log tree commit diff homepage
path: root/lib
diff options
context:
space:
mode:
authorNguyễn Gia Phong <mcsinyx@disroot.org>2023-03-30 20:11:04 +0900
committerNguyễn Gia Phong <cnx@loang.net>2024-03-05 17:23:52 +0900
commitcae01bffb4c60e586ad54e4d056dfc5a193faa9e (patch)
tree9e94307131661ff1c30cd66899d2db9fb8886432 /lib
parent96f2da3a83ea1d9fc92b2a6ae649c4c69909259b (diff)
downloadklee-cae01bffb4c60e586ad54e4d056dfc5a193faa9e.tar.gz
Implement differentiator extraction
Diffstat (limited to 'lib')
-rw-r--r--lib/Core/ExecutionState.cpp29
-rw-r--r--lib/Core/Executor.cpp93
-rw-r--r--lib/Core/Executor.h18
3 files changed, 139 insertions, 1 deletions
diff --git a/lib/Core/ExecutionState.cpp b/lib/Core/ExecutionState.cpp
index cb8a3ced..b4af092a 100644
--- a/lib/Core/ExecutionState.cpp
+++ b/lib/Core/ExecutionState.cpp
@@ -388,7 +388,36 @@ void ExecutionState::dumpStack(llvm::raw_ostream &out) const {
   }
 }
 
+bool isMetaConstraint(ref<Expr> e) {
+  if (e.get()->getKind() != Expr::Eq)
+    return false;
+  const auto eq = dyn_cast<EqExpr>(e.get());
+  if (eq->left.get()->getKind() != Expr::Constant)
+    return false;
+  const auto constant = dyn_cast<ConstantExpr>(eq->left.get());
+  if (constant->isFalse()) // the else branch
+    return isMetaConstraint(eq->right);
+
+  if (eq->right.get()->getKind() != Expr::Concat)
+    return false;
+  const auto concat = dyn_cast<ConcatExpr>(eq->right.get());
+  if (concat->getLeft().get()->getKind() != Expr::Read)
+    return false;
+
+  const auto read = dyn_cast<ReadExpr>(concat->getLeft().get());
+  const auto& name = read->updates.root->name;
+  // string::starts_with requires C++20
+  if (name.substr(0, 8) != "__choose")
+    return false;
+  for (const auto c : name.substr(8))
+    if ('0' > c || c > '9')
+      return false;
+  return true;
+}
+
 void ExecutionState::addConstraint(ref<Expr> e) {
+  if (isMetaConstraint(e))
+    return;
   ConstraintManager c(constraints);
   c.addConstraint(e);
 }
diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp
index a2c28dee..b335e566 100644
--- a/lib/Core/Executor.cpp
+++ b/lib/Core/Executor.cpp
@@ -92,9 +92,11 @@
 #include <iosfwd>
 #include <limits>
 #include <sstream>
+#include <stdlib.h>
 #include <string>
 #include <sys/mman.h>
 #include <sys/resource.h>
+#include <unistd.h>
 #include <vector>
 
 using namespace llvm;
@@ -3802,6 +3804,86 @@ static std::string terminationTypeFileExtension(StateTerminationType type) {
   return ret;
 };
 
+void Executor::extractDifferentiator(uint64_t a, uint64_t b, const z3::model& m) {
+  auto test = Differentiator {a, b};
+  for (auto k = m.size(); k--;) {
+    const auto& name = m[k].name().str();
+    if (isSymArg(name)) {
+      const uint8_t i = std::atoi(name.c_str() + 3);
+      test.args[i] = "";
+      const auto& expr = m.eval(m[k]());
+      for (uint8_t b = 0; b < this->symArgs[i]; ++b) {
+        const auto c = z3::select(expr, b).simplify().as_uint64();
+        assert(c <= std::numeric_limits<unsigned char>::max());
+        test.args[i].push_back((unsigned char) c);
+      }
+    } else if (isSymOut(name.substr(0, name.size() - 2))) {
+      // TODO: use arguments for all patches
+      const auto& expr = m.eval(m[k]());
+      std::string binary {""};
+      const auto size = this->symOuts[name.substr(0, name.size() - 2)];
+      for (uint8_t b = 0; b < size; ++b) {
+        const auto c = z3::select(expr, b).simplify().as_uint64();
+        assert(c <= std::numeric_limits<unsigned char>::max());
+        binary.push_back((unsigned char) c);
+        const auto& ident = name.substr(4, name.size() - 6);
+        if (name[name.size() - 1] == 'a')
+          test.outputs[ident].first = binary;
+        else
+          test.outputs[ident].second = binary;
+      }
+    }
+  }
+  this->diffTests.push_back(test);
+}
+
+void Executor::searchDifferentiators(ExecutionState* latest) {
+  std::string last_smt2_name = "smt2.XXXXXX";
+  auto last_smt2_fd = mkstemp(last_smt2_name.data());
+  assert(last_smt2_fd != -1);
+  const auto written = write(last_smt2_fd, latest->formula.c_str(),
+                             latest->formula.size());
+  assert(written == latest->formula.size());
+  close(last_smt2_fd);
+  for (const auto& state : this->exitStates) {
+    // TODO: skip moar and order
+    if (state->patchNo == latest->patchNo)
+      continue;
+
+    // File I/O is expensive but SMT solving is even more expensive (-;
+    // Seriously though, FIXME: implement symbdiff natively
+    std::string smt2_name = "smt2.XXXXXX";
+    auto smt2_fd = mkstemp(smt2_name.data());
+    assert(smt2_fd != -1);
+    const auto written = write(smt2_fd, state->formula.c_str(),
+                               state->formula.size());
+    assert(written == state->formula.size());
+    close(smt2_fd);
+
+    auto command = "symbdiff " + smt2_name + " " + last_smt2_name;
+    for (const auto& out : this->symOuts)
+      command += " " + std::to_string(out.second);
+    const auto pipe = popen(command.c_str(), "r");
+    std::string formula;
+    char buffer[128];
+    while (!feof(pipe))
+      if (fgets(buffer, 128, pipe) != NULL)
+        formula += buffer;
+    assert(pclose(pipe) == 0);
+    remove(smt2_name.c_str());
+
+    static z3::context c;
+    static z3::solver s {c};
+    s.reset();
+    s.from_string(formula.c_str());
+    if (s.check() != z3::sat)
+      continue;
+    this->extractDifferentiator(latest->patchNo, state->patchNo, s.get_model());
+    llvm::errs() << this->diffTests.back() << "\n";
+  }
+  remove(last_smt2_name.c_str());
+}
+
 void Executor::terminateStateOnExit(ExecutionState &state) {
   ++stats::terminationExit;
   if (shouldWriteTest(state) || (AlwaysOutputSeeds && seedMap.count(&state)))
@@ -3811,6 +3893,7 @@ void Executor::terminateStateOnExit(ExecutionState &state) {
 
   interpreterHandler->incPathsCompleted();
   getConstraintLog(state, state.formula, Interpreter::SMTLIB2);
+  searchDifferentiators(&state);
   exitStates.insert(&state);
   terminateState(state, StateTerminationType::Exit);
 }
@@ -4615,7 +4698,15 @@ void Executor::executeMakeSymbolic(ExecutionState &state,
     const Array *array = arrayCache.CreateArray(uniqueName, mo->size);
     bindObjectInState(state, mo, false, array);
     state.addSymbolic(mo, array);
-    
+
+    if (isSymArg(uniqueName)) {
+      assert(std::atoi(name.c_str() + 3) == this->symArgs.size());
+      this->symArgs.push_back(mo->size - 1); // string's null termination
+    } else if (isSymOut(uniqueName)) {
+      assert(this->symOuts.find(uniqueName) == this->symOuts.end());
+      this->symOuts[uniqueName] = mo->size;
+    }
+
     auto found = seedMap.find(&state);
     if (found != seedMap.end()) {
       // In seed mode we need to add this as binding
diff --git a/lib/Core/Executor.h b/lib/Core/Executor.h
index 465751f6..5719c510 100644
--- a/lib/Core/Executor.h
+++ b/lib/Core/Executor.h
@@ -15,6 +15,7 @@
 #ifndef KLEE_EXECUTOR_H
 #define KLEE_EXECUTOR_H
 
+#include "Differentiator.h"
 #include "ExecutionState.h"
 #include "UserSearcher.h"
 
@@ -32,6 +33,9 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include <z3.h>
+#include <z3++.h>
+
 #include <map>
 #include <memory>
 #include <set>
@@ -112,6 +116,8 @@ private:
   std::unique_ptr<TimingSolver> solver;
   std::unique_ptr<MemoryManager> memory;
   std::set<ExecutionState*, ExecutionStateIDCompare> states;
+  std::set<ExecutionState*, ExecutionStateIDCompare> exitStates;
+  std::vector<Differentiator> diffTests;
   StatsTracker *statsTracker;
   TreeStreamWriter *pathWriter, *symPathWriter;
   SpecialFunctionHandler *specialFunctionHandler;
@@ -145,6 +151,12 @@ private:
   /// globals that have no representative object (e.g. aliases).
   std::map<const llvm::GlobalValue*, ref<ConstantExpr>> globalAddresses;
 
+  /// Size of symbolic arguments.
+  std::vector<size_t> symArgs;
+
+  /// Size of symbolic outputs.
+  std::map<std::string, size_t> symOuts;
+
   /// Map of legal function addresses to the corresponding Function.
   /// Used to validate and dereference function pointers.
   std::unordered_map<std::uint64_t, llvm::Function*> legalFunctions;
@@ -416,6 +428,12 @@ private:
   const InstructionInfo & getLastNonKleeInternalInstruction(const ExecutionState &state,
       llvm::Instruction** lastInstruction);
 
+  /// Extract differencial test from SMT model
+  void extractDifferentiator(uint64_t, uint64_t, const z3::model&);
+
+  /// Compare with other exit states for possible differencial tests
+  void searchDifferentiators(ExecutionState *state);
+
   /// Remove state from queue and delete state. This function should only be
   /// used in the termination functions below.
   void terminateState(ExecutionState &state, StateTerminationType reason);