aboutsummaryrefslogtreecommitdiffhomepage
path: root/lib/Solver/Solver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Solver/Solver.cpp')
-rw-r--r--lib/Solver/Solver.cpp643
1 files changed, 643 insertions, 0 deletions
diff --git a/lib/Solver/Solver.cpp b/lib/Solver/Solver.cpp
new file mode 100644
index 00000000..24d3ef86
--- /dev/null
+++ b/lib/Solver/Solver.cpp
@@ -0,0 +1,643 @@
+//===-- Solver.cpp --------------------------------------------------------===//
+//
+// The KLEE Symbolic Virtual Machine
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "klee/Solver.h"
+#include "klee/SolverImpl.h"
+
+#include "SolverStats.h"
+#include "STPBuilder.h"
+
+#include "klee/Constraints.h"
+#include "klee/Expr.h"
+#include "klee/TimerStatIncrementer.h"
+#include "klee/util/Assignment.h"
+#include "klee/util/ExprPPrinter.h"
+#include "klee/util/ExprUtil.h"
+#include "klee/Internal/Support/Timer.h"
+
+#define vc_bvBoolExtract IAMTHESPAWNOFSATAN
+
+#include <cassert>
+#include <map>
+#include <vector>
+
+#include <sys/wait.h>
+#include <sys/ipc.h>
+#include <sys/shm.h>
+
+using namespace klee;
+
+/***/
+
+const char *Solver::validity_to_str(Validity v) {
+ switch (v) {
+ default: return "Unknown";
+ case True: return "True";
+ case False: return "False";
+ }
+}
+
+Solver::~Solver() {
+ delete impl;
+}
+
+SolverImpl::~SolverImpl() {
+}
+
+bool Solver::evaluate(const Query& query, Validity &result) {
+ assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!");
+
+ // Maintain invariants implementation expect.
+ if (query.expr.isConstant()) {
+ result = query.expr.getConstantValue() ? True : False;
+ return true;
+ }
+
+ return impl->computeValidity(query, result);
+}
+
+bool SolverImpl::computeValidity(const Query& query, Solver::Validity &result) {
+ bool isTrue, isFalse;
+ if (!computeTruth(query, isTrue))
+ return false;
+ if (isTrue) {
+ result = Solver::True;
+ } else {
+ if (!computeTruth(query.negateExpr(), isFalse))
+ return false;
+ result = isFalse ? Solver::False : Solver::Unknown;
+ }
+ return true;
+}
+
+bool Solver::mustBeTrue(const Query& query, bool &result) {
+ assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!");
+
+ // Maintain invariants implementation expect.
+ if (query.expr.isConstant()) {
+ result = query.expr.getConstantValue() ? true : false;
+ return true;
+ }
+
+ return impl->computeTruth(query, result);
+}
+
+bool Solver::mustBeFalse(const Query& query, bool &result) {
+ return mustBeTrue(query.negateExpr(), result);
+}
+
+bool Solver::mayBeTrue(const Query& query, bool &result) {
+ bool res;
+ if (!mustBeFalse(query, res))
+ return false;
+ result = !res;
+ return true;
+}
+
+bool Solver::mayBeFalse(const Query& query, bool &result) {
+ bool res;
+ if (!mustBeTrue(query, res))
+ return false;
+ result = !res;
+ return true;
+}
+
+bool Solver::getValue(const Query& query, ref<Expr> &result) {
+ // Maintain invariants implementation expect.
+ if (query.expr.isConstant()) {
+ result = query.expr;
+ return true;
+ }
+
+ return impl->computeValue(query, result);
+}
+
+bool
+Solver::getInitialValues(const Query& query,
+ const std::vector<const Array*> &objects,
+ std::vector< std::vector<unsigned char> > &values) {
+ bool hasSolution;
+ bool success =
+ impl->computeInitialValues(query, objects, values, hasSolution);
+ // FIXME: Propogate this out.
+ if (!hasSolution)
+ return false;
+
+ return success;
+}
+
+std::pair< ref<Expr>, ref<Expr> > Solver::getRange(const Query& query) {
+ ref<Expr> e = query.expr;
+ Expr::Width width = e.getWidth();
+ uint64_t min, max;
+
+ if (width==1) {
+ Solver::Validity result;
+ if (!evaluate(query, result))
+ assert(0 && "computeValidity failed");
+ switch (result) {
+ case Solver::True:
+ min = max = 1; break;
+ case Solver::False:
+ min = max = 0; break;
+ default:
+ min = 0, max = 1; break;
+ }
+ } else if (e.isConstant()) {
+ min = max = e.getConstantValue();
+ } else {
+ // binary search for # of useful bits
+ uint64_t lo=0, hi=width, mid, bits=0;
+ while (lo<hi) {
+ mid = (lo+hi)/2;
+ bool res;
+ bool success =
+ mustBeTrue(query.withExpr(
+ EqExpr::create(LShrExpr::create(e,
+ ConstantExpr::create(mid,
+ width)),
+ ConstantExpr::create(0, width))),
+ res);
+ assert(success && "FIXME: Unhandled solver failure");
+ if (res) {
+ hi = mid;
+ } else {
+ lo = mid+1;
+ }
+
+ bits = lo;
+ }
+
+ // could binary search for training zeros and offset
+ // min max but unlikely to be very useful
+
+ // check common case
+ bool res = false;
+ bool success =
+ mayBeTrue(query.withExpr(EqExpr::create(e, ConstantExpr::create(0,
+ width))),
+ res);
+ assert(success && "FIXME: Unhandled solver failure");
+ if (res) {
+ min = 0;
+ } else {
+ // binary search for min
+ lo=0, hi=bits64::maxValueOfNBits(bits);
+ while (lo<hi) {
+ mid = (lo+hi)/2;
+ bool res = false;
+ bool success =
+ mayBeTrue(query.withExpr(UleExpr::create(e,
+ ConstantExpr::create(mid,
+ width))),
+ res);
+ assert(success && "FIXME: Unhandled solver failure");
+ if (res) {
+ hi = mid;
+ } else {
+ lo = mid+1;
+ }
+ }
+
+ min = lo;
+ }
+
+ // binary search for max
+ lo=min, hi=bits64::maxValueOfNBits(bits);
+ while (lo<hi) {
+ mid = (lo+hi)/2;
+ bool res;
+ bool success =
+ mustBeTrue(query.withExpr(UleExpr::create(e,
+ ConstantExpr::create(mid,
+ width))),
+ res);
+ assert(success && "FIXME: Unhandled solver failure");
+ if (res) {
+ hi = mid;
+ } else {
+ lo = mid+1;
+ }
+ }
+
+ max = lo;
+ }
+
+ return std::make_pair(ConstantExpr::create(min, width),
+ ConstantExpr::create(max, width));
+}
+
+/***/
+
+class ValidatingSolver : public SolverImpl {
+private:
+ Solver *solver, *oracle;
+
+public:
+ ValidatingSolver(Solver *_solver, Solver *_oracle)
+ : solver(_solver), oracle(_oracle) {}
+ ~ValidatingSolver() { delete solver; }
+
+ bool computeValidity(const Query&, Solver::Validity &result);
+ bool computeTruth(const Query&, bool &isValid);
+ bool computeValue(const Query&, ref<Expr> &result);
+ bool computeInitialValues(const Query&,
+ const std::vector<const Array*> &objects,
+ std::vector< std::vector<unsigned char> > &values,
+ bool &hasSolution);
+};
+
+bool ValidatingSolver::computeTruth(const Query& query,
+ bool &isValid) {
+ bool answer;
+
+ if (!solver->impl->computeTruth(query, isValid))
+ return false;
+ if (!oracle->impl->computeTruth(query, answer))
+ return false;
+
+ if (isValid != answer)
+ assert(0 && "invalid solver result (computeTruth)");
+
+ return true;
+}
+
+bool ValidatingSolver::computeValidity(const Query& query,
+ Solver::Validity &result) {
+ Solver::Validity answer;
+
+ if (!solver->impl->computeValidity(query, result))
+ return false;
+ if (!oracle->impl->computeValidity(query, answer))
+ return false;
+
+ if (result != answer)
+ assert(0 && "invalid solver result (computeValidity)");
+
+ return true;
+}
+
+bool ValidatingSolver::computeValue(const Query& query,
+ ref<Expr> &result) {
+ bool answer;
+
+ if (!solver->impl->computeValue(query, result))
+ return false;
+ // We don't want to compare, but just make sure this is a legal
+ // solution.
+ if (!oracle->impl->computeTruth(query.withExpr(NeExpr::create(query.expr,
+ result)),
+ answer))
+ return false;
+
+ if (answer)
+ assert(0 && "invalid solver result (computeValue)");
+
+ return true;
+}
+
+bool
+ValidatingSolver::computeInitialValues(const Query& query,
+ const std::vector<const Array*>
+ &objects,
+ std::vector< std::vector<unsigned char> >
+ &values,
+ bool &hasSolution) {
+ bool answer;
+
+ if (!solver->impl->computeInitialValues(query, objects, values,
+ hasSolution))
+ return false;
+
+ if (hasSolution) {
+ // Assert the bindings as constraints, and verify that the
+ // conjunction of the actual constraints is satisfiable.
+ std::vector< ref<Expr> > bindings;
+ for (unsigned i = 0; i != values.size(); ++i) {
+ const Array *array = objects[i];
+ for (unsigned j=0; j<array->size; j++) {
+ unsigned char value = values[i][j];
+ bindings.push_back(EqExpr::create(ReadExpr::create(UpdateList(array,
+ true, 0),
+ ref<Expr>(j, Expr::Int32)),
+ ref<Expr>(value, Expr::Int8)));
+ }
+ }
+ ConstraintManager tmp(bindings);
+ ref<Expr> constraints = Expr::createNot(query.expr);
+ for (ConstraintManager::const_iterator it = query.constraints.begin(),
+ ie = query.constraints.end(); it != ie; ++it)
+ constraints = AndExpr::create(constraints, *it);
+
+ if (!oracle->impl->computeTruth(Query(tmp, constraints), answer))
+ return false;
+ if (!answer)
+ assert(0 && "invalid solver result (computeInitialValues)");
+ } else {
+ if (!oracle->impl->computeTruth(query, answer))
+ return false;
+ if (!answer)
+ assert(0 && "invalid solver result (computeInitialValues)");
+ }
+
+ return true;
+}
+
+Solver *klee::createValidatingSolver(Solver *s, Solver *oracle) {
+ return new Solver(new ValidatingSolver(s, oracle));
+}
+
+/***/
+
+class STPSolverImpl : public SolverImpl {
+private:
+ /// The solver we are part of, for access to public information.
+ STPSolver *solver;
+ VC vc;
+ STPBuilder *builder;
+ double timeout;
+ bool useForkedSTP;
+
+public:
+ STPSolverImpl(STPSolver *_solver, bool _useForkedSTP);
+ ~STPSolverImpl();
+
+ char *getConstraintLog(const Query&);
+ void setTimeout(double _timeout) { timeout = _timeout; }
+
+ bool computeTruth(const Query&, bool &isValid);
+ bool computeValue(const Query&, ref<Expr> &result);
+ bool computeInitialValues(const Query&,
+ const std::vector<const Array*> &objects,
+ std::vector< std::vector<unsigned char> > &values,
+ bool &hasSolution);
+};
+
+static unsigned char *shared_memory_ptr;
+static const unsigned shared_memory_size = 1<<20;
+static int shared_memory_id;
+
+static void stp_error_handler(const char* err_msg) {
+ fprintf(stderr, "error: STP Error: %s\n", err_msg);
+ abort();
+}
+
+STPSolverImpl::STPSolverImpl(STPSolver *_solver, bool _useForkedSTP)
+ : solver(_solver),
+ vc(vc_createValidityChecker()),
+ builder(new STPBuilder(vc)),
+ timeout(0.0),
+ useForkedSTP(_useForkedSTP)
+{
+ assert(vc && "unable to create validity checker");
+ assert(builder && "unable to create STPBuilder");
+
+ vc_registerErrorHandler(::stp_error_handler);
+
+ if (useForkedSTP) {
+ shared_memory_id = shmget(IPC_PRIVATE, shared_memory_size, IPC_CREAT | 0700);
+ assert(shared_memory_id>=0 && "shmget failed");
+ shared_memory_ptr = (unsigned char*) shmat(shared_memory_id, NULL, 0);
+ assert(shared_memory_ptr!=(void*)-1 && "shmat failed");
+ shmctl(shared_memory_id, IPC_RMID, NULL);
+ }
+}
+
+STPSolverImpl::~STPSolverImpl() {
+ delete builder;
+
+ vc_Destroy(vc);
+}
+
+/***/
+
+STPSolver::STPSolver(bool useForkedSTP)
+ : Solver(new STPSolverImpl(this, useForkedSTP))
+{
+}
+
+char *STPSolver::getConstraintLog(const Query &query) {
+ return static_cast<STPSolverImpl*>(impl)->getConstraintLog(query);
+}
+
+void STPSolver::setTimeout(double timeout) {
+ static_cast<STPSolverImpl*>(impl)->setTimeout(timeout);
+}
+
+/***/
+
+char *STPSolverImpl::getConstraintLog(const Query &query) {
+ vc_push(vc);
+ for (std::vector< ref<Expr> >::const_iterator it = query.constraints.begin(),
+ ie = query.constraints.end(); it != ie; ++it)
+ vc_assertFormula(vc, builder->construct(*it));
+ assert(query.expr == ref<Expr>(0, Expr::Bool) &&
+ "Unexpected expression in query!");
+
+ char *buffer;
+ unsigned long length;
+ vc_printQueryStateToBuffer(vc, builder->getFalse(),
+ &buffer, &length, false);
+ vc_pop(vc);
+
+ return buffer;
+}
+
+bool STPSolverImpl::computeTruth(const Query& query,
+ bool &isValid) {
+ std::vector<const Array*> objects;
+ std::vector< std::vector<unsigned char> > values;
+ bool hasSolution;
+
+ if (!computeInitialValues(query, objects, values, hasSolution))
+ return false;
+
+ isValid = !hasSolution;
+ return true;
+}
+
+bool STPSolverImpl::computeValue(const Query& query,
+ ref<Expr> &result) {
+ std::vector<const Array*> objects;
+ std::vector< std::vector<unsigned char> > values;
+ bool hasSolution;
+
+ // Find the object used in the expression, and compute an assignment
+ // for them.
+ findSymbolicObjects(query.expr, objects);
+ if (!computeInitialValues(query.withFalse(), objects, values, hasSolution))
+ return false;
+ assert(hasSolution && "state has invalid constraint set");
+
+ // Evaluate the expression with the computed assignment.
+ Assignment a(objects, values);
+ result = a.evaluate(query.expr);
+
+ return true;
+}
+
+static void runAndGetCex(::VC vc, STPBuilder *builder, ::VCExpr q,
+ const std::vector<const Array*> &objects,
+ std::vector< std::vector<unsigned char> > &values,
+ bool &hasSolution) {
+ // XXX I want to be able to timeout here, safely
+ hasSolution = !vc_query(vc, q);
+
+ if (hasSolution) {
+ values.reserve(objects.size());
+ for (std::vector<const Array*>::const_iterator
+ it = objects.begin(), ie = objects.end(); it != ie; ++it) {
+ const Array *array = *it;
+ std::vector<unsigned char> data;
+
+ data.reserve(array->size);
+ for (unsigned offset = 0; offset < array->size; offset++) {
+ ExprHandle counter =
+ vc_getCounterExample(vc, builder->getInitialRead(array, offset));
+ unsigned char val = getBVUnsigned(counter);
+ data.push_back(val);
+ }
+
+ values.push_back(data);
+ }
+ }
+}
+
+static void stpTimeoutHandler(int x) {
+ _exit(52);
+}
+
+static bool runAndGetCexForked(::VC vc,
+ STPBuilder *builder,
+ ::VCExpr q,
+ const std::vector<const Array*> &objects,
+ std::vector< std::vector<unsigned char> >
+ &values,
+ bool &hasSolution,
+ double timeout) {
+ unsigned char *pos = shared_memory_ptr;
+ unsigned sum = 0;
+ for (std::vector<const Array*>::const_iterator
+ it = objects.begin(), ie = objects.end(); it != ie; ++it)
+ sum += (*it)->size;
+ assert(sum<shared_memory_size && "not enough shared memory for counterexample");
+
+ fflush(stdout);
+ fflush(stderr);
+ int pid = fork();
+ if (pid==-1) {
+ fprintf(stderr, "error: fork failed (for STP)");
+ return false;
+ }
+
+ if (pid == 0) {
+ if (timeout) {
+ ::alarm(0); /* Turn off alarm so we can safely set signal handler */
+ ::signal(SIGALRM, stpTimeoutHandler);
+ ::alarm(std::max(1, (int)timeout));
+ }
+ unsigned res = vc_query(vc, q);
+ if (!res) {
+ for (std::vector<const Array*>::const_iterator
+ it = objects.begin(), ie = objects.end(); it != ie; ++it) {
+ const Array *array = *it;
+ for (unsigned offset = 0; offset < array->size; offset++) {
+ ExprHandle counter =
+ vc_getCounterExample(vc, builder->getInitialRead(array, offset));
+ *pos++ = getBVUnsigned(counter);
+ }
+ }
+ }
+ _exit(res);
+ } else {
+ int status;
+ int res = waitpid(pid, &status, 0);
+
+ if (res<0) {
+ fprintf(stderr, "error: waitpid() for STP failed");
+ return false;
+ }
+
+ // From timed_run.py: It appears that linux at least will on
+ // "occasion" return a status when the process was terminated by a
+ // signal, so test signal first.
+ if (WIFSIGNALED(status) || !WIFEXITED(status)) {
+ fprintf(stderr, "error: STP did not return successfully");
+ return false;
+ }
+
+ int exitcode = WEXITSTATUS(status);
+ if (exitcode==0) {
+ hasSolution = true;
+ } else if (exitcode==1) {
+ hasSolution = false;
+ } else if (exitcode==52) {
+ fprintf(stderr, "error: STP timed out");
+ return false;
+ } else {
+ fprintf(stderr, "error: STP did not return a recognized code");
+ return false;
+ }
+
+ if (hasSolution) {
+ values = std::vector< std::vector<unsigned char> >(objects.size());
+ unsigned i=0;
+ for (std::vector<const Array*>::const_iterator
+ it = objects.begin(), ie = objects.end(); it != ie; ++it) {
+ const Array *array = *it;
+ std::vector<unsigned char> &data = values[i++];
+ data.insert(data.begin(), pos, pos + array->size);
+ pos += array->size;
+ }
+ }
+
+ return true;
+ }
+}
+
+bool
+STPSolverImpl::computeInitialValues(const Query &query,
+ const std::vector<const Array*>
+ &objects,
+ std::vector< std::vector<unsigned char> >
+ &values,
+ bool &hasSolution) {
+ TimerStatIncrementer t(stats::queryTime);
+
+ vc_push(vc);
+
+ for (ConstraintManager::const_iterator it = query.constraints.begin(),
+ ie = query.constraints.end(); it != ie; ++it)
+ vc_assertFormula(vc, builder->construct(*it));
+
+ ++stats::queries;
+ ++stats::queryCounterexamples;
+
+ ExprHandle stp_e = builder->construct(query.expr);
+
+ bool success;
+ if (useForkedSTP) {
+ success = runAndGetCexForked(vc, builder, stp_e, objects, values,
+ hasSolution, timeout);
+ } else {
+ runAndGetCex(vc, builder, stp_e, objects, values, hasSolution);
+ success = true;
+ }
+
+ if (success) {
+ if (hasSolution)
+ ++stats::queriesInvalid;
+ else
+ ++stats::queriesValid;
+ }
+
+ vc_pop(vc);
+
+ return success;
+}