diff options
author | Daniel Dunbar <daniel@zuster.org> | 2009-05-21 04:36:41 +0000 |
---|---|---|
committer | Daniel Dunbar <daniel@zuster.org> | 2009-05-21 04:36:41 +0000 |
commit | 6f290d8f9e9d7faac295cb51fc96884a18f4ded4 (patch) | |
tree | 46e7d426abc0c9f06ac472ac6f7f9e661b5d78cb /lib/Solver/Solver.cpp | |
parent | a55960edd4dcd7535526de8d2277642522aa0209 (diff) | |
download | klee-6f290d8f9e9d7faac295cb51fc96884a18f4ded4.tar.gz |
Initial KLEE checkin.
- Lots more tweaks, documentation, and web page content is needed, but this should compile & work on OS X & Linux. git-svn-id: https://llvm.org/svn/llvm-project/klee/trunk@72205 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Solver/Solver.cpp')
-rw-r--r-- | lib/Solver/Solver.cpp | 643 |
1 files changed, 643 insertions, 0 deletions
diff --git a/lib/Solver/Solver.cpp b/lib/Solver/Solver.cpp new file mode 100644 index 00000000..24d3ef86 --- /dev/null +++ b/lib/Solver/Solver.cpp @@ -0,0 +1,643 @@ +//===-- Solver.cpp --------------------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" +#include "klee/SolverImpl.h" + +#include "SolverStats.h" +#include "STPBuilder.h" + +#include "klee/Constraints.h" +#include "klee/Expr.h" +#include "klee/TimerStatIncrementer.h" +#include "klee/util/Assignment.h" +#include "klee/util/ExprPPrinter.h" +#include "klee/util/ExprUtil.h" +#include "klee/Internal/Support/Timer.h" + +#define vc_bvBoolExtract IAMTHESPAWNOFSATAN + +#include <cassert> +#include <map> +#include <vector> + +#include <sys/wait.h> +#include <sys/ipc.h> +#include <sys/shm.h> + +using namespace klee; + +/***/ + +const char *Solver::validity_to_str(Validity v) { + switch (v) { + default: return "Unknown"; + case True: return "True"; + case False: return "False"; + } +} + +Solver::~Solver() { + delete impl; +} + +SolverImpl::~SolverImpl() { +} + +bool Solver::evaluate(const Query& query, Validity &result) { + assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!"); + + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr.getConstantValue() ? True : False; + return true; + } + + return impl->computeValidity(query, result); +} + +bool SolverImpl::computeValidity(const Query& query, Solver::Validity &result) { + bool isTrue, isFalse; + if (!computeTruth(query, isTrue)) + return false; + if (isTrue) { + result = Solver::True; + } else { + if (!computeTruth(query.negateExpr(), isFalse)) + return false; + result = isFalse ? Solver::False : Solver::Unknown; + } + return true; +} + +bool Solver::mustBeTrue(const Query& query, bool &result) { + assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!"); + + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr.getConstantValue() ? true : false; + return true; + } + + return impl->computeTruth(query, result); +} + +bool Solver::mustBeFalse(const Query& query, bool &result) { + return mustBeTrue(query.negateExpr(), result); +} + +bool Solver::mayBeTrue(const Query& query, bool &result) { + bool res; + if (!mustBeFalse(query, res)) + return false; + result = !res; + return true; +} + +bool Solver::mayBeFalse(const Query& query, bool &result) { + bool res; + if (!mustBeTrue(query, res)) + return false; + result = !res; + return true; +} + +bool Solver::getValue(const Query& query, ref<Expr> &result) { + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr; + return true; + } + + return impl->computeValue(query, result); +} + +bool +Solver::getInitialValues(const Query& query, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values) { + bool hasSolution; + bool success = + impl->computeInitialValues(query, objects, values, hasSolution); + // FIXME: Propogate this out. + if (!hasSolution) + return false; + + return success; +} + +std::pair< ref<Expr>, ref<Expr> > Solver::getRange(const Query& query) { + ref<Expr> e = query.expr; + Expr::Width width = e.getWidth(); + uint64_t min, max; + + if (width==1) { + Solver::Validity result; + if (!evaluate(query, result)) + assert(0 && "computeValidity failed"); + switch (result) { + case Solver::True: + min = max = 1; break; + case Solver::False: + min = max = 0; break; + default: + min = 0, max = 1; break; + } + } else if (e.isConstant()) { + min = max = e.getConstantValue(); + } else { + // binary search for # of useful bits + uint64_t lo=0, hi=width, mid, bits=0; + while (lo<hi) { + mid = (lo+hi)/2; + bool res; + bool success = + mustBeTrue(query.withExpr( + EqExpr::create(LShrExpr::create(e, + ConstantExpr::create(mid, + width)), + ConstantExpr::create(0, width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + + bits = lo; + } + + // could binary search for training zeros and offset + // min max but unlikely to be very useful + + // check common case + bool res = false; + bool success = + mayBeTrue(query.withExpr(EqExpr::create(e, ConstantExpr::create(0, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + min = 0; + } else { + // binary search for min + lo=0, hi=bits64::maxValueOfNBits(bits); + while (lo<hi) { + mid = (lo+hi)/2; + bool res = false; + bool success = + mayBeTrue(query.withExpr(UleExpr::create(e, + ConstantExpr::create(mid, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + } + + min = lo; + } + + // binary search for max + lo=min, hi=bits64::maxValueOfNBits(bits); + while (lo<hi) { + mid = (lo+hi)/2; + bool res; + bool success = + mustBeTrue(query.withExpr(UleExpr::create(e, + ConstantExpr::create(mid, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + } + + max = lo; + } + + return std::make_pair(ConstantExpr::create(min, width), + ConstantExpr::create(max, width)); +} + +/***/ + +class ValidatingSolver : public SolverImpl { +private: + Solver *solver, *oracle; + +public: + ValidatingSolver(Solver *_solver, Solver *_oracle) + : solver(_solver), oracle(_oracle) {} + ~ValidatingSolver() { delete solver; } + + bool computeValidity(const Query&, Solver::Validity &result); + bool computeTruth(const Query&, bool &isValid); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +bool ValidatingSolver::computeTruth(const Query& query, + bool &isValid) { + bool answer; + + if (!solver->impl->computeTruth(query, isValid)) + return false; + if (!oracle->impl->computeTruth(query, answer)) + return false; + + if (isValid != answer) + assert(0 && "invalid solver result (computeTruth)"); + + return true; +} + +bool ValidatingSolver::computeValidity(const Query& query, + Solver::Validity &result) { + Solver::Validity answer; + + if (!solver->impl->computeValidity(query, result)) + return false; + if (!oracle->impl->computeValidity(query, answer)) + return false; + + if (result != answer) + assert(0 && "invalid solver result (computeValidity)"); + + return true; +} + +bool ValidatingSolver::computeValue(const Query& query, + ref<Expr> &result) { + bool answer; + + if (!solver->impl->computeValue(query, result)) + return false; + // We don't want to compare, but just make sure this is a legal + // solution. + if (!oracle->impl->computeTruth(query.withExpr(NeExpr::create(query.expr, + result)), + answer)) + return false; + + if (answer) + assert(0 && "invalid solver result (computeValue)"); + + return true; +} + +bool +ValidatingSolver::computeInitialValues(const Query& query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + bool answer; + + if (!solver->impl->computeInitialValues(query, objects, values, + hasSolution)) + return false; + + if (hasSolution) { + // Assert the bindings as constraints, and verify that the + // conjunction of the actual constraints is satisfiable. + std::vector< ref<Expr> > bindings; + for (unsigned i = 0; i != values.size(); ++i) { + const Array *array = objects[i]; + for (unsigned j=0; j<array->size; j++) { + unsigned char value = values[i][j]; + bindings.push_back(EqExpr::create(ReadExpr::create(UpdateList(array, + true, 0), + ref<Expr>(j, Expr::Int32)), + ref<Expr>(value, Expr::Int8))); + } + } + ConstraintManager tmp(bindings); + ref<Expr> constraints = Expr::createNot(query.expr); + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + constraints = AndExpr::create(constraints, *it); + + if (!oracle->impl->computeTruth(Query(tmp, constraints), answer)) + return false; + if (!answer) + assert(0 && "invalid solver result (computeInitialValues)"); + } else { + if (!oracle->impl->computeTruth(query, answer)) + return false; + if (!answer) + assert(0 && "invalid solver result (computeInitialValues)"); + } + + return true; +} + +Solver *klee::createValidatingSolver(Solver *s, Solver *oracle) { + return new Solver(new ValidatingSolver(s, oracle)); +} + +/***/ + +class STPSolverImpl : public SolverImpl { +private: + /// The solver we are part of, for access to public information. + STPSolver *solver; + VC vc; + STPBuilder *builder; + double timeout; + bool useForkedSTP; + +public: + STPSolverImpl(STPSolver *_solver, bool _useForkedSTP); + ~STPSolverImpl(); + + char *getConstraintLog(const Query&); + void setTimeout(double _timeout) { timeout = _timeout; } + + bool computeTruth(const Query&, bool &isValid); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +static unsigned char *shared_memory_ptr; +static const unsigned shared_memory_size = 1<<20; +static int shared_memory_id; + +static void stp_error_handler(const char* err_msg) { + fprintf(stderr, "error: STP Error: %s\n", err_msg); + abort(); +} + +STPSolverImpl::STPSolverImpl(STPSolver *_solver, bool _useForkedSTP) + : solver(_solver), + vc(vc_createValidityChecker()), + builder(new STPBuilder(vc)), + timeout(0.0), + useForkedSTP(_useForkedSTP) +{ + assert(vc && "unable to create validity checker"); + assert(builder && "unable to create STPBuilder"); + + vc_registerErrorHandler(::stp_error_handler); + + if (useForkedSTP) { + shared_memory_id = shmget(IPC_PRIVATE, shared_memory_size, IPC_CREAT | 0700); + assert(shared_memory_id>=0 && "shmget failed"); + shared_memory_ptr = (unsigned char*) shmat(shared_memory_id, NULL, 0); + assert(shared_memory_ptr!=(void*)-1 && "shmat failed"); + shmctl(shared_memory_id, IPC_RMID, NULL); + } +} + +STPSolverImpl::~STPSolverImpl() { + delete builder; + + vc_Destroy(vc); +} + +/***/ + +STPSolver::STPSolver(bool useForkedSTP) + : Solver(new STPSolverImpl(this, useForkedSTP)) +{ +} + +char *STPSolver::getConstraintLog(const Query &query) { + return static_cast<STPSolverImpl*>(impl)->getConstraintLog(query); +} + +void STPSolver::setTimeout(double timeout) { + static_cast<STPSolverImpl*>(impl)->setTimeout(timeout); +} + +/***/ + +char *STPSolverImpl::getConstraintLog(const Query &query) { + vc_push(vc); + for (std::vector< ref<Expr> >::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + vc_assertFormula(vc, builder->construct(*it)); + assert(query.expr == ref<Expr>(0, Expr::Bool) && + "Unexpected expression in query!"); + + char *buffer; + unsigned long length; + vc_printQueryStateToBuffer(vc, builder->getFalse(), + &buffer, &length, false); + vc_pop(vc); + + return buffer; +} + +bool STPSolverImpl::computeTruth(const Query& query, + bool &isValid) { + std::vector<const Array*> objects; + std::vector< std::vector<unsigned char> > values; + bool hasSolution; + + if (!computeInitialValues(query, objects, values, hasSolution)) + return false; + + isValid = !hasSolution; + return true; +} + +bool STPSolverImpl::computeValue(const Query& query, + ref<Expr> &result) { + std::vector<const Array*> objects; + std::vector< std::vector<unsigned char> > values; + bool hasSolution; + + // Find the object used in the expression, and compute an assignment + // for them. + findSymbolicObjects(query.expr, objects); + if (!computeInitialValues(query.withFalse(), objects, values, hasSolution)) + return false; + assert(hasSolution && "state has invalid constraint set"); + + // Evaluate the expression with the computed assignment. + Assignment a(objects, values); + result = a.evaluate(query.expr); + + return true; +} + +static void runAndGetCex(::VC vc, STPBuilder *builder, ::VCExpr q, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution) { + // XXX I want to be able to timeout here, safely + hasSolution = !vc_query(vc, q); + + if (hasSolution) { + values.reserve(objects.size()); + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + std::vector<unsigned char> data; + + data.reserve(array->size); + for (unsigned offset = 0; offset < array->size; offset++) { + ExprHandle counter = + vc_getCounterExample(vc, builder->getInitialRead(array, offset)); + unsigned char val = getBVUnsigned(counter); + data.push_back(val); + } + + values.push_back(data); + } + } +} + +static void stpTimeoutHandler(int x) { + _exit(52); +} + +static bool runAndGetCexForked(::VC vc, + STPBuilder *builder, + ::VCExpr q, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution, + double timeout) { + unsigned char *pos = shared_memory_ptr; + unsigned sum = 0; + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) + sum += (*it)->size; + assert(sum<shared_memory_size && "not enough shared memory for counterexample"); + + fflush(stdout); + fflush(stderr); + int pid = fork(); + if (pid==-1) { + fprintf(stderr, "error: fork failed (for STP)"); + return false; + } + + if (pid == 0) { + if (timeout) { + ::alarm(0); /* Turn off alarm so we can safely set signal handler */ + ::signal(SIGALRM, stpTimeoutHandler); + ::alarm(std::max(1, (int)timeout)); + } + unsigned res = vc_query(vc, q); + if (!res) { + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + for (unsigned offset = 0; offset < array->size; offset++) { + ExprHandle counter = + vc_getCounterExample(vc, builder->getInitialRead(array, offset)); + *pos++ = getBVUnsigned(counter); + } + } + } + _exit(res); + } else { + int status; + int res = waitpid(pid, &status, 0); + + if (res<0) { + fprintf(stderr, "error: waitpid() for STP failed"); + return false; + } + + // From timed_run.py: It appears that linux at least will on + // "occasion" return a status when the process was terminated by a + // signal, so test signal first. + if (WIFSIGNALED(status) || !WIFEXITED(status)) { + fprintf(stderr, "error: STP did not return successfully"); + return false; + } + + int exitcode = WEXITSTATUS(status); + if (exitcode==0) { + hasSolution = true; + } else if (exitcode==1) { + hasSolution = false; + } else if (exitcode==52) { + fprintf(stderr, "error: STP timed out"); + return false; + } else { + fprintf(stderr, "error: STP did not return a recognized code"); + return false; + } + + if (hasSolution) { + values = std::vector< std::vector<unsigned char> >(objects.size()); + unsigned i=0; + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + std::vector<unsigned char> &data = values[i++]; + data.insert(data.begin(), pos, pos + array->size); + pos += array->size; + } + } + + return true; + } +} + +bool +STPSolverImpl::computeInitialValues(const Query &query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + TimerStatIncrementer t(stats::queryTime); + + vc_push(vc); + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + vc_assertFormula(vc, builder->construct(*it)); + + ++stats::queries; + ++stats::queryCounterexamples; + + ExprHandle stp_e = builder->construct(query.expr); + + bool success; + if (useForkedSTP) { + success = runAndGetCexForked(vc, builder, stp_e, objects, values, + hasSolution, timeout); + } else { + runAndGetCex(vc, builder, stp_e, objects, values, hasSolution); + success = true; + } + + if (success) { + if (hasSolution) + ++stats::queriesInvalid; + else + ++stats::queriesValid; + } + + vc_pop(vc); + + return success; +} |