about summary refs log tree commit diff homepage
path: root/lib/Core/Differentiator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Core/Differentiator.cpp')
-rw-r--r--lib/Core/Differentiator.cpp250
1 files changed, 250 insertions, 0 deletions
diff --git a/lib/Core/Differentiator.cpp b/lib/Core/Differentiator.cpp
new file mode 100644
index 00000000..20aa69c9
--- /dev/null
+++ b/lib/Core/Differentiator.cpp
@@ -0,0 +1,250 @@
+// Metaprogram's revision differentiator implementation
+// Copyright 2023  Nguyễn Gia Phong
+//
+// This file is distributed under the University of Illinois
+// Open Source License.  See LICENSE.TXT for details.
+
+#include "Differentiator.h"
+
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/CommandLine.h"
+
+#include <spawn.h>
+
+llvm::cl::opt<std::string>
+InputProgram{llvm::cl::desc("<input program>"),
+             llvm::cl::Positional, llvm::cl::Required};
+
+using namespace klee;
+
+ExprCmbnVisitor::ExprCmbnVisitor(char rev, ArrayCache& ac)
+: revision(rev), arrayCache(ac)
+{}
+
+ExprVisitor::Action
+ExprCmbnVisitor::visitExprPost(const Expr& e)
+{
+  if (e.getKind() != Expr::Read)
+    return ExprVisitor::Action::doChildren();
+  const auto load = dyn_cast<ReadExpr>(&e);
+  const auto orig = load->updates.root;
+  if (!Differentiator::isSymOut(orig->name))
+    return ExprVisitor::Action::doChildren();
+  const auto a = this->arrayCache.CreateArray(
+    orig->name + '!' + this->revision,
+    orig->size, 0, 0, orig->domain, orig->range);
+  const auto replacement = ReadExpr::create({a, load->updates.head},
+                                            load->index);
+  this->outVars[orig->name][load->index] = dyn_cast<ReadExpr>(replacement.get());
+  return ExprVisitor::Action::changeTo(replacement);
+}
+
+bool
+Differentiator::isSymArg(const std::string& name)
+{
+  return (name.size() == 5 // string::starts_with requires C++20
+          && name[0] == 'a' && name[1] == 'r' && name[2] == 'g'
+          && '0' <= name[3] && name[3] <= '9'
+          && '0' <= name[4] && name[4] <= '9');
+}
+
+bool
+Differentiator::isSymOut(const std::string& name)
+{
+  // string::starts_with requires C++20
+  return (name[0] == 'o' && name[1] == 'u' && name[2] == 't' && name[3] == '!'
+          && '0' <= name[name.size() - 1] && name[name.size() - 1] <= '9');
+}
+
+Differentiator::Differentiator(std::unique_ptr<TimingSolver>* s,
+                               time::Span& t, ArrayCache& ac)
+: envs{{0, ""}}, prog{InputProgram}, solver{s}, solverTimeout{t},
+  arrayCache{ac}, visitorA{'a', ac}, visitorB{'b', ac}
+{}
+
+void
+Differentiator::extract(ExecutionState* a, ExecutionState* b,
+                        const std::vector<const Array*>& objects,
+                        const std::vector<Bytes>& values)
+{
+  TestArgs argv;
+  TestOuts outputs;
+  {
+    std::map<std::uint8_t, Bytes> args;
+    for (unsigned i = 0; i < objects.size(); ++i) {
+      const auto& name = objects[i]->name;
+      if (isSymArg(name)) {
+        args[std::atoi(name.c_str() + 3)] = values[i];
+        args[std::atoi(name.c_str() + 3)].push_back(0); // null termination
+      } else if (isSymOut(name.substr(0, name.size() - 2))) {
+        const auto rev = ((name[name.size() - 1] == 'a') ? a : b)->patchNo;
+        outputs[rev].first[name.substr(4, name.size() - 6)] = values[i];
+      }
+    }
+    uint8_t last = 0;
+    for (const auto& p : args) {
+      assert(p.first == last);
+      argv.push_back(p.second);
+      last++;
+    }
+  }
+
+  char buffer[128]; // output buffer for concrete execution
+  for (const auto& rev : this->envs) {
+    auto& envs = rev.second;
+    pid_t pid;
+    std::vector<const char*> argp {this->prog.c_str()};
+    for (const auto& v : argv)
+      argp.push_back((const char *) v.data());
+    argp.push_back(NULL);
+
+    int fildes[2];
+    auto err = pipe(fildes);
+    assert(!err);
+    posix_spawn_file_actions_t action;
+    posix_spawn_file_actions_init(&action);
+    posix_spawn_file_actions_addclose(&action, fildes[0]);
+    posix_spawn_file_actions_adddup2(&action, fildes[1], 1);
+    char *const envp[] = {const_cast<char* const>(envs.c_str()), NULL};
+    err = posix_spawn(&pid, this->prog.c_str(), &action, NULL,
+                      const_cast<char* const *>(argp.data()),
+                      envs.empty() ? NULL : envp);
+    assert(!err);
+    close(fildes[1]);
+    for (unsigned char n; n = read(fildes[0], buffer, sizeof(buffer));) {
+      assert(n >= 0);
+      for (unsigned char i = 0; i < n; ++i)
+        outputs[rev.first].second.push_back(buffer[i]);
+    }
+    outputs[rev.first].second.push_back(0); // null termination
+    posix_spawn_file_actions_destroy(&action);
+  }
+  this->tests[argv] = outputs;
+
+  // :var :val cluster
+  std::map<std::string, std::map<Bytes, std::set<std::uint64_t>>> revOut;
+  for (auto& o : outputs) {
+    for (auto& var : o.second.first)
+      revOut[var.first][var.second].insert(o.first);
+    revOut[""][o.second.second].insert(o.first); // stdout
+  }
+
+  for (auto& vp : revOut)
+    for (auto& p : vp.second)
+      for (auto& q : vp.second) {
+        if (&p == &q)
+          continue;
+        for (std::uint64_t i : p.second)
+          for (std::uint64_t j : q.second)
+            if (i < j)
+              this->done.emplace(std::make_pair(i, j), &this->tests[argv]);
+      }
+}
+
+void
+Differentiator::search(ExecutionState* latest)
+{
+  if (!this->exits[latest->patchNo].insert(latest).second)
+    return; // skip when seen before
+  for (const auto& rev : this->exits) {
+    if (rev.first == latest->patchNo)
+      continue;
+    if (rev.first < latest->patchNo) {
+      if (this->done.find(std::make_pair(rev.first, latest->patchNo))
+          != this->done.end())
+        continue;
+    } else {
+      if (this->done.find(std::make_pair(latest->patchNo, rev.first))
+          != this->done.end())
+        continue;
+    }
+    for (const auto& state : rev.second) {
+      ConstraintSet cmbnSet;
+      {
+        std::set<ref<Expr>> combination;
+        for (auto const& constraint : state->constraints)
+          combination.insert(this->visitorA.visit(constraint));
+        for (auto const& constraint : latest->constraints)
+          combination.insert(this->visitorB.visit(constraint));
+        for (auto const& constraint : combination)
+          cmbnSet.push_back(constraint);
+      }
+
+      std::vector<const Array*> objects;
+      ref<Expr> distinction;
+      for (const auto& sym : state->symbolics) {
+        if (isSymArg(sym.second->name)) {
+          objects.push_back(sym.second);
+          continue;
+        }
+        if (!isSymOut(sym.second->name))
+          continue;
+        unsigned char i = sym.second->name.size() - 1;
+        while (sym.second->name[i] != '!' && sym.second->name[i] != '_')
+          --i;
+        const auto& name = sym.second->name[i] == '_'
+                         ? sym.second->name.substr(0, i) : sym.second->name;
+        objects.push_back(arrayCache.CreateArray(name + "!a",
+                                                 sym.second->size));
+        objects.push_back(arrayCache.CreateArray(name + "!b",
+                                                 sym.second->size));
+
+        // FIXME: impossible to use visitor hash
+        for (const auto& a : this->visitorA.outVars[name]) {
+          const auto ne = NotExpr::create(EqExpr::create(a.second,
+            this->visitorB.outVars[name][a.first]));
+          if (distinction.get() == nullptr)
+            distinction = ne;
+          else
+            distinction = OrExpr::create(ne, distinction);
+        }
+      }
+
+      if (distinction.get() == nullptr)
+        continue; // no common symbolic
+      cmbnSet.push_back(distinction);
+      std::vector<Bytes> values;
+      std::unique_ptr<TimingSolver>& solver = *this->solver; // do judge!
+      solver->setTimeout(solverTimeout);
+      bool success = solver->getInitialValues(cmbnSet, objects, values,
+                                              state->queryMetaData);
+      solver->setTimeout(time::Span());
+      if (!success)
+        continue;
+      this->extract(latest, state, objects, values);
+      assert(!this->tests.empty());
+      break; // one diff test per termination for now
+    }
+  }
+}
+
+void
+Differentiator::log()
+{
+  for (const auto& t : this->tests) {
+    std::map<Bytes, std::set<std::uint64_t>> stdoutClusters;
+    std::map<std::pair<std::string, Bytes>,
+             std::set<std::uint64_t>> varClusters;
+    for (const auto& rev : t.second) {
+      for (const auto& var : rev.second.first)
+        varClusters[var].insert(rev.first);
+      stdoutClusters[rev.second.second].insert(rev.first);
+      if (stdoutClusters.size() > 1) {
+        llvm::errs() << "Args:";
+        for (const auto& s : t.first) {
+          llvm::errs() << ' ';
+          for (const auto c : s)
+            llvm::errs() << (int) c << '.';
+        }
+        llvm::errs() << '\n';
+        for (const auto& so : stdoutClusters) {
+          llvm::errs() << "Revisions:";
+          for (const auto& r : so.second)
+            llvm::errs() << ' ' << r;
+          llvm::errs() << '\n' << std::string((const char*) so.first.data());
+        }
+        llvm::errs() << '\n';
+      }
+    }
+  }
+}