diff options
Diffstat (limited to 'lib/Solver')
-rw-r--r-- | lib/Solver/CachingSolver.cpp | 241 | ||||
-rw-r--r-- | lib/Solver/CexCachingSolver.cpp | 313 | ||||
-rw-r--r-- | lib/Solver/ConstantDivision.cpp | 146 | ||||
-rw-r--r-- | lib/Solver/ConstantDivision.h | 51 | ||||
-rw-r--r-- | lib/Solver/FastCexSolver.cpp | 959 | ||||
-rw-r--r-- | lib/Solver/IncompleteSolver.cpp | 136 | ||||
-rw-r--r-- | lib/Solver/IndependentSolver.cpp | 314 | ||||
-rwxr-xr-x | lib/Solver/Makefile | 16 | ||||
-rw-r--r-- | lib/Solver/PCLoggingSolver.cpp | 134 | ||||
-rw-r--r-- | lib/Solver/STPBuilder.cpp | 819 | ||||
-rw-r--r-- | lib/Solver/STPBuilder.h | 125 | ||||
-rw-r--r-- | lib/Solver/Solver.cpp | 643 | ||||
-rw-r--r-- | lib/Solver/SolverStats.cpp | 23 | ||||
-rw-r--r-- | lib/Solver/SolverStats.h | 32 |
14 files changed, 3952 insertions, 0 deletions
diff --git a/lib/Solver/CachingSolver.cpp b/lib/Solver/CachingSolver.cpp new file mode 100644 index 00000000..517e133b --- /dev/null +++ b/lib/Solver/CachingSolver.cpp @@ -0,0 +1,241 @@ +//===-- CachingSolver.cpp - Caching expression solver ---------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + + +#include "klee/Solver.h" + +#include "klee/Constraints.h" +#include "klee/Expr.h" +#include "klee/IncompleteSolver.h" +#include "klee/SolverImpl.h" + +#include "SolverStats.h" + +#include <tr1/unordered_map> + +using namespace klee; + +class CachingSolver : public SolverImpl { +private: + ref<Expr> canonicalizeQuery(ref<Expr> originalQuery, + bool &negationUsed); + + void cacheInsert(const Query& query, + IncompleteSolver::PartialValidity result); + + bool cacheLookup(const Query& query, + IncompleteSolver::PartialValidity &result); + + struct CacheEntry { + CacheEntry(const ConstraintManager &c, ref<Expr> q) + : constraints(c), query(q) {} + + CacheEntry(const CacheEntry &ce) + : constraints(ce.constraints), query(ce.query) {} + + ConstraintManager constraints; + ref<Expr> query; + + bool operator==(const CacheEntry &b) const { + return constraints==b.constraints && *query.get()==*b.query.get(); + } + }; + + struct CacheEntryHash { + unsigned operator()(const CacheEntry &ce) const { + unsigned result = ce.query.hash(); + + for (ConstraintManager::constraint_iterator it = ce.constraints.begin(); + it != ce.constraints.end(); ++it) + result ^= it->hash(); + + return result; + } + }; + + typedef std::tr1::unordered_map<CacheEntry, + IncompleteSolver::PartialValidity, + CacheEntryHash> cache_map; + + Solver *solver; + cache_map cache; + +public: + CachingSolver(Solver *s) : solver(s) {} + ~CachingSolver() { cache.clear(); delete solver; } + + bool computeValidity(const Query&, Solver::Validity &result); + bool computeTruth(const Query&, bool &isValid); + bool computeValue(const Query& query, ref<Expr> &result) { + return solver->impl->computeValue(query, result); + } + bool computeInitialValues(const Query& query, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution) { + return solver->impl->computeInitialValues(query, objects, values, + hasSolution); + } +}; + +/** @returns the canonical version of the given query. The reference + negationUsed is set to true if the original query was negated in + the canonicalization process. */ +ref<Expr> CachingSolver::canonicalizeQuery(ref<Expr> originalQuery, + bool &negationUsed) { + ref<Expr> negatedQuery = Expr::createNot(originalQuery); + + // select the "smaller" query to the be canonical representation + if (originalQuery.compare(negatedQuery) < 0) { + negationUsed = false; + return originalQuery; + } else { + negationUsed = true; + return negatedQuery; + } +} + +/** @returns true on a cache hit, false of a cache miss. Reference + value result only valid on a cache hit. */ +bool CachingSolver::cacheLookup(const Query& query, + IncompleteSolver::PartialValidity &result) { + bool negationUsed; + ref<Expr> canonicalQuery = canonicalizeQuery(query.expr, negationUsed); + + CacheEntry ce(query.constraints, canonicalQuery); + cache_map::iterator it = cache.find(ce); + + if (it != cache.end()) { + result = (negationUsed ? + IncompleteSolver::negatePartialValidity(it->second) : + it->second); + return true; + } + + return false; +} + +/// Inserts the given query, result pair into the cache. +void CachingSolver::cacheInsert(const Query& query, + IncompleteSolver::PartialValidity result) { + bool negationUsed; + ref<Expr> canonicalQuery = canonicalizeQuery(query.expr, negationUsed); + + CacheEntry ce(query.constraints, canonicalQuery); + IncompleteSolver::PartialValidity cachedResult = + (negationUsed ? IncompleteSolver::negatePartialValidity(result) : result); + + cache.insert(std::make_pair(ce, cachedResult)); +} + +bool CachingSolver::computeValidity(const Query& query, + Solver::Validity &result) { + IncompleteSolver::PartialValidity cachedResult; + bool tmp, cacheHit = cacheLookup(query, cachedResult); + + if (cacheHit) { + ++stats::queryCacheHits; + + switch(cachedResult) { + case IncompleteSolver::MustBeTrue: + result = Solver::True; + return true; + case IncompleteSolver::MustBeFalse: + result = Solver::False; + return true; + case IncompleteSolver::TrueOrFalse: + result = Solver::Unknown; + return true; + case IncompleteSolver::MayBeTrue: { + if (!solver->impl->computeTruth(query, tmp)) + return false; + if (tmp) { + cacheInsert(query, IncompleteSolver::MustBeTrue); + result = Solver::True; + return true; + } else { + cacheInsert(query, IncompleteSolver::TrueOrFalse); + result = Solver::Unknown; + return true; + } + } + case IncompleteSolver::MayBeFalse: { + if (!solver->impl->computeTruth(query.negateExpr(), tmp)) + return false; + if (tmp) { + cacheInsert(query, IncompleteSolver::MustBeFalse); + result = Solver::False; + return true; + } else { + cacheInsert(query, IncompleteSolver::TrueOrFalse); + result = Solver::Unknown; + return true; + } + } + default: assert(0 && "unreachable"); + } + } + + ++stats::queryCacheMisses; + + if (!solver->impl->computeValidity(query, result)) + return false; + + switch (result) { + case Solver::True: + cachedResult = IncompleteSolver::MustBeTrue; break; + case Solver::False: + cachedResult = IncompleteSolver::MustBeFalse; break; + default: + cachedResult = IncompleteSolver::TrueOrFalse; break; + } + + cacheInsert(query, cachedResult); + return true; +} + +bool CachingSolver::computeTruth(const Query& query, + bool &isValid) { + IncompleteSolver::PartialValidity cachedResult; + bool cacheHit = cacheLookup(query, cachedResult); + + // a cached result of MayBeTrue forces us to check whether + // a False assignment exists. + if (cacheHit && cachedResult != IncompleteSolver::MayBeTrue) { + ++stats::queryCacheHits; + isValid = (cachedResult == IncompleteSolver::MustBeTrue); + return true; + } + + ++stats::queryCacheMisses; + + // cache miss: query solver + if (!solver->impl->computeTruth(query, isValid)) + return false; + + if (isValid) { + cachedResult = IncompleteSolver::MustBeTrue; + } else if (cacheHit) { + // We know a true assignment exists, and query isn't valid, so + // must be TrueOrFalse. + assert(cachedResult == IncompleteSolver::MayBeTrue); + cachedResult = IncompleteSolver::TrueOrFalse; + } else { + cachedResult = IncompleteSolver::MayBeFalse; + } + + cacheInsert(query, cachedResult); + return true; +} + +/// + +Solver *klee::createCachingSolver(Solver *_solver) { + return new Solver(new CachingSolver(_solver)); +} diff --git a/lib/Solver/CexCachingSolver.cpp b/lib/Solver/CexCachingSolver.cpp new file mode 100644 index 00000000..79bc985d --- /dev/null +++ b/lib/Solver/CexCachingSolver.cpp @@ -0,0 +1,313 @@ +//===-- CexCachingSolver.cpp ----------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" + +#include "klee/Constraints.h" +#include "klee/Expr.h" +#include "klee/SolverImpl.h" +#include "klee/TimerStatIncrementer.h" +#include "klee/util/Assignment.h" +#include "klee/util/ExprUtil.h" +#include "klee/util/ExprVisitor.h" +#include "klee/Internal/ADT/MapOfSets.h" + +#include "SolverStats.h" + +#include "llvm/Support/CommandLine.h" + +using namespace klee; +using namespace llvm; + +namespace { + cl::opt<bool> + DebugCexCacheCheckBinding("debug-cex-cache-check-binding"); + + cl::opt<bool> + CexCacheTryAll("cex-cache-try-all", + cl::desc("try substituting all counterexamples before asking STP"), + cl::init(false)); + + cl::opt<bool> + CexCacheExperimental("cex-cache-exp", cl::init(false)); + +} + +/// + +typedef std::set< ref<Expr> > KeyType; + +struct AssignmentLessThan { + bool operator()(const Assignment *a, const Assignment *b) { + return a->bindings < b->bindings; + } +}; + + +class CexCachingSolver : public SolverImpl { + typedef std::set<Assignment*, AssignmentLessThan> assignmentsTable_ty; + + Solver *solver; + + MapOfSets<ref<Expr>, Assignment*> cache; + // memo table + assignmentsTable_ty assignmentsTable; + + bool searchForAssignment(KeyType &key, + Assignment *&result); + + bool lookupAssignment(const Query& query, Assignment *&result); + + bool getAssignment(const Query& query, Assignment *&result); + +public: + CexCachingSolver(Solver *_solver) : solver(_solver) {} + ~CexCachingSolver(); + + bool computeTruth(const Query&, bool &isValid); + bool computeValidity(const Query&, Solver::Validity &result); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +/// + +struct NullAssignment { + bool operator()(Assignment *a) const { return !a; } +}; + +struct NonNullAssignment { + bool operator()(Assignment *a) const { return a!=0; } +}; + +struct NullOrSatisfyingAssignment { + KeyType &key; + + NullOrSatisfyingAssignment(KeyType &_key) : key(_key) {} + + bool operator()(Assignment *a) const { + return !a || a->satisfies(key.begin(), key.end()); + } +}; + +bool CexCachingSolver::searchForAssignment(KeyType &key, Assignment *&result) { + Assignment * const *lookup = cache.lookup(key); + if (lookup) { + result = *lookup; + return true; + } + + if (CexCacheTryAll) { + Assignment **lookup = cache.findSuperset(key, NonNullAssignment()); + if (!lookup) lookup = cache.findSubset(key, NullAssignment()); + if (lookup) { + result = *lookup; + return true; + } + for (assignmentsTable_ty::iterator it = assignmentsTable.begin(), + ie = assignmentsTable.end(); it != ie; ++it) { + Assignment *a = *it; + if (a->satisfies(key.begin(), key.end())) { + result = a; + return true; + } + } + } else { + // XXX which order? one is sure to be better + Assignment **lookup = cache.findSuperset(key, NonNullAssignment()); + if (!lookup) lookup = cache.findSubset(key, NullOrSatisfyingAssignment(key)); + if (lookup) { + result = *lookup; + return true; + } + } + + return false; +} + +bool CexCachingSolver::lookupAssignment(const Query &query, + Assignment *&result) { + KeyType key(query.constraints.begin(), query.constraints.end()); + ref<Expr> neg = Expr::createNot(query.expr); + if (neg.isConstant()) { + if (!neg.getConstantValue()) { + result = (Assignment*) 0; + return true; + } + } else { + key.insert(neg); + } + + return searchForAssignment(key, result); +} + +bool CexCachingSolver::getAssignment(const Query& query, Assignment *&result) { + KeyType key(query.constraints.begin(), query.constraints.end()); + ref<Expr> neg = Expr::createNot(query.expr); + if (neg.isConstant()) { + if (!neg.getConstantValue()) { + result = (Assignment*) 0; + return true; + } + } else { + key.insert(neg); + } + + if (!searchForAssignment(key, result)) { + // need to solve + + std::vector<const Array*> objects; + findSymbolicObjects(key.begin(), key.end(), objects); + + std::vector< std::vector<unsigned char> > values; + bool hasSolution; + if (!solver->impl->computeInitialValues(query, objects, values, + hasSolution)) + return false; + + Assignment *binding; + if (hasSolution) { + binding = new Assignment(objects, values); + + // memoization + std::pair<assignmentsTable_ty::iterator, bool> + res = assignmentsTable.insert(binding); + if (!res.second) { + delete binding; + binding = *res.first; + } + + if (DebugCexCacheCheckBinding) + assert(binding->satisfies(key.begin(), key.end())); + } else { + binding = (Assignment*) 0; + } + + result = binding; + cache.insert(key, binding); + } + + return true; +} + +/// + +CexCachingSolver::~CexCachingSolver() { + cache.clear(); + delete solver; + for (assignmentsTable_ty::iterator it = assignmentsTable.begin(), + ie = assignmentsTable.end(); it != ie; ++it) + delete *it; +} + +bool CexCachingSolver::computeValidity(const Query& query, + Solver::Validity &result) { + TimerStatIncrementer t(stats::cexCacheTime); + Assignment *a; + if (!getAssignment(query.withFalse(), a)) + return false; + assert(a && "computeValidity() must have assignment"); + ref<Expr> q = a->evaluate(query.expr); + assert(q.isConstant() && "assignment evaluation did not result in constant"); + + if (q.getConstantValue()) { + if (!getAssignment(query, a)) + return false; + result = !a ? Solver::True : Solver::Unknown; + } else { + if (!getAssignment(query.negateExpr(), a)) + return false; + result = !a ? Solver::False : Solver::Unknown; + } + + return true; +} + +bool CexCachingSolver::computeTruth(const Query& query, + bool &isValid) { + TimerStatIncrementer t(stats::cexCacheTime); + + // There is a small amount of redundancy here. We only need to know + // truth and do not really need to compute an assignment. This means + // that we could check the cache to see if we already know that + // state ^ query has no assignment. In that case, by the validity of + // state, we know that state ^ !query must have an assignment, and + // so query cannot be true (valid). This does get hits, but doesn't + // really seem to be worth the overhead. + + if (CexCacheExperimental) { + Assignment *a; + if (lookupAssignment(query.negateExpr(), a) && !a) + return false; + } + + Assignment *a; + if (!getAssignment(query, a)) + return false; + + isValid = !a; + + return true; +} + +bool CexCachingSolver::computeValue(const Query& query, + ref<Expr> &result) { + TimerStatIncrementer t(stats::cexCacheTime); + + Assignment *a; + if (!getAssignment(query.withFalse(), a)) + return false; + assert(a && "computeValue() must have assignment"); + result = a->evaluate(query.expr); + assert(result.isConstant() && + "assignment evaluation did not result in constant"); + return true; +} + +bool +CexCachingSolver::computeInitialValues(const Query& query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + TimerStatIncrementer t(stats::cexCacheTime); + Assignment *a; + if (!getAssignment(query, a)) + return false; + hasSolution = !!a; + + if (!a) + return true; + + // FIXME: We should use smarter assignment for result so we don't + // need redundant copy. + values = std::vector< std::vector<unsigned char> >(objects.size()); + for (unsigned i=0; i < objects.size(); ++i) { + const Array *os = objects[i]; + Assignment::bindings_ty::iterator it = a->bindings.find(os); + + if (it == a->bindings.end()) { + values[i] = std::vector<unsigned char>(os->size, 0); + } else { + values[i] = it->second; + } + } + + return true; +} + +/// + +Solver *klee::createCexCachingSolver(Solver *_solver) { + return new Solver(new CexCachingSolver(_solver)); +} diff --git a/lib/Solver/ConstantDivision.cpp b/lib/Solver/ConstantDivision.cpp new file mode 100644 index 00000000..c8f8f3d5 --- /dev/null +++ b/lib/Solver/ConstantDivision.cpp @@ -0,0 +1,146 @@ +//===-- ConstantDivision.cpp ----------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "ConstantDivision.h" + +#include "klee/util/Bits.h" + +#include <algorithm> +#include <cassert> + +namespace klee { + +/* Macros and functions which define the basic bit-level operations + * needed to implement quick division operations. + * + * Based on Hacker's Delight (2003) by Henry S. Warren, Jr. + */ + +/* 32 -- number of bits in the integer type on this architecture */ + +/* 2^32 -- NUM_BITS=32 requires 64 bits to represent this unsigned value */ +#define TWO_TO_THE_32_U64 (1ULL << 32) + +/* 2e31 -- NUM_BITS=32 requires 64 bits to represent this signed value */ +#define TWO_TO_THE_31_S64 (1LL << 31) + +/* ABS(x) -- positive x */ +#define ABS(x) ( ((x)>0)?x:-(x) ) /* fails if x is the min value of its type */ + +/* XSIGN(x) -- -1 if x<0 and 0 otherwise */ +#define XSIGN(x) ( (x) >> 31 ) + +/* LOG2_CEIL(x) -- logarithm base 2 of x, rounded up */ +#define LOG2_CEIL(x) ( 32 - ldz(x - 1) ) + +/* ones(x) -- counts the number of bits in x with the value 1 */ +static uint32_t ones( register uint32_t x ) { + x -= ((x >> 1) & 0x55555555); + x = (((x >> 2) & 0x33333333) + (x & 0x33333333)); + x = (((x >> 4) + x) & 0x0f0f0f0f); + x += (x >> 8); + x += (x >> 16); + + return( x & 0x0000003f ); +} + +/* ldz(x) -- counts the number of leading zeroes in a 32-bit word */ +static uint32_t ldz( register uint32_t x ) { + x |= (x >> 1); + x |= (x >> 2); + x |= (x >> 4); + x |= (x >> 8); + x |= (x >> 16); + + return 32 - ones(x); +} + +/* exp_base_2(n) -- 2^n computed as an integer */ +static uint32_t exp_base_2( register int32_t n ) { + register uint32_t x = ~n & (n - 32); + x = x >> 31; + return( x << n ); +} + +// A simple algorithm: Iterate over all contiguous regions of 1 bits +// in x starting with the lowest bits. +// +// For a particular range where x is 1 for bits [low,high) then: +// 1) if the range is just one bit, simple add it +// 2) if the range is more than one bit, replace with an add +// of the high bit and a subtract of the low bit. we apply +// one useful optimization: if we were going to add the bit +// below the one we wish to subtract, we simply change that +// add to a subtract instead of subtracting the low bit itself. +// Obviously we must take care when high==64. +void ComputeMultConstants64(uint64_t multiplicand, + uint64_t &add, uint64_t &sub) { + uint64_t x = multiplicand; + add = sub = 0; + + while (x) { + // Determine rightmost contiguous region of 1s. + unsigned low = bits64::indexOfRightmostBit(x); + uint64_t lowbit = 1LL << low; + uint64_t p = x + lowbit; + uint64_t q = bits64::isolateRightmostBit(p); + unsigned high = q ? bits64::indexOfSingleBit(q) : 64; + + if (high==low+1) { // Just one bit... + add |= lowbit; + } else { + // Rewrite as +(1<<high) - (1<<low). + + // Optimize +(1<<x) - (1<<(x+1)) to -(1<<x). + if (low && (add & (lowbit>>1))) { + add ^= lowbit>>1; + sub ^= lowbit>>1; + } else { + sub |= lowbit; + } + + if (high!=64) + add |= 1LL << high; + } + + x = p ^ q; + } + + assert(multiplicand == add - sub); +} + +void ComputeUDivConstants32(uint32_t d, uint32_t &mprime, uint32_t &sh1, + uint32_t &sh2) { + int32_t l = LOG2_CEIL( d ); /* signed so l-1 => -1 when l=0 (see sh2) */ + uint32_t mid = exp_base_2(l) - d; + + mprime = (TWO_TO_THE_32_U64 * mid / d) + 1; + sh1 = std::min( l, 1 ); + sh2 = std::max( l-1, 0 ); +} + +void ComputeSDivConstants32(int32_t d, int32_t &mprime, int32_t &dsign, + int32_t &shpost ) { + uint64_t abs_d = ABS( (int64_t)d ); /* use 64-bits in case d is INT32_MIN */ + + /* LOG2_CEIL works on 32-bits, so we cast abs_d. The only possible value + * outside the 32-bit rep. is 2^31. This is special cased to save computer + * time since 64-bit routines would be overkill. */ + int32_t l = std::max( 1U, LOG2_CEIL((uint32_t)abs_d) ); + if( abs_d == TWO_TO_THE_31_S64 ) l = 31; + + uint32_t mid = exp_base_2( l - 1 ); + uint64_t m = TWO_TO_THE_32_U64 * mid / abs_d + 1ULL; + + mprime = m - TWO_TO_THE_32_U64; /* implicit cast to 32-bits signed */ + dsign = XSIGN( d ); + shpost = l - 1; +} + +} diff --git a/lib/Solver/ConstantDivision.h b/lib/Solver/ConstantDivision.h new file mode 100644 index 00000000..9e3e9c95 --- /dev/null +++ b/lib/Solver/ConstantDivision.h @@ -0,0 +1,51 @@ +//===-- ConstantDivision.h --------------------------------------*- C++ -*-===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef __UTIL_CONSTANTDIVISION_H__ +#define __UTIL_CONSTANTDIVISION_H__ + +#include <stdint.h> + +namespace klee { + +/// ComputeMultConstants64 - Compute add and sub such that add-sub==x, +/// while attempting to minimize the number of bits in add and sub +/// combined. +void ComputeMultConstants64(uint64_t x, uint64_t &add_out, + uint64_t &sub_out); + +/// Compute the constants to perform a quicker equivalent of a division of some +/// 32-bit unsigned integer n by a known constant d (also a 32-bit unsigned +/// integer). The constants to compute n/d without explicit division will be +/// stored in mprime, sh1, and sh2 (unsigned 32-bit integers). +/// +/// @param d - denominator (divisor) +/// +/// @param [out] mprime +/// @param [out] sh1 +/// @param [out] sh2 +void ComputeUDivConstants32(uint32_t d, uint32_t &mprime, uint32_t &sh1, + uint32_t &sh2); + +/// Compute the constants to perform a quicker equivalent of a division of some +/// 32-bit signed integer n by a known constant d (also a 32-bit signed +/// integer). The constants to compute n/d without explicit division will be +/// stored in mprime, dsign, and shpost (signed 32-bit integers). +/// +/// @param d - denominator (divisor) +/// +/// @param [out] mprime +/// @param [out] dsign +/// @param [out] shpost +void ComputeSDivConstants32(int32_t d, int32_t &mprime, int32_t &dsign, + int32_t &shpost); + +} + +#endif diff --git a/lib/Solver/FastCexSolver.cpp b/lib/Solver/FastCexSolver.cpp new file mode 100644 index 00000000..d2bc27c6 --- /dev/null +++ b/lib/Solver/FastCexSolver.cpp @@ -0,0 +1,959 @@ +//===-- FastCexSolver.cpp -------------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" + +#include "klee/Constraints.h" +#include "klee/Expr.h" +#include "klee/IncompleteSolver.h" +#include "klee/util/ExprEvaluator.h" +#include "klee/util/ExprRangeEvaluator.h" +#include "klee/util/ExprVisitor.h" +// FIXME: Use APInt. +#include "klee/Internal/Support/IntEvaluation.h" + +#include <iostream> +#include <sstream> +#include <cassert> +#include <map> +#include <vector> + +using namespace klee; + +/***/ + +//#define LOG +#ifdef LOG +std::ostream *theLog; +#endif + + // Hacker's Delight, pgs 58-63 +static uint64_t minOR(uint64_t a, uint64_t b, + uint64_t c, uint64_t d) { + uint64_t temp, m = ((uint64_t) 1)<<63; + while (m) { + if (~a & c & m) { + temp = (a | m) & -m; + if (temp <= b) { a = temp; break; } + } else if (a & ~c & m) { + temp = (c | m) & -m; + if (temp <= d) { c = temp; break; } + } + m >>= 1; + } + + return a | c; +} +static uint64_t maxOR(uint64_t a, uint64_t b, + uint64_t c, uint64_t d) { + uint64_t temp, m = ((uint64_t) 1)<<63; + + while (m) { + if (b & d & m) { + temp = (b - m) | (m - 1); + if (temp >= a) { b = temp; break; } + temp = (d - m) | (m -1); + if (temp >= c) { d = temp; break; } + } + m >>= 1; + } + + return b | d; +} +static uint64_t minAND(uint64_t a, uint64_t b, + uint64_t c, uint64_t d) { + uint64_t temp, m = ((uint64_t) 1)<<63; + while (m) { + if (~a & ~c & m) { + temp = (a | m) & -m; + if (temp <= b) { a = temp; break; } + temp = (c | m) & -m; + if (temp <= d) { c = temp; break; } + } + m >>= 1; + } + + return a & c; +} +static uint64_t maxAND(uint64_t a, uint64_t b, + uint64_t c, uint64_t d) { + uint64_t temp, m = ((uint64_t) 1)<<63; + while (m) { + if (b & ~d & m) { + temp = (b & ~m) | (m - 1); + if (temp >= a) { b = temp; break; } + } else if (~b & d & m) { + temp = (d & ~m) | (m - 1); + if (temp >= c) { d = temp; break; } + } + m >>= 1; + } + + return b & d; +} + +/// + +class ValueRange { +private: + uint64_t m_min, m_max; + +public: + ValueRange() : m_min(1),m_max(0) {} + ValueRange(uint64_t value) : m_min(value), m_max(value) {} + ValueRange(uint64_t _min, uint64_t _max) : m_min(_min), m_max(_max) {} + ValueRange(const ValueRange &b) : m_min(b.m_min), m_max(b.m_max) {} + + void print(std::ostream &os) const { + if (isFixed()) { + os << m_min; + } else { + os << "[" << m_min << "," << m_max << "]"; + } + } + + bool isEmpty() const { + return m_min>m_max; + } + bool contains(uint64_t value) const { + return this->intersects(ValueRange(value)); + } + bool intersects(const ValueRange &b) const { + return !this->set_intersection(b).isEmpty(); + } + + bool isFullRange(unsigned bits) { + return m_min==0 && m_max==bits64::maxValueOfNBits(bits); + } + + ValueRange set_intersection(const ValueRange &b) const { + return ValueRange(std::max(m_min,b.m_min), std::min(m_max,b.m_max)); + } + ValueRange set_union(const ValueRange &b) const { + return ValueRange(std::min(m_min,b.m_min), std::max(m_max,b.m_max)); + } + ValueRange set_difference(const ValueRange &b) const { + if (b.isEmpty() || b.m_min > m_max || b.m_max < m_min) { // no intersection + return *this; + } else if (b.m_min <= m_min && b.m_max >= m_max) { // empty + return ValueRange(1,0); + } else if (b.m_min <= m_min) { // one range out + // cannot overflow because b.m_max < m_max + return ValueRange(b.m_max+1, m_max); + } else if (b.m_max >= m_max) { + // cannot overflow because b.min > m_min + return ValueRange(m_min, b.m_min-1); + } else { + // two ranges, take bottom + return ValueRange(m_min, b.m_min-1); + } + } + ValueRange binaryAnd(const ValueRange &b) const { + // XXX + assert(!isEmpty() && !b.isEmpty() && "XXX"); + if (isFixed() && b.isFixed()) { + return ValueRange(m_min & b.m_min); + } else { + return ValueRange(minAND(m_min, m_max, b.m_min, b.m_max), + maxAND(m_min, m_max, b.m_min, b.m_max)); + } + } + ValueRange binaryAnd(uint64_t b) const { return binaryAnd(ValueRange(b)); } + ValueRange binaryOr(ValueRange b) const { + // XXX + assert(!isEmpty() && !b.isEmpty() && "XXX"); + if (isFixed() && b.isFixed()) { + return ValueRange(m_min | b.m_min); + } else { + return ValueRange(minOR(m_min, m_max, b.m_min, b.m_max), + maxOR(m_min, m_max, b.m_min, b.m_max)); + } + } + ValueRange binaryOr(uint64_t b) const { return binaryOr(ValueRange(b)); } + ValueRange binaryXor(ValueRange b) const { + if (isFixed() && b.isFixed()) { + return ValueRange(m_min ^ b.m_min); + } else { + uint64_t t = m_max | b.m_max; + while (!bits64::isPowerOfTwo(t)) + t = bits64::withoutRightmostBit(t); + return ValueRange(0, (t<<1)-1); + } + } + + ValueRange binaryShiftLeft(unsigned bits) const { + return ValueRange(m_min<<bits, m_max<<bits); + } + ValueRange binaryShiftRight(unsigned bits) const { + return ValueRange(m_min>>bits, m_max>>bits); + } + + ValueRange concat(const ValueRange &b, unsigned bits) const { + return binaryShiftLeft(bits).binaryOr(b); + } + ValueRange extract(uint64_t lowBit, uint64_t maxBit) const { + return binaryShiftRight(lowBit).binaryAnd(bits64::maxValueOfNBits(maxBit-lowBit)); + } + + ValueRange add(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange sub(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange mul(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange udiv(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange sdiv(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange urem(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + ValueRange srem(const ValueRange &b, unsigned width) const { + return ValueRange(0, bits64::maxValueOfNBits(width)); + } + + // use min() to get value if true (XXX should we add a method to + // make code clearer?) + bool isFixed() const { return m_min==m_max; } + + bool operator==(const ValueRange &b) const { return m_min==b.m_min && m_max==b.m_max; } + bool operator!=(const ValueRange &b) const { return !(*this==b); } + + bool mustEqual(const uint64_t b) const { return m_min==m_max && m_min==b; } + bool mayEqual(const uint64_t b) const { return m_min<=b && m_max>=b; } + + bool mustEqual(const ValueRange &b) const { return isFixed() && b.isFixed() && m_min==b.m_min; } + bool mayEqual(const ValueRange &b) const { return this->intersects(b); } + + uint64_t min() const { + assert(!isEmpty() && "cannot get minimum of empty range"); + return m_min; + } + + uint64_t max() const { + assert(!isEmpty() && "cannot get maximum of empty range"); + return m_max; + } + + int64_t minSigned(unsigned bits) const { + assert((m_min>>bits)==0 && (m_max>>bits)==0 && + "range is outside given number of bits"); + + // if max allows sign bit to be set then it can be smallest value, + // otherwise since the range is not empty, min cannot have a sign + // bit + + uint64_t smallest = ((uint64_t) 1 << (bits-1)); + if (m_max >= smallest) { + return ints::sext(smallest, 64, bits); + } else { + return m_min; + } + } + + int64_t maxSigned(unsigned bits) const { + assert((m_min>>bits)==0 && (m_max>>bits)==0 && + "range is outside given number of bits"); + + uint64_t smallest = ((uint64_t) 1 << (bits-1)); + + // if max and min have sign bit then max is max, otherwise if only + // max has sign bit then max is largest signed integer, otherwise + // max is max + + if (m_min < smallest && m_max >= smallest) { + return smallest - 1; + } else { + return ints::sext(m_max, 64, bits); + } + } +}; + +inline std::ostream &operator<<(std::ostream &os, const ValueRange &vr) { + vr.print(os); + return os; +} + +// used to find all memory object ids and the maximum size of any +// object state that references them (for symbolic size). +class ObjectFinder : public ExprVisitor { +protected: + Action visitRead(const ReadExpr &re) { + addUpdates(re.updates); + return Action::doChildren(); + } + + // XXX nice if this information was cached somewhere, used by + // independence as well right? + void addUpdates(const UpdateList &ul) { + for (const UpdateNode *un=ul.head; un; un=un->next) { + visit(un->index); + visit(un->value); + } + + addObject(*ul.root); + } + +public: + void addObject(const Array& array) { + unsigned id = array.id; + std::map<unsigned,unsigned>::iterator it = results.find(id); + + // FIXME: Not 64-bit size clean. + if (it == results.end()) { + results[id] = (unsigned) array.size; + } else { + it->second = std::max(it->second, (unsigned) array.size); + } + } + +public: + std::map<unsigned, unsigned> results; +}; + +// XXX waste of space, rather have ByteValueRange +typedef ValueRange CexValueData; + +class CexObjectData { +public: + unsigned size; + CexValueData *values; + +public: + CexObjectData(unsigned _size) : size(_size), values(new CexValueData[size]) { + for (unsigned i=0; i<size; i++) + values[i] = ValueRange(0, 255); + } +}; + +class CexRangeEvaluator : public ExprRangeEvaluator<ValueRange> { +public: + std::map<unsigned, CexObjectData> &objectValues; + CexRangeEvaluator(std::map<unsigned, CexObjectData> &_objectValues) + : objectValues(_objectValues) {} + + ValueRange getInitialReadRange(const Array &os, ValueRange index) { + return ValueRange(0, 255); + } +}; + +class CexConstifier : public ExprEvaluator { +protected: + ref<Expr> getInitialValue(const Array& array, unsigned index) { + std::map<unsigned, CexObjectData>::iterator it = + objectValues.find(array.id); + assert(it != objectValues.end() && "missing object?"); + CexObjectData &cod = it->second; + + if (index >= cod.size) { + return ReadExpr::create(UpdateList(&array, true, 0), + ref<Expr>(index, Expr::Int32)); + } else { + CexValueData &cvd = cod.values[index]; + assert(cvd.min() == cvd.max() && "value is not fixed"); + return ref<Expr>(cvd.min(), Expr::Int8); + } + } + +public: + std::map<unsigned, CexObjectData> &objectValues; + CexConstifier(std::map<unsigned, CexObjectData> &_objectValues) + : objectValues(_objectValues) {} +}; + +class CexData { +public: + std::map<unsigned, CexObjectData> objectValues; + +public: + CexData(ObjectFinder &finder) { + for (std::map<unsigned,unsigned>::iterator it = finder.results.begin(), + ie = finder.results.end(); it != ie; ++it) { + objectValues.insert(std::pair<unsigned, CexObjectData>(it->first, + CexObjectData(it->second))); + } + } + ~CexData() { + for (std::map<unsigned, CexObjectData>::iterator it = objectValues.begin(), + ie = objectValues.end(); it != ie; ++it) + delete[] it->second.values; + } + + void forceExprToValue(ref<Expr> e, uint64_t value) { + forceExprToRange(e, CexValueData(value,value)); + } + + void forceExprToRange(ref<Expr> e, CexValueData range) { +#ifdef LOG + // *theLog << "force: " << e << " to " << range << "\n"; +#endif + switch (e.getKind()) { + case Expr::Constant: { + // rather a pity if the constant isn't in the range, but how can + // we use this? + break; + } + + // Special + + case Expr::NotOptimized: break; + + case Expr::Read: { + ReadExpr *re = static_ref_cast<ReadExpr>(e); + const Array *array = re->updates.root; + CexObjectData &cod = objectValues.find(array->id)->second; + + // XXX we need to respect the version here and object state chain + + if (re->index.isConstant() && + re->index.getConstantValue() < array->size) { + CexValueData &cvd = cod.values[re->index.getConstantValue()]; + CexValueData tmp = cvd.set_intersection(range); + + if (tmp.isEmpty()) { + if (range.isFixed()) // ranges conflict, if new one is fixed use that + cvd = range; + } else { + cvd = tmp; + } + } else { + // XXX fatal("XXX not implemented"); + } + + break; + } + + case Expr::Select: { + SelectExpr *se = static_ref_cast<SelectExpr>(e); + ValueRange cond = evalRangeForExpr(se->cond); + if (cond.isFixed()) { + if (cond.min()) { + forceExprToRange(se->trueExpr, range); + } else { + forceExprToRange(se->falseExpr, range); + } + } else { + // XXX imprecise... we have a choice here. One method is to + // simply force both sides into the specified range (since the + // condition is indetermined). This may lose in two ways, the + // first is that the condition chosen may limit further + // restrict the range in each of the children, however this is + // less of a problem as the range will be a superset of legal + // values. The other is if the condition ends up being forced + // by some other constraints, then we needlessly forced one + // side into the given range. + // + // The other method would be to force the condition to one + // side and force that side into the given range. This loses + // when we force the condition to an unsatisfiable value + // (either because the condition cannot be that, or the + // resulting range given that condition is not in the required + // range). + // + // Currently we just force both into the range. A hybrid would + // be to evaluate the ranges for each of the children... if + // one of the ranges happens to already be a subset of the + // required range then it may be preferable to force the + // condition to that side. + forceExprToRange(se->trueExpr, range); + forceExprToRange(se->falseExpr, range); + } + break; + } + + // XXX imprecise... the problem here is that extracting bits + // loses information about what bits are connected across the + // bytes. if a value can be 1 or 256 then either the top or + // lower byte is 0, but just extraction loses this information + // and will allow neither,one,or both to be 1. + // + // we can protect against this in a limited fashion by writing + // the extraction a byte at a time, then checking the evaluated + // value, isolating for that range, and continuing. + case Expr::Concat: { + ConcatExpr *ce = static_ref_cast<ConcatExpr>(e); + if (ce->is2ByteConcat()) { + forceExprToRange(ce->getKid(0), range.extract( 8, 16)); + forceExprToRange(ce->getKid(1), range.extract( 0, 8)); + } + else if (ce->is4ByteConcat()) { + forceExprToRange(ce->getKid(0), range.extract(24, 32)); + forceExprToRange(ce->getKid(1), range.extract(16, 24)); + forceExprToRange(ce->getKid(2), range.extract( 8, 16)); + forceExprToRange(ce->getKid(3), range.extract( 0, 8)); + } + else if (ce->is8ByteConcat()) { + forceExprToRange(ce->getKid(0), range.extract(56, 64)); + forceExprToRange(ce->getKid(1), range.extract(48, 56)); + forceExprToRange(ce->getKid(2), range.extract(40, 48)); + forceExprToRange(ce->getKid(3), range.extract(32, 40)); + forceExprToRange(ce->getKid(4), range.extract(24, 32)); + forceExprToRange(ce->getKid(5), range.extract(16, 24)); + forceExprToRange(ce->getKid(6), range.extract( 8, 16)); + forceExprToRange(ce->getKid(7), range.extract( 0, 8)); + } + + break; + } + + case Expr::Extract: { + // XXX + break; + } + + // Casting + + // Simply intersect the output range with the range of all + // possible outputs and then truncate to the desired number of + // bits. + + // For ZExt this simplifies to just intersection with the + // possible input range. + case Expr::ZExt: { + CastExpr *ce = static_ref_cast<CastExpr>(e); + unsigned inBits = ce->src.getWidth(); + ValueRange input = range.set_intersection(ValueRange(0, bits64::maxValueOfNBits(inBits))); + forceExprToRange(ce->src, input); + break; + } + // For SExt instead of doing the intersection we just take the output range + // minus the impossible values. This is nicer since it is a single interval. + case Expr::SExt: { + CastExpr *ce = static_ref_cast<CastExpr>(e); + unsigned inBits = ce->src.getWidth(); + unsigned outBits = ce->width; + ValueRange output = range.set_difference(ValueRange(1<<(inBits-1), + (bits64::maxValueOfNBits(outBits)- + bits64::maxValueOfNBits(inBits-1)-1))); + ValueRange input = output.binaryAnd(bits64::maxValueOfNBits(inBits)); + forceExprToRange(ce->src, input); + break; + } + + // Binary + + case Expr::And: { + BinaryExpr *be = static_ref_cast<BinaryExpr>(e); + if (be->getWidth()==Expr::Bool) { + if (range.isFixed()) { + ValueRange left = evalRangeForExpr(be->left); + ValueRange right = evalRangeForExpr(be->right); + + if (!range.min()) { + if (left.mustEqual(0) || right.mustEqual(0)) { + // all is well + } else { + // XXX heuristic, which order + + forceExprToValue(be->left, 0); + left = evalRangeForExpr(be->left); + + // see if that worked + if (!left.mustEqual(1)) + forceExprToValue(be->right, 0); + } + } else { + if (!left.mustEqual(1)) forceExprToValue(be->left, 1); + if (!right.mustEqual(1)) forceExprToValue(be->right, 1); + } + } + } else { + // XXX + } + break; + } + + case Expr::Or: { + BinaryExpr *be = static_ref_cast<BinaryExpr>(e); + if (be->getWidth()==Expr::Bool) { + if (range.isFixed()) { + ValueRange left = evalRangeForExpr(be->left); + ValueRange right = evalRangeForExpr(be->right); + + if (range.min()) { + if (left.mustEqual(1) || right.mustEqual(1)) { + // all is well + } else { + // XXX heuristic, which order? + + // force left to value we need + forceExprToValue(be->left, 1); + left = evalRangeForExpr(be->left); + + // see if that worked + if (!left.mustEqual(1)) + forceExprToValue(be->right, 1); + } + } else { + if (!left.mustEqual(0)) forceExprToValue(be->left, 0); + if (!right.mustEqual(0)) forceExprToValue(be->right, 0); + } + } + } else { + // XXX + } + break; + } + + case Expr::Xor: break; + + // Comparison + + case Expr::Eq: { + BinaryExpr *be = static_ref_cast<BinaryExpr>(e); + if (range.isFixed()) { + if (be->left.isConstant()) { + uint64_t value = be->left.getConstantValue(); + if (range.min()) { + forceExprToValue(be->right, value); + } else { + if (value==0) { + forceExprToRange(be->right, + CexValueData(1, + ints::sext(1, + be->right.getWidth(), + 1))); + } else { + // XXX heuristic / lossy, could be better to pick larger range? + forceExprToRange(be->right, CexValueData(0, value-1)); + } + } + } else { + // XXX what now + } + } + break; + } + + case Expr::Ult: { + BinaryExpr *be = static_ref_cast<BinaryExpr>(e); + + // XXX heuristic / lossy, what order if conflict + + if (range.isFixed()) { + ValueRange left = evalRangeForExpr(be->left); + ValueRange right = evalRangeForExpr(be->right); + + uint64_t maxValue = bits64::maxValueOfNBits(be->right.getWidth()); + + // XXX should deal with overflow (can lead to empty range) + + if (left.isFixed()) { + if (range.min()) { + forceExprToRange(be->right, CexValueData(left.min()+1, maxValue)); + } else { + forceExprToRange(be->right, CexValueData(0, left.min())); + } + } else if (right.isFixed()) { + if (range.min()) { + forceExprToRange(be->left, CexValueData(0, right.min()-1)); + } else { + forceExprToRange(be->left, CexValueData(right.min(), maxValue)); + } + } else { + // XXX ??? + } + } + break; + } + case Expr::Ule: { + BinaryExpr *be = static_ref_cast<BinaryExpr>(e); + + // XXX heuristic / lossy, what order if conflict + + if (range.isFixed()) { + ValueRange left = evalRangeForExpr(be->left); + ValueRange right = evalRangeForExpr(be->right); + + // XXX should deal with overflow (can lead to empty range) + + uint64_t maxValue = bits64::maxValueOfNBits(be->right.getWidth()); + if (left.isFixed()) { + if (range.min()) { + forceExprToRange(be->right, CexValueData(left.min(), maxValue)); + } else { + forceExprToRange(be->right, CexValueData(0, left.min()-1)); + } + } else if (right.isFixed()) { + if (range.min()) { + forceExprToRange(be->left, CexValueData(0, right.min())); + } else { + forceExprToRange(be->left, CexValueData(right.min()+1, maxValue)); + } + } else { + // XXX ??? + } + } + break; + } + + case Expr::Ne: + case Expr::Ugt: + case Expr::Uge: + case Expr::Sgt: + case Expr::Sge: + assert(0 && "invalid expressions (uncanonicalized"); + + default: + break; + } + } + + void fixValues() { + for (std::map<unsigned, CexObjectData>::iterator it = objectValues.begin(), + ie = objectValues.end(); it != ie; ++it) { + CexObjectData &cod = it->second; + for (unsigned i=0; i<cod.size; i++) { + CexValueData &cvd = cod.values[i]; + cvd = CexValueData(cvd.min() + (cvd.max()-cvd.min())/2); + } + } + } + + ValueRange evalRangeForExpr(ref<Expr> &e) { + CexRangeEvaluator ce(objectValues); + return ce.evaluate(e); + } + + bool exprMustBeValue(ref<Expr> e, uint64_t value) { + CexConstifier cc(objectValues); + ref<Expr> v = cc.visit(e); + if (!v.isConstant()) return false; + // XXX reenable once all reads and vars are fixed + // assert(v.isConstant() && "not all values have been fixed"); + return v.getConstantValue()==value; + } +}; + +/* *** */ + + +class FastCexSolver : public IncompleteSolver { +public: + FastCexSolver(); + ~FastCexSolver(); + + IncompleteSolver::PartialValidity computeTruth(const Query&); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +FastCexSolver::FastCexSolver() { } + +FastCexSolver::~FastCexSolver() { } + +IncompleteSolver::PartialValidity +FastCexSolver::computeTruth(const Query& query) { +#ifdef LOG + std::ostringstream log; + theLog = &log; + // log << "------ start FastCexSolver::mustBeTrue ------\n"; + log << "-- QUERY --\n"; + unsigned i=0; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + log << " C" << i++ << ": " << *it << ", \n"; + log << " Q : " << query.expr << "\n"; +#endif + + ObjectFinder of; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + of.visit(*it); + of.visit(query.expr); + CexData cd(of); + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + cd.forceExprToValue(*it, 1); + cd.forceExprToValue(query.expr, 0); + +#ifdef LOG + log << " -- ranges --\n"; + for (std::map<unsigned, CexObjectData>::iterator it = objectValues.begin(), + ie = objectValues.end(); it != ie; ++it) { + CexObjectData &cod = it->second; + log << " arr" << it->first << "[" << cod.size << "] = ["; + unsigned continueFrom=cod.size-1; + for (; continueFrom>0; continueFrom--) + if (cod.values[continueFrom-1]!=cod.values[continueFrom]) + break; + for (unsigned i=0; i<cod.size; i++) { + log << cod.values[i]; + if (i<cod.size-1) { + log << ", "; + if (i==continueFrom) { + log << "..."; + break; + } + } + } + log << "]\n"; + } +#endif + + // this could be done lazily of course + cd.fixValues(); + +#ifdef LOG + log << " -- fixed values --\n"; + for (std::map<unsigned, CexObjectData>::iterator it = objectValues.begin(), + ie = objectValues.end(); it != ie; ++it) { + CexObjectData &cod = it->second; + log << " arr" << it->first << "[" << cod.size << "] = ["; + unsigned continueFrom=cod.size-1; + for (; continueFrom>0; continueFrom--) + if (cod.values[continueFrom-1]!=cod.values[continueFrom]) + break; + for (unsigned i=0; i<cod.size; i++) { + log << cod.values[i]; + if (i<cod.size-1) { + log << ", "; + if (i==continueFrom) { + log << "..."; + break; + } + } + } + log << "]\n"; + } +#endif + + // check the result + + bool isGood = true; + + if (!cd.exprMustBeValue(query.expr, 0)) isGood = false; + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + if (!cd.exprMustBeValue(*it, 1)) + isGood = false; + +#ifdef LOG + log << " -- evaluating result --\n"; + + i=0; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) { + bool res = cd.exprMustBeValue(*it, 1); + log << " C" << i++ << ": " << (res?"true":"false") << "\n"; + } + log << " Q : " + << (cd.exprMustBeValue(query.expr, 0)?"true":"false") << "\n"; + + log << "\n\n"; +#endif + + return isGood ? IncompleteSolver::MayBeFalse : IncompleteSolver::None; +} + +bool FastCexSolver::computeValue(const Query& query, ref<Expr> &result) { + ObjectFinder of; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + of.visit(*it); + of.visit(query.expr); + CexData cd(of); + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + cd.forceExprToValue(*it, 1); + + // this could be done lazily of course + cd.fixValues(); + + // check the result + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + if (!cd.exprMustBeValue(*it, 1)) + return false; + + CexConstifier cc(cd.objectValues); + ref<Expr> value = cc.visit(query.expr); + + if (value.isConstant()) { + result = value; + return true; + } else { + return false; + } +} + +bool +FastCexSolver::computeInitialValues(const Query& query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + ObjectFinder of; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + of.visit(*it); + of.visit(query.expr); + for (unsigned i = 0; i != objects.size(); ++i) + of.addObject(*objects[i]); + CexData cd(of); + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + cd.forceExprToValue(*it, 1); + cd.forceExprToValue(query.expr, 0); + + // this could be done lazily of course + cd.fixValues(); + + // check the result + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + if (!cd.exprMustBeValue(*it, 1)) + return false; + if (!cd.exprMustBeValue(query.expr, 0)) + return false; + + hasSolution = true; + CexConstifier cc(cd.objectValues); + for (unsigned i = 0; i != objects.size(); ++i) { + const Array *array = objects[i]; + std::vector<unsigned char> data; + data.reserve(array->size); + + for (unsigned i=0; i < array->size; i++) { + ref<Expr> value = + cc.visit(ReadExpr::create(UpdateList(array, true, 0), + ConstantExpr::create(i, + kMachinePointerType))); + + if (value.isConstant()) { + data.push_back(value.getConstantValue()); + } else { + // FIXME: When does this happen? + return false; + } + } + + values.push_back(data); + } + + return true; +} + + +Solver *klee::createFastCexSolver(Solver *s) { + return new Solver(new StagedSolverImpl(new FastCexSolver(), s)); +} diff --git a/lib/Solver/IncompleteSolver.cpp b/lib/Solver/IncompleteSolver.cpp new file mode 100644 index 00000000..f473f70b --- /dev/null +++ b/lib/Solver/IncompleteSolver.cpp @@ -0,0 +1,136 @@ +//===-- IncompleteSolver.cpp ----------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/IncompleteSolver.h" + +#include "klee/Constraints.h" + +using namespace klee; +using namespace llvm; + +/***/ + +IncompleteSolver::PartialValidity +IncompleteSolver::negatePartialValidity(PartialValidity pv) { + switch(pv) { + case MustBeTrue: return MustBeFalse; + case MustBeFalse: return MustBeTrue; + case MayBeTrue: return MayBeFalse; + case MayBeFalse: return MayBeTrue; + case TrueOrFalse: return TrueOrFalse; + default: assert(0 && "invalid partial validity"); + } +} + +IncompleteSolver::PartialValidity +IncompleteSolver::computeValidity(const Query& query) { + PartialValidity trueResult = computeTruth(query); + + if (trueResult == MustBeTrue) { + return MustBeTrue; + } else { + PartialValidity falseResult = computeTruth(query.negateExpr()); + + if (falseResult == MustBeTrue) { + return MustBeFalse; + } else { + bool trueCorrect = trueResult != None, + falseCorrect = falseResult != None; + + if (trueCorrect && falseCorrect) { + return TrueOrFalse; + } else if (trueCorrect) { // ==> trueResult == MayBeFalse + return MayBeFalse; + } else if (falseCorrect) { // ==> falseResult == MayBeFalse + return MayBeTrue; + } else { + return None; + } + } + } +} + +/***/ + +StagedSolverImpl::StagedSolverImpl(IncompleteSolver *_primary, + Solver *_secondary) + : primary(_primary), + secondary(_secondary) { +} + +StagedSolverImpl::~StagedSolverImpl() { + delete primary; + delete secondary; +} + +bool StagedSolverImpl::computeTruth(const Query& query, bool &isValid) { + IncompleteSolver::PartialValidity trueResult = primary->computeTruth(query); + + if (trueResult != IncompleteSolver::None) { + isValid = (trueResult == IncompleteSolver::MustBeTrue); + return true; + } + + return secondary->impl->computeTruth(query, isValid); +} + +bool StagedSolverImpl::computeValidity(const Query& query, + Solver::Validity &result) { + bool tmp; + + switch(primary->computeValidity(query)) { + case IncompleteSolver::MustBeTrue: + result = Solver::True; + break; + case IncompleteSolver::MustBeFalse: + result = Solver::False; + break; + case IncompleteSolver::TrueOrFalse: + result = Solver::Unknown; + break; + case IncompleteSolver::MayBeTrue: + if (!secondary->impl->computeTruth(query, tmp)) + return false; + result = tmp ? Solver::True : Solver::Unknown; + break; + case IncompleteSolver::MayBeFalse: + if (!secondary->impl->computeTruth(query.negateExpr(), tmp)) + return false; + result = tmp ? Solver::False : Solver::Unknown; + break; + default: + if (!secondary->impl->computeValidity(query, result)) + return false; + break; + } + + return true; +} + +bool StagedSolverImpl::computeValue(const Query& query, + ref<Expr> &result) { + if (primary->computeValue(query, result)) + return true; + + return secondary->impl->computeValue(query, result); +} + +bool +StagedSolverImpl::computeInitialValues(const Query& query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + if (primary->computeInitialValues(query, objects, values, hasSolution)) + return true; + + return secondary->impl->computeInitialValues(query, objects, values, + hasSolution); +} diff --git a/lib/Solver/IndependentSolver.cpp b/lib/Solver/IndependentSolver.cpp new file mode 100644 index 00000000..c966aff6 --- /dev/null +++ b/lib/Solver/IndependentSolver.cpp @@ -0,0 +1,314 @@ +//===-- IndependentSolver.cpp ---------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" + +#include "klee/Expr.h" +#include "klee/Constraints.h" +#include "klee/SolverImpl.h" + +#include "klee/util/ExprUtil.h" + +#include "llvm/Support/Streams.h" + +#include <map> +#include <vector> + +using namespace klee; +using namespace llvm; + +template<class T> +class DenseSet { + typedef std::set<T> set_ty; + set_ty s; + +public: + DenseSet() {} + + void add(T x) { + s.insert(x); + } + void add(T start, T end) { + for (; start<end; start++) + s.insert(start); + } + + // returns true iff set is changed by addition + bool add(const DenseSet &b) { + bool modified = false; + for (typename set_ty::const_iterator it = b.s.begin(), ie = b.s.end(); + it != ie; ++it) { + if (modified || !s.count(*it)) { + modified = true; + s.insert(*it); + } + } + return modified; + } + + bool intersects(const DenseSet &b) { + for (typename set_ty::iterator it = s.begin(), ie = s.end(); + it != ie; ++it) + if (b.s.count(*it)) + return true; + return false; + } + + void print(std::ostream &os) const { + bool first = true; + os << "{"; + for (typename set_ty::iterator it = s.begin(), ie = s.end(); + it != ie; ++it) { + if (first) { + first = false; + } else { + os << ","; + } + os << *it; + } + os << "}"; + } +}; + +template<class T> +inline std::ostream &operator<<(std::ostream &os, const DenseSet<T> &dis) { + dis.print(os); + return os; +} + +class IndependentElementSet { + typedef std::map<const Array*, DenseSet<unsigned> > elements_ty; + elements_ty elements; + std::set<const Array*> wholeObjects; + +public: + IndependentElementSet() {} + IndependentElementSet(ref<Expr> e) { + std::vector< ref<ReadExpr> > reads; + findReads(e, /* visitUpdates= */ true, reads); + for (unsigned i = 0; i != reads.size(); ++i) { + ReadExpr *re = reads[i].get(); + if (re->updates.isRooted) { + const Array *array = re->updates.root; + if (!wholeObjects.count(array)) { + if (re->index.isConstant()) { + DenseSet<unsigned> &dis = elements[array]; + dis.add((unsigned) re->index.getConstantValue()); + } else { + elements_ty::iterator it2 = elements.find(array); + if (it2!=elements.end()) + elements.erase(it2); + wholeObjects.insert(array); + } + } + } + } + } + IndependentElementSet(const IndependentElementSet &ies) : + elements(ies.elements), + wholeObjects(ies.wholeObjects) {} + + IndependentElementSet &operator=(const IndependentElementSet &ies) { + elements = ies.elements; + wholeObjects = ies.wholeObjects; + return *this; + } + + void print(std::ostream &os) const { + os << "{"; + bool first = true; + for (std::set<const Array*>::iterator it = wholeObjects.begin(), + ie = wholeObjects.end(); it != ie; ++it) { + const Array *array = *it; + + if (first) { + first = false; + } else { + os << ", "; + } + + os << "MO" << array->id; + } + for (elements_ty::const_iterator it = elements.begin(), ie = elements.end(); + it != ie; ++it) { + const Array *array = it->first; + const DenseSet<unsigned> &dis = it->second; + + if (first) { + first = false; + } else { + os << ", "; + } + + os << "MO" << array->id << " : " << dis; + } + os << "}"; + } + + // more efficient when this is the smaller set + bool intersects(const IndependentElementSet &b) { + for (std::set<const Array*>::iterator it = wholeObjects.begin(), + ie = wholeObjects.end(); it != ie; ++it) { + const Array *array = *it; + if (b.wholeObjects.count(array) || + b.elements.find(array) != b.elements.end()) + return true; + } + for (elements_ty::iterator it = elements.begin(), ie = elements.end(); + it != ie; ++it) { + const Array *array = it->first; + if (b.wholeObjects.count(array)) + return true; + elements_ty::const_iterator it2 = b.elements.find(array); + if (it2 != b.elements.end()) { + if (it->second.intersects(it2->second)) + return true; + } + } + return false; + } + + // returns true iff set is changed by addition + bool add(const IndependentElementSet &b) { + bool modified = false; + for (std::set<const Array*>::const_iterator it = b.wholeObjects.begin(), + ie = b.wholeObjects.end(); it != ie; ++it) { + const Array *array = *it; + elements_ty::iterator it2 = elements.find(array); + if (it2!=elements.end()) { + modified = true; + elements.erase(it2); + wholeObjects.insert(array); + } else { + if (!wholeObjects.count(array)) { + modified = true; + wholeObjects.insert(array); + } + } + } + for (elements_ty::const_iterator it = b.elements.begin(), + ie = b.elements.end(); it != ie; ++it) { + const Array *array = it->first; + if (!wholeObjects.count(array)) { + elements_ty::iterator it2 = elements.find(array); + if (it2==elements.end()) { + modified = true; + elements.insert(*it); + } else { + if (it2->second.add(it->second)) + modified = true; + } + } + } + return modified; + } +}; + +inline std::ostream &operator<<(std::ostream &os, const IndependentElementSet &ies) { + ies.print(os); + return os; +} + +static +IndependentElementSet getIndependentConstraints(const Query& query, + std::vector< ref<Expr> > &result) { + IndependentElementSet eltsClosure(query.expr); + std::vector< std::pair<ref<Expr>, IndependentElementSet> > worklist; + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + worklist.push_back(std::make_pair(*it, IndependentElementSet(*it))); + + // XXX This should be more efficient (in terms of low level copy stuff). + bool done = false; + do { + done = true; + std::vector< std::pair<ref<Expr>, IndependentElementSet> > newWorklist; + for (std::vector< std::pair<ref<Expr>, IndependentElementSet> >::iterator + it = worklist.begin(), ie = worklist.end(); it != ie; ++it) { + if (it->second.intersects(eltsClosure)) { + if (eltsClosure.add(it->second)) + done = false; + result.push_back(it->first); + } else { + newWorklist.push_back(*it); + } + } + worklist.swap(newWorklist); + } while (!done); + + if (0) { + std::set< ref<Expr> > reqset(result.begin(), result.end()); + llvm::cerr << "--\n"; + llvm::cerr << "Q: " << query.expr << "\n"; + llvm::cerr << "\telts: " << IndependentElementSet(query.expr) << "\n"; + int i = 0; + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) { + llvm::cerr << "C" << i++ << ": " << *it; + llvm::cerr << " " << (reqset.count(*it) ? "(required)" : "(independent)") << "\n"; + llvm::cerr << "\telts: " << IndependentElementSet(*it) << "\n"; + } + llvm::cerr << "elts closure: " << eltsClosure << "\n"; + } + + return eltsClosure; +} + +class IndependentSolver : public SolverImpl { +private: + Solver *solver; + +public: + IndependentSolver(Solver *_solver) + : solver(_solver) {} + ~IndependentSolver() { delete solver; } + + bool computeTruth(const Query&, bool &isValid); + bool computeValidity(const Query&, Solver::Validity &result); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query& query, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution) { + return solver->impl->computeInitialValues(query, objects, values, + hasSolution); + } +}; + +bool IndependentSolver::computeValidity(const Query& query, + Solver::Validity &result) { + std::vector< ref<Expr> > required; + IndependentElementSet eltsClosure = + getIndependentConstraints(query, required); + ConstraintManager tmp(required); + return solver->impl->computeValidity(Query(tmp, query.expr), + result); +} + +bool IndependentSolver::computeTruth(const Query& query, bool &isValid) { + std::vector< ref<Expr> > required; + IndependentElementSet eltsClosure = + getIndependentConstraints(query, required); + ConstraintManager tmp(required); + return solver->impl->computeTruth(Query(tmp, query.expr), + isValid); +} + +bool IndependentSolver::computeValue(const Query& query, ref<Expr> &result) { + std::vector< ref<Expr> > required; + IndependentElementSet eltsClosure = + getIndependentConstraints(query, required); + ConstraintManager tmp(required); + return solver->impl->computeValue(Query(tmp, query.expr), result); +} + +Solver *klee::createIndependentSolver(Solver *s) { + return new Solver(new IndependentSolver(s)); +} diff --git a/lib/Solver/Makefile b/lib/Solver/Makefile new file mode 100755 index 00000000..11d3d330 --- /dev/null +++ b/lib/Solver/Makefile @@ -0,0 +1,16 @@ +#===-- lib/Solver/Makefile ---------------------------------*- Makefile -*--===# +# +# The KLEE Symbolic Virtual Machine +# +# This file is distributed under the University of Illinois Open Source +# License. See LICENSE.TXT for details. +# +#===------------------------------------------------------------------------===# + +LEVEL=../.. + +LIBRARYNAME=kleaverSolver +DONT_BUILD_RELINKED=1 +BUILD_ARCHIVE=1 + +include $(LEVEL)/Makefile.common diff --git a/lib/Solver/PCLoggingSolver.cpp b/lib/Solver/PCLoggingSolver.cpp new file mode 100644 index 00000000..4b787acb --- /dev/null +++ b/lib/Solver/PCLoggingSolver.cpp @@ -0,0 +1,134 @@ +//===-- PCLoggingSolver.cpp -----------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" + +// FIXME: This should not be here. +#include "klee/ExecutionState.h" +#include "klee/Expr.h" +#include "klee/SolverImpl.h" +#include "klee/Statistics.h" +#include "klee/util/ExprPPrinter.h" +#include "klee/Internal/Support/QueryLog.h" +#include "klee/Internal/System/Time.h" + +#include "llvm/Support/CommandLine.h" + +#include <fstream> + +using namespace klee; +using namespace llvm; +using namespace klee::util; + +/// + +class PCLoggingSolver : public SolverImpl { + Solver *solver; + std::ofstream os; + ExprPPrinter *printer; + unsigned queryCount; + double startTime; + + void startQuery(const Query& query, const char *typeName) { + Statistic *S = theStatisticManager->getStatisticByName("Instructions"); + uint64_t instructions = S ? S->getValue() : 0; + os << "# Query " << queryCount++ << " -- " + << "Type: " << typeName << ", " + << "Instructions: " << instructions << "\n"; + printer->printQuery(os, query.constraints, query.expr); + + startTime = getWallTime(); + } + + void finishQuery(bool success) { + double delta = getWallTime() - startTime; + os << "# " << (success ? "OK" : "FAIL") << " -- " + << "Elapsed: " << delta << "\n"; + } + +public: + PCLoggingSolver(Solver *_solver, std::string path) + : solver(_solver), + os(path.c_str(), std::ios::trunc), + printer(ExprPPrinter::create(os)), + queryCount(0) { + } + ~PCLoggingSolver() { + delete printer; + delete solver; + } + + bool computeTruth(const Query& query, bool &isValid) { + startQuery(query, "Truth"); + bool success = solver->impl->computeTruth(query, isValid); + finishQuery(success); + if (success) + os << "# Is Valid: " << (isValid ? "true" : "false") << "\n"; + os << "\n"; + return success; + } + + bool computeValidity(const Query& query, Solver::Validity &result) { + startQuery(query, "Validity"); + bool success = solver->impl->computeValidity(query, result); + finishQuery(success); + if (success) + os << "# Validity: " << result << "\n"; + os << "\n"; + return success; + } + + bool computeValue(const Query& query, ref<Expr> &result) { + startQuery(query, "Value"); + bool success = solver->impl->computeValue(query, result); + finishQuery(success); + if (success) + os << "# Result: " << result << "\n"; + os << "\n"; + return success; + } + + bool computeInitialValues(const Query& query, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution) { + // FIXME: Add objects to output. + startQuery(query, "InitialValues"); + bool success = solver->impl->computeInitialValues(query, objects, + values, hasSolution); + finishQuery(success); + if (success) { + os << "# Solvable: " << (hasSolution ? "true" : "false") << "\n"; + if (hasSolution) { + std::vector< std::vector<unsigned char> >::iterator + values_it = values.begin(); + for (std::vector<const Array*>::const_iterator i = objects.begin(), + e = objects.end(); i != e; ++i, ++values_it) { + const Array *array = *i; + std::vector<unsigned char> &data = *values_it; + os << "# arr" << array->id << " = ["; + for (unsigned j = 0; j < array->size; j++) { + os << (int) data[j]; + if (j+1 < array->size) + os << ","; + } + os << "]\n"; + } + } + } + os << "\n"; + return success; + } +}; + +/// + +Solver *klee::createPCLoggingSolver(Solver *_solver, std::string path) { + return new Solver(new PCLoggingSolver(_solver, path)); +} diff --git a/lib/Solver/STPBuilder.cpp b/lib/Solver/STPBuilder.cpp new file mode 100644 index 00000000..33aee035 --- /dev/null +++ b/lib/Solver/STPBuilder.cpp @@ -0,0 +1,819 @@ +//===-- STPBuilder.cpp ----------------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "STPBuilder.h" + +#include "klee/Expr.h" +#include "klee/Solver.h" +#include "klee/util/Bits.h" + +#include "ConstantDivision.h" +#include "SolverStats.h" + +#include "llvm/Support/CommandLine.h" + +#define vc_bvBoolExtract IAMTHESPAWNOFSATAN +// unclear return +#define vc_bvLeftShiftExpr IAMTHESPAWNOFSATAN +#define vc_bvRightShiftExpr IAMTHESPAWNOFSATAN +// bad refcnt'ng +#define vc_bvVar32LeftShiftExpr IAMTHESPAWNOFSATAN +#define vc_bvVar32RightShiftExpr IAMTHESPAWNOFSATAN +#define vc_bvVar32DivByPowOfTwoExpr IAMTHESPAWNOFSATAN +#define vc_bvCreateMemoryArray IAMTHESPAWNOFSATAN +#define vc_bvReadMemoryArray IAMTHESPAWNOFSATAN +#define vc_bvWriteToMemoryArray IAMTHESPAWNOFSATAN + +#include <algorithm> // max, min +#include <cassert> +#include <iostream> +#include <map> +#include <sstream> +#include <vector> + +using namespace klee; + +namespace { + llvm::cl::opt<bool> + UseConstructHash("use-construct-hash", + llvm::cl::desc("Use hash-consing during STP query construction."), + llvm::cl::init(true)); +} + +/// + +/***/ + +STPBuilder::STPBuilder(::VC _vc, bool _optimizeDivides) + : vc(_vc), optimizeDivides(_optimizeDivides) +{ + tempVars[0] = buildVar("__tmpInt8", 8); + tempVars[1] = buildVar("__tmpInt16", 16); + tempVars[2] = buildVar("__tmpInt32", 32); + tempVars[3] = buildVar("__tmpInt64", 64); +} + +STPBuilder::~STPBuilder() { +} + +/// + +/* Warning: be careful about what c_interface functions you use. Some of + them look like they cons memory but in fact don't, which is bad when + you call vc_DeleteExpr on them. */ + +::VCExpr STPBuilder::buildVar(const char *name, unsigned width) { + // XXX don't rebuild if this stuff cons's + ::Type t = (width==1) ? vc_boolType(vc) : vc_bvType(vc, width); + ::VCExpr res = vc_varExpr(vc, (char*) name, t); + vc_DeleteExpr(t); + return res; +} + +::VCExpr STPBuilder::buildArray(const char *name, unsigned indexWidth, unsigned valueWidth) { + // XXX don't rebuild if this stuff cons's + ::Type t1 = vc_bvType(vc, indexWidth); + ::Type t2 = vc_bvType(vc, valueWidth); + ::Type t = vc_arrayType(vc, t1, t2); + ::VCExpr res = vc_varExpr(vc, (char*) name, t); + vc_DeleteExpr(t); + vc_DeleteExpr(t2); + vc_DeleteExpr(t1); + return res; +} + +ExprHandle STPBuilder::getTempVar(Expr::Width w) { + switch (w) { + case Expr::Int8: return tempVars[0]; + case Expr::Int16: return tempVars[1]; + case Expr::Int32: return tempVars[2]; + case Expr::Int64: return tempVars[3]; + default: + assert(0 && "invalid type"); + } +} + +ExprHandle STPBuilder::getTrue() { + return vc_trueExpr(vc); +} +ExprHandle STPBuilder::getFalse() { + return vc_falseExpr(vc); +} +ExprHandle STPBuilder::bvOne(unsigned width) { + return bvConst32(width, 1); +} +ExprHandle STPBuilder::bvZero(unsigned width) { + return bvConst32(width, 0); +} +ExprHandle STPBuilder::bvMinusOne(unsigned width) { + return bvConst64(width, (int64_t) -1); +} +ExprHandle STPBuilder::bvConst32(unsigned width, uint32_t value) { + return vc_bvConstExprFromInt(vc, width, value); +} +ExprHandle STPBuilder::bvConst64(unsigned width, uint64_t value) { + return vc_bvConstExprFromLL(vc, width, value); +} + +ExprHandle STPBuilder::bvBoolExtract(ExprHandle expr, int bit) { + return vc_eqExpr(vc, bvExtract(expr, bit, bit), bvOne(1)); +} +ExprHandle STPBuilder::bvExtract(ExprHandle expr, unsigned top, unsigned bottom) { + return vc_bvExtract(vc, expr, top, bottom); +} +ExprHandle STPBuilder::eqExpr(ExprHandle a, ExprHandle b) { + return vc_eqExpr(vc, a, b); +} + +// logical right shift +ExprHandle STPBuilder::bvRightShift(ExprHandle expr, unsigned amount, unsigned shiftBits) { + unsigned width = vc_getBVLength(vc, expr); + unsigned shift = amount & ((1<<shiftBits) - 1); + + if (shift==0) { + return expr; + } else if (shift>=width) { + return bvZero(width); + } else { + return vc_bvConcatExpr(vc, + bvZero(shift), + bvExtract(expr, width - 1, shift)); + } +} + +// logical left shift +ExprHandle STPBuilder::bvLeftShift(ExprHandle expr, unsigned amount, unsigned shiftBits) { + unsigned width = vc_getBVLength(vc, expr); + unsigned shift = amount & ((1<<shiftBits) - 1); + + if (shift==0) { + return expr; + } else if (shift>=width) { + return bvZero(width); + } else { + // stp shift does "expr @ [0 x s]" which we then have to extract, + // rolling our own gives slightly smaller exprs + return vc_bvConcatExpr(vc, + bvExtract(expr, width - shift - 1, 0), + bvZero(shift)); + } +} + +// left shift by a variable amount on an expression of the specified width +ExprHandle STPBuilder::bvVarLeftShift(ExprHandle expr, ExprHandle amount, unsigned width) { + ExprHandle res = bvZero(width); + + int shiftBits = getShiftBits( width ); + + //get the shift amount (looking only at the bits appropriate for the given width) + ExprHandle shift = vc_bvExtract( vc, amount, shiftBits - 1, 0 ); + + //construct a big if-then-elif-elif-... with one case per possible shift amount + for( int i=width-1; i>=0; i-- ) { + res = vc_iteExpr(vc, + eqExpr(shift, bvConst32(shiftBits, i)), + bvLeftShift(expr, i, shiftBits), + res); + } + return res; +} + +// logical right shift by a variable amount on an expression of the specified width +ExprHandle STPBuilder::bvVarRightShift(ExprHandle expr, ExprHandle amount, unsigned width) { + ExprHandle res = bvZero(width); + + int shiftBits = getShiftBits( width ); + + //get the shift amount (looking only at the bits appropriate for the given width) + ExprHandle shift = vc_bvExtract( vc, amount, shiftBits - 1, 0 ); + + //construct a big if-then-elif-elif-... with one case per possible shift amount + for( int i=width-1; i>=0; i-- ) { + res = vc_iteExpr(vc, + eqExpr(shift, bvConst32(shiftBits, i)), + bvRightShift(expr, i, shiftBits), + res); + } + + return res; +} + +// arithmetic right shift by a variable amount on an expression of the specified width +ExprHandle STPBuilder::bvVarArithRightShift(ExprHandle expr, ExprHandle amount, unsigned width) { + int shiftBits = getShiftBits( width ); + + //get the shift amount (looking only at the bits appropriate for the given width) + ExprHandle shift = vc_bvExtract( vc, amount, shiftBits - 1, 0 ); + + //get the sign bit to fill with + ExprHandle signedBool = bvBoolExtract(expr, width-1); + + //start with the result if shifting by width-1 + ExprHandle res = constructAShrByConstant(expr, width-1, signedBool, shiftBits); + + //construct a big if-then-elif-elif-... with one case per possible shift amount + // XXX more efficient to move the ite on the sign outside all exprs? + // XXX more efficient to sign extend, right shift, then extract lower bits? + for( int i=width-2; i>=0; i-- ) { + res = vc_iteExpr(vc, + eqExpr(shift, bvConst32(shiftBits,i)), + constructAShrByConstant(expr, + i, + signedBool, + shiftBits), + res); + } + + return res; +} + +ExprHandle STPBuilder::constructAShrByConstant(ExprHandle expr, + unsigned amount, + ExprHandle isSigned, + unsigned shiftBits) { + unsigned width = vc_getBVLength(vc, expr); + unsigned shift = amount & ((1<<shiftBits) - 1); + + if (shift==0) { + return expr; + } else if (shift>=width-1) { + return vc_iteExpr(vc, isSigned, bvMinusOne(width), bvZero(width)); + } else { + return vc_iteExpr(vc, + isSigned, + ExprHandle(vc_bvConcatExpr(vc, + bvMinusOne(shift), + bvExtract(expr, width - 1, shift))), + bvRightShift(expr, shift, shiftBits)); + } +} + +ExprHandle STPBuilder::constructMulByConstant(ExprHandle expr, unsigned width, uint64_t x) { + unsigned shiftBits = getShiftBits(width); + uint64_t add, sub; + ExprHandle res = 0; + + // expr*x == expr*(add-sub) == expr*add - expr*sub + ComputeMultConstants64(x, add, sub); + + // legal, these would overflow completely + add = bits64::truncateToNBits(add, width); + sub = bits64::truncateToNBits(sub, width); + + for (int j=63; j>=0; j--) { + uint64_t bit = 1LL << j; + + if ((add&bit) || (sub&bit)) { + assert(!((add&bit) && (sub&bit)) && "invalid mult constants"); + ExprHandle op = bvLeftShift(expr, j, shiftBits); + + if (add&bit) { + if (res) { + res = vc_bvPlusExpr(vc, width, res, op); + } else { + res = op; + } + } else { + if (res) { + res = vc_bvMinusExpr(vc, width, res, op); + } else { + res = vc_bvUMinusExpr(vc, op); + } + } + } + } + + if (!res) + res = bvZero(width); + + return res; +} + +/* + * Compute the 32-bit unsigned integer division of n by a divisor d based on + * the constants derived from the constant divisor d. + * + * Returns n/d without doing explicit division. The cost is 2 adds, 3 shifts, + * and a (64-bit) multiply. + * + * @param n numerator (dividend) as an expression + * @param width number of bits used to represent the value + * @param d the divisor + * + * @return n/d without doing explicit division + */ +ExprHandle STPBuilder::constructUDivByConstant(ExprHandle expr_n, unsigned width, uint64_t d) { + assert(width==32 && "can only compute udiv constants for 32-bit division"); + + // Compute the constants needed to compute n/d for constant d w/o + // division by d. + uint32_t mprime, sh1, sh2; + ComputeUDivConstants32(d, mprime, sh1, sh2); + ExprHandle expr_sh1 = bvConst32( 32, sh1); + ExprHandle expr_sh2 = bvConst32( 32, sh2); + + // t1 = MULUH(mprime, n) = ( (uint64_t)mprime * (uint64_t)n ) >> 32 + ExprHandle expr_n_64 = vc_bvConcatExpr( vc, bvZero(32), expr_n ); //extend to 64 bits + ExprHandle t1_64bits = constructMulByConstant( expr_n_64, 64, (uint64_t)mprime ); + ExprHandle t1 = vc_bvExtract( vc, t1_64bits, 63, 32 ); //upper 32 bits + + // n/d = (((n - t1) >> sh1) + t1) >> sh2; + ExprHandle n_minus_t1 = vc_bvMinusExpr( vc, width, expr_n, t1 ); + ExprHandle shift_sh1 = bvVarRightShift( n_minus_t1, expr_sh1, 32 ); + ExprHandle plus_t1 = vc_bvPlusExpr( vc, width, shift_sh1, t1 ); + ExprHandle res = bvVarRightShift( plus_t1, expr_sh2, 32 ); + + return res; +} + +/* + * Compute the 32-bitnsigned integer division of n by a divisor d based on + * the constants derived from the constant divisor d. + * + * Returns n/d without doing explicit division. The cost is 3 adds, 3 shifts, + * a (64-bit) multiply, and an XOR. + * + * @param n numerator (dividend) as an expression + * @param width number of bits used to represent the value + * @param d the divisor + * + * @return n/d without doing explicit division + */ +ExprHandle STPBuilder::constructSDivByConstant(ExprHandle expr_n, unsigned width, uint64_t d) { + assert(width==32 && "can only compute udiv constants for 32-bit division"); + + // Compute the constants needed to compute n/d for constant d w/o division by d. + int32_t mprime, dsign, shpost; + ComputeSDivConstants32(d, mprime, dsign, shpost); + ExprHandle expr_dsign = bvConst32( 32, dsign); + ExprHandle expr_shpost = bvConst32( 32, shpost); + + // q0 = n + MULSH( mprime, n ) = n + (( (int64_t)mprime * (int64_t)n ) >> 32) + int64_t mprime_64 = (int64_t)mprime; + + ExprHandle expr_n_64 = vc_bvSignExtend( vc, expr_n, 64 ); + ExprHandle mult_64 = constructMulByConstant( expr_n_64, 64, mprime_64 ); + ExprHandle mulsh = vc_bvExtract( vc, mult_64, 63, 32 ); //upper 32-bits + ExprHandle n_plus_mulsh = vc_bvPlusExpr( vc, width, expr_n, mulsh ); + + // Improved variable arithmetic right shift: sign extend, shift, + // extract. + ExprHandle extend_npm = vc_bvSignExtend( vc, n_plus_mulsh, 64 ); + ExprHandle shift_npm = bvVarRightShift( extend_npm, expr_shpost, 64 ); + ExprHandle shift_shpost = vc_bvExtract( vc, shift_npm, 31, 0 ); //lower 32-bits + + // XSIGN(n) is -1 if n is negative, positive one otherwise + ExprHandle is_signed = bvBoolExtract( expr_n, 31 ); + ExprHandle neg_one = bvMinusOne(32); + ExprHandle xsign_of_n = vc_iteExpr( vc, is_signed, neg_one, bvZero(32) ); + + // q0 = (n_plus_mulsh >> shpost) - XSIGN(n) + ExprHandle q0 = vc_bvMinusExpr( vc, width, shift_shpost, xsign_of_n ); + + // n/d = (q0 ^ dsign) - dsign + ExprHandle q0_xor_dsign = vc_bvXorExpr( vc, q0, expr_dsign ); + ExprHandle res = vc_bvMinusExpr( vc, width, q0_xor_dsign, expr_dsign ); + + return res; +} + +::VCExpr STPBuilder::getInitialArray(const Array *root) { + if (root->stpInitialArray) { + return root->stpInitialArray; + } else { + char buf[32]; + sprintf(buf, "arr%d", root->id); + root->stpInitialArray = buildArray(buf, 32, 8); + return root->stpInitialArray; + } +} + +ExprHandle STPBuilder::getInitialRead(const Array *root, unsigned index) { + return vc_readExpr(vc, getInitialArray(root), bvConst32(32, index)); +} + +::VCExpr STPBuilder::getArrayForUpdate(const Array *root, + const UpdateNode *un) { + if (!un) { + return getInitialArray(root); + } else { + // FIXME: This really needs to be non-recursive. + if (!un->stpArray) + un->stpArray = vc_writeExpr(vc, + getArrayForUpdate(root, un->next), + construct(un->index, 0), + construct(un->value, 0)); + + return un->stpArray; + } +} + +/** if *width_out!=1 then result is a bitvector, + otherwise it is a bool */ +ExprHandle STPBuilder::construct(ref<Expr> e, int *width_out) { + if (!UseConstructHash || e.isConstant()) { + return constructActual(e, width_out); + } else { + ExprHashMap< std::pair<ExprHandle, unsigned> >::iterator it = + constructed.find(e); + if (it!=constructed.end()) { + if (width_out) + *width_out = it->second.second; + return it->second.first; + } else { + int width; + if (!width_out) width_out = &width; + ExprHandle res = constructActual(e, width_out); + constructed.insert(std::make_pair(e, std::make_pair(res, *width_out))); + return res; + } + } +} + + +/** if *width_out!=1 then result is a bitvector, + otherwise it is a bool */ +ExprHandle STPBuilder::constructActual(ref<Expr> e, int *width_out) { + int width; + if (!width_out) width_out = &width; + + ++stats::queryConstructs; + + switch(e.getKind()) { + + case Expr::Constant: { + uint64_t asUInt64 = e.getConstantValue(); + *width_out = e.getWidth(); + + if (*width_out > 64) + assert(0 && "constructActual: width > 64"); + + if (*width_out == 1) + return asUInt64 ? getTrue() : getFalse(); + else if (*width_out <= 32) + return bvConst32(*width_out, asUInt64); + else return bvConst64(*width_out, asUInt64); + } + + // Special + case Expr::NotOptimized: { + NotOptimizedExpr *noe = static_ref_cast<NotOptimizedExpr>(e); + return construct(noe->src, width_out); + } + + case Expr::Read: { + ReadExpr *re = static_ref_cast<ReadExpr>(e); + *width_out = 8; + return vc_readExpr(vc, + getArrayForUpdate(re->updates.root, re->updates.head), + construct(re->index, 0)); + } + + case Expr::Select: { + SelectExpr *se = static_ref_cast<SelectExpr>(e); + ExprHandle cond = construct(se->cond, 0); + ExprHandle tExpr = construct(se->trueExpr, width_out); + ExprHandle fExpr = construct(se->falseExpr, width_out); + return vc_iteExpr(vc, cond, tExpr, fExpr); + } + + case Expr::Concat: { + ConcatExpr *ce = static_ref_cast<ConcatExpr>(e); + unsigned numKids = ce->getNumKids(); + ExprHandle res = construct(ce->getKid(numKids-1), 0); + for (int i=numKids-2; i>=0; i--) { + res = vc_bvConcatExpr(vc, construct(ce->getKid(i), 0), res); + } + *width_out = ce->getWidth(); + return res; + } + + case Expr::Extract: { + ExtractExpr *ee = static_ref_cast<ExtractExpr>(e); + ExprHandle src = construct(ee->expr, width_out); + *width_out = ee->getWidth(); + if (*width_out==1) { + return bvBoolExtract(src, 0); + } else { + return vc_bvExtract(vc, src, ee->offset + *width_out - 1, ee->offset); + } + } + + // Casting + + case Expr::ZExt: { + int srcWidth; + CastExpr *ce = static_ref_cast<CastExpr>(e); + ExprHandle src = construct(ce->src, &srcWidth); + *width_out = ce->getWidth(); + if (srcWidth==1) { + return vc_iteExpr(vc, src, bvOne(*width_out), bvZero(*width_out)); + } else { + return vc_bvConcatExpr(vc, bvZero(*width_out-srcWidth), src); + } + } + + case Expr::SExt: { + int srcWidth; + CastExpr *ce = static_ref_cast<CastExpr>(e); + ExprHandle src = construct(ce->src, &srcWidth); + *width_out = ce->getWidth(); + if (srcWidth==1) { + return vc_iteExpr(vc, src, bvMinusOne(*width_out), bvZero(*width_out)); + } else { + return vc_bvSignExtend(vc, src, *width_out); + } + } + + // Arithmetic + + case Expr::Add: { + AddExpr *ae = static_ref_cast<AddExpr>(e); + ExprHandle left = construct(ae->left, width_out); + ExprHandle right = construct(ae->right, width_out); + assert(*width_out!=1 && "uncanonicalized add"); + return vc_bvPlusExpr(vc, *width_out, left, right); + } + + case Expr::Sub: { + SubExpr *se = static_ref_cast<SubExpr>(e); + ExprHandle left = construct(se->left, width_out); + ExprHandle right = construct(se->right, width_out); + assert(*width_out!=1 && "uncanonicalized sub"); + return vc_bvMinusExpr(vc, *width_out, left, right); + } + + case Expr::Mul: { + MulExpr *me = static_ref_cast<MulExpr>(e); + ExprHandle right = construct(me->right, width_out); + assert(*width_out!=1 && "uncanonicalized mul"); + + if (me->left.isConstant()) { + return constructMulByConstant(right, *width_out, me->left.getConstantValue()); + } else { + ExprHandle left = construct(me->left, width_out); + return vc_bvMultExpr(vc, *width_out, left, right); + } + } + + case Expr::UDiv: { + UDivExpr *de = static_ref_cast<UDivExpr>(e); + ExprHandle left = construct(de->left, width_out); + assert(*width_out!=1 && "uncanonicalized udiv"); + + if (de->right.isConstant()) { + uint64_t divisor = de->right.getConstantValue(); + + if (bits64::isPowerOfTwo(divisor)) { + return bvRightShift(left, + bits64::indexOfSingleBit(divisor), + getShiftBits(*width_out)); + } else if (optimizeDivides) { + if (*width_out == 32) //only works for 32-bit division + return constructUDivByConstant( left, *width_out, (uint32_t)divisor ); + } + } + + ExprHandle right = construct(de->right, width_out); + return vc_bvDivExpr(vc, *width_out, left, right); + } + + case Expr::SDiv: { + SDivExpr *de = static_ref_cast<SDivExpr>(e); + ExprHandle left = construct(de->left, width_out); + assert(*width_out!=1 && "uncanonicalized sdiv"); + + if (de->right.isConstant()) { + uint64_t divisor = de->right.getConstantValue(); + + if (optimizeDivides) { + if (*width_out == 32) //only works for 32-bit division + return constructSDivByConstant( left, *width_out, divisor); + } + } + + // XXX need to test for proper handling of sign, not sure I + // trust STP + ExprHandle right = construct(de->right, width_out); + return vc_sbvDivExpr(vc, *width_out, left, right); + } + + case Expr::URem: { + URemExpr *de = static_ref_cast<URemExpr>(e); + ExprHandle left = construct(de->left, width_out); + assert(*width_out!=1 && "uncanonicalized urem"); + + if (de->right.isConstant()) { + uint64_t divisor = de->right.getConstantValue(); + + if (bits64::isPowerOfTwo(divisor)) { + unsigned bits = bits64::indexOfSingleBit(divisor); + + // special case for modding by 1 or else we bvExtract -1:0 + if (bits == 0) { + return bvZero(*width_out); + } else { + return vc_bvConcatExpr(vc, + bvZero(*width_out - bits), + bvExtract(left, bits - 1, 0)); + } + } + + //use fast division to compute modulo without explicit division for constant divisor + if (optimizeDivides) { + if (*width_out == 32) { //only works for 32-bit division + ExprHandle quotient = constructUDivByConstant( left, *width_out, (uint32_t)divisor ); + ExprHandle quot_times_divisor = constructMulByConstant( quotient, *width_out, divisor ); + ExprHandle rem = vc_bvMinusExpr( vc, *width_out, left, quot_times_divisor ); + return rem; + } + } + } + + ExprHandle right = construct(de->right, width_out); + return vc_bvModExpr(vc, *width_out, left, right); + } + + case Expr::SRem: { + SRemExpr *de = static_ref_cast<SRemExpr>(e); + ExprHandle left = construct(de->left, width_out); + ExprHandle right = construct(de->right, width_out); + assert(*width_out!=1 && "uncanonicalized srem"); + +#if 0 //not faster per first benchmark + if (optimizeDivides) { + if (ConstantExpr *cre = de->right->asConstant()) { + uint64_t divisor = cre->asUInt64; + + //use fast division to compute modulo without explicit division for constant divisor + if( *width_out == 32 ) { //only works for 32-bit division + ExprHandle quotient = constructSDivByConstant( left, *width_out, divisor ); + ExprHandle quot_times_divisor = constructMulByConstant( quotient, *width_out, divisor ); + ExprHandle rem = vc_bvMinusExpr( vc, *width_out, left, quot_times_divisor ); + return rem; + } + } + } +#endif + + // XXX implement my fast path and test for proper handling of sign + return vc_sbvModExpr(vc, *width_out, left, right); + } + + // Binary + + case Expr::And: { + AndExpr *ae = static_ref_cast<AndExpr>(e); + ExprHandle left = construct(ae->left, width_out); + ExprHandle right = construct(ae->right, width_out); + if (*width_out==1) { + return vc_andExpr(vc, left, right); + } else { + return vc_bvAndExpr(vc, left, right); + } + } + case Expr::Or: { + OrExpr *oe = static_ref_cast<OrExpr>(e); + ExprHandle left = construct(oe->left, width_out); + ExprHandle right = construct(oe->right, width_out); + if (*width_out==1) { + return vc_orExpr(vc, left, right); + } else { + return vc_bvOrExpr(vc, left, right); + } + } + + case Expr::Xor: { + XorExpr *xe = static_ref_cast<XorExpr>(e); + ExprHandle left = construct(xe->left, width_out); + ExprHandle right = construct(xe->right, width_out); + + if (*width_out==1) { + // XXX check for most efficient? + return vc_iteExpr(vc, left, + ExprHandle(vc_notExpr(vc, right)), right); + } else { + return vc_bvXorExpr(vc, left, right); + } + } + + case Expr::Shl: { + ShlExpr *se = static_ref_cast<ShlExpr>(e); + ExprHandle left = construct(se->left, width_out); + assert(*width_out!=1 && "uncanonicalized shl"); + + if (se->right.isConstant()) { + return bvLeftShift(left, se->right.getConstantValue(), getShiftBits(*width_out)); + } else { + int shiftWidth; + ExprHandle amount = construct(se->right, &shiftWidth); + return bvVarLeftShift( left, amount, *width_out ); + } + } + + case Expr::LShr: { + LShrExpr *lse = static_ref_cast<LShrExpr>(e); + ExprHandle left = construct(lse->left, width_out); + unsigned shiftBits = getShiftBits(*width_out); + assert(*width_out!=1 && "uncanonicalized lshr"); + + if (lse->right.isConstant()) { + return bvRightShift(left, (unsigned) lse->right.getConstantValue(), shiftBits); + } else { + int shiftWidth; + ExprHandle amount = construct(lse->right, &shiftWidth); + return bvVarRightShift( left, amount, *width_out ); + } + } + + case Expr::AShr: { + AShrExpr *ase = static_ref_cast<AShrExpr>(e); + ExprHandle left = construct(ase->left, width_out); + assert(*width_out!=1 && "uncanonicalized ashr"); + + if (ase->right.isConstant()) { + unsigned shift = (unsigned) ase->right.getConstantValue(); + ExprHandle signedBool = bvBoolExtract(left, *width_out-1); + return constructAShrByConstant(left, shift, signedBool, getShiftBits(*width_out)); + } else { + int shiftWidth; + ExprHandle amount = construct(ase->right, &shiftWidth); + return bvVarArithRightShift( left, amount, *width_out ); + } + } + + // Comparison + + case Expr::Eq: { + EqExpr *ee = static_ref_cast<EqExpr>(e); + ExprHandle left = construct(ee->left, width_out); + ExprHandle right = construct(ee->right, width_out); + if (*width_out==1) { + if (ee->left.isConstant()) { + assert(!ee->left.getConstantValue() && "uncanonicalized eq"); + return vc_notExpr(vc, right); + } else { + return vc_iffExpr(vc, left, right); + } + } else { + *width_out = 1; + return vc_eqExpr(vc, left, right); + } + } + + case Expr::Ult: { + UltExpr *ue = static_ref_cast<UltExpr>(e); + ExprHandle left = construct(ue->left, width_out); + ExprHandle right = construct(ue->right, width_out); + assert(*width_out!=1 && "uncanonicalized ult"); + *width_out = 1; + return vc_bvLtExpr(vc, left, right); + } + + case Expr::Ule: { + UleExpr *ue = static_ref_cast<UleExpr>(e); + ExprHandle left = construct(ue->left, width_out); + ExprHandle right = construct(ue->right, width_out); + assert(*width_out!=1 && "uncanonicalized ule"); + *width_out = 1; + return vc_bvLeExpr(vc, left, right); + } + + case Expr::Slt: { + SltExpr *se = static_ref_cast<SltExpr>(e); + ExprHandle left = construct(se->left, width_out); + ExprHandle right = construct(se->right, width_out); + assert(*width_out!=1 && "uncanonicalized slt"); + *width_out = 1; + return vc_sbvLtExpr(vc, left, right); + } + + case Expr::Sle: { + SleExpr *se = static_ref_cast<SleExpr>(e); + ExprHandle left = construct(se->left, width_out); + ExprHandle right = construct(se->right, width_out); + assert(*width_out!=1 && "uncanonicalized sle"); + *width_out = 1; + return vc_sbvLeExpr(vc, left, right); + } + + // unused due to canonicalization +#if 0 + case Expr::Ne: + case Expr::Ugt: + case Expr::Uge: + case Expr::Sgt: + case Expr::Sge: +#endif + + default: + assert(0 && "unhandled Expr type"); + return vc_trueExpr(vc); + } +} diff --git a/lib/Solver/STPBuilder.h b/lib/Solver/STPBuilder.h new file mode 100644 index 00000000..6382bc1f --- /dev/null +++ b/lib/Solver/STPBuilder.h @@ -0,0 +1,125 @@ +//===-- STPBuilder.h --------------------------------------------*- C++ -*-===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef __UTIL_STPBUILDER_H__ +#define __UTIL_STPBUILDER_H__ + +#include "klee/util/ExprHashMap.h" +#include "klee/Config/config.h" + +#include <vector> +#include <map> + +#define Expr VCExpr +#include "stp/c_interface.h" + +#if ENABLE_STPLOG == 1 +#include "stp/stplog.h" +#endif +#undef Expr + +namespace klee { + class ExprHolder { + friend class ExprHandle; + ::VCExpr expr; + unsigned count; + + public: + ExprHolder(const ::VCExpr _expr) : expr(_expr), count(0) {} + ~ExprHolder() { + if (expr) vc_DeleteExpr(expr); + } + }; + + class ExprHandle { + ExprHolder *H; + + public: + ExprHandle() : H(new ExprHolder(0)) { H->count++; } + ExprHandle(::VCExpr _expr) : H(new ExprHolder(_expr)) { H->count++; } + ExprHandle(const ExprHandle &b) : H(b.H) { H->count++; } + ~ExprHandle() { if (--H->count == 0) delete H; } + + ExprHandle &operator=(const ExprHandle &b) { + if (--H->count == 0) delete H; + H = b.H; + H->count++; + return *this; + } + + operator bool () { return H->expr; } + operator ::VCExpr () { return H->expr; } + }; + +class STPBuilder { + ::VC vc; + ExprHandle tempVars[4]; + ExprHashMap< std::pair<ExprHandle, unsigned> > constructed; + + /// optimizeDivides - Rewrite division and reminders by constants + /// into multiplies and shifts. STP should probably handle this for + /// use. + bool optimizeDivides; + +private: + unsigned getShiftBits(unsigned amount) { + return (amount == 64) ? 6 : 5; + } + + ExprHandle bvOne(unsigned width); + ExprHandle bvZero(unsigned width); + ExprHandle bvMinusOne(unsigned width); + ExprHandle bvConst32(unsigned width, uint32_t value); + ExprHandle bvConst64(unsigned width, uint64_t value); + + ExprHandle bvBoolExtract(ExprHandle expr, int bit); + ExprHandle bvExtract(ExprHandle expr, unsigned top, unsigned bottom); + ExprHandle eqExpr(ExprHandle a, ExprHandle b); + + //logical left and right shift (not arithmetic) + ExprHandle bvLeftShift(ExprHandle expr, unsigned shift, unsigned shiftBits); + ExprHandle bvRightShift(ExprHandle expr, unsigned amount, unsigned shiftBits); + ExprHandle bvVarLeftShift(ExprHandle expr, ExprHandle amount, unsigned width); + ExprHandle bvVarRightShift(ExprHandle expr, ExprHandle amount, unsigned width); + ExprHandle bvVarArithRightShift(ExprHandle expr, ExprHandle amount, unsigned width); + + ExprHandle constructAShrByConstant(ExprHandle expr, unsigned shift, + ExprHandle isSigned, unsigned shiftBits); + ExprHandle constructMulByConstant(ExprHandle expr, unsigned width, uint64_t x); + ExprHandle constructUDivByConstant(ExprHandle expr_n, unsigned width, uint64_t d); + ExprHandle constructSDivByConstant(ExprHandle expr_n, unsigned width, uint64_t d); + + ::VCExpr getInitialArray(const Array *os); + ::VCExpr getArrayForUpdate(const Array *root, const UpdateNode *un); + + ExprHandle constructActual(ref<Expr> e, int *width_out); + ExprHandle construct(ref<Expr> e, int *width_out); + + ::VCExpr buildVar(const char *name, unsigned width); + ::VCExpr buildArray(const char *name, unsigned indexWidth, unsigned valueWidth); + +public: + STPBuilder(::VC _vc, bool _optimizeDivides=true); + ~STPBuilder(); + + ExprHandle getTrue(); + ExprHandle getFalse(); + ExprHandle getTempVar(Expr::Width w); + ExprHandle getInitialRead(const Array *os, unsigned index); + + ExprHandle construct(ref<Expr> e) { + ExprHandle res = construct(e, 0); + constructed.clear(); + return res; + } +}; + +} + +#endif diff --git a/lib/Solver/Solver.cpp b/lib/Solver/Solver.cpp new file mode 100644 index 00000000..24d3ef86 --- /dev/null +++ b/lib/Solver/Solver.cpp @@ -0,0 +1,643 @@ +//===-- Solver.cpp --------------------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "klee/Solver.h" +#include "klee/SolverImpl.h" + +#include "SolverStats.h" +#include "STPBuilder.h" + +#include "klee/Constraints.h" +#include "klee/Expr.h" +#include "klee/TimerStatIncrementer.h" +#include "klee/util/Assignment.h" +#include "klee/util/ExprPPrinter.h" +#include "klee/util/ExprUtil.h" +#include "klee/Internal/Support/Timer.h" + +#define vc_bvBoolExtract IAMTHESPAWNOFSATAN + +#include <cassert> +#include <map> +#include <vector> + +#include <sys/wait.h> +#include <sys/ipc.h> +#include <sys/shm.h> + +using namespace klee; + +/***/ + +const char *Solver::validity_to_str(Validity v) { + switch (v) { + default: return "Unknown"; + case True: return "True"; + case False: return "False"; + } +} + +Solver::~Solver() { + delete impl; +} + +SolverImpl::~SolverImpl() { +} + +bool Solver::evaluate(const Query& query, Validity &result) { + assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!"); + + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr.getConstantValue() ? True : False; + return true; + } + + return impl->computeValidity(query, result); +} + +bool SolverImpl::computeValidity(const Query& query, Solver::Validity &result) { + bool isTrue, isFalse; + if (!computeTruth(query, isTrue)) + return false; + if (isTrue) { + result = Solver::True; + } else { + if (!computeTruth(query.negateExpr(), isFalse)) + return false; + result = isFalse ? Solver::False : Solver::Unknown; + } + return true; +} + +bool Solver::mustBeTrue(const Query& query, bool &result) { + assert(query.expr.getWidth() == Expr::Bool && "Invalid expression type!"); + + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr.getConstantValue() ? true : false; + return true; + } + + return impl->computeTruth(query, result); +} + +bool Solver::mustBeFalse(const Query& query, bool &result) { + return mustBeTrue(query.negateExpr(), result); +} + +bool Solver::mayBeTrue(const Query& query, bool &result) { + bool res; + if (!mustBeFalse(query, res)) + return false; + result = !res; + return true; +} + +bool Solver::mayBeFalse(const Query& query, bool &result) { + bool res; + if (!mustBeTrue(query, res)) + return false; + result = !res; + return true; +} + +bool Solver::getValue(const Query& query, ref<Expr> &result) { + // Maintain invariants implementation expect. + if (query.expr.isConstant()) { + result = query.expr; + return true; + } + + return impl->computeValue(query, result); +} + +bool +Solver::getInitialValues(const Query& query, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values) { + bool hasSolution; + bool success = + impl->computeInitialValues(query, objects, values, hasSolution); + // FIXME: Propogate this out. + if (!hasSolution) + return false; + + return success; +} + +std::pair< ref<Expr>, ref<Expr> > Solver::getRange(const Query& query) { + ref<Expr> e = query.expr; + Expr::Width width = e.getWidth(); + uint64_t min, max; + + if (width==1) { + Solver::Validity result; + if (!evaluate(query, result)) + assert(0 && "computeValidity failed"); + switch (result) { + case Solver::True: + min = max = 1; break; + case Solver::False: + min = max = 0; break; + default: + min = 0, max = 1; break; + } + } else if (e.isConstant()) { + min = max = e.getConstantValue(); + } else { + // binary search for # of useful bits + uint64_t lo=0, hi=width, mid, bits=0; + while (lo<hi) { + mid = (lo+hi)/2; + bool res; + bool success = + mustBeTrue(query.withExpr( + EqExpr::create(LShrExpr::create(e, + ConstantExpr::create(mid, + width)), + ConstantExpr::create(0, width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + + bits = lo; + } + + // could binary search for training zeros and offset + // min max but unlikely to be very useful + + // check common case + bool res = false; + bool success = + mayBeTrue(query.withExpr(EqExpr::create(e, ConstantExpr::create(0, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + min = 0; + } else { + // binary search for min + lo=0, hi=bits64::maxValueOfNBits(bits); + while (lo<hi) { + mid = (lo+hi)/2; + bool res = false; + bool success = + mayBeTrue(query.withExpr(UleExpr::create(e, + ConstantExpr::create(mid, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + } + + min = lo; + } + + // binary search for max + lo=min, hi=bits64::maxValueOfNBits(bits); + while (lo<hi) { + mid = (lo+hi)/2; + bool res; + bool success = + mustBeTrue(query.withExpr(UleExpr::create(e, + ConstantExpr::create(mid, + width))), + res); + assert(success && "FIXME: Unhandled solver failure"); + if (res) { + hi = mid; + } else { + lo = mid+1; + } + } + + max = lo; + } + + return std::make_pair(ConstantExpr::create(min, width), + ConstantExpr::create(max, width)); +} + +/***/ + +class ValidatingSolver : public SolverImpl { +private: + Solver *solver, *oracle; + +public: + ValidatingSolver(Solver *_solver, Solver *_oracle) + : solver(_solver), oracle(_oracle) {} + ~ValidatingSolver() { delete solver; } + + bool computeValidity(const Query&, Solver::Validity &result); + bool computeTruth(const Query&, bool &isValid); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +bool ValidatingSolver::computeTruth(const Query& query, + bool &isValid) { + bool answer; + + if (!solver->impl->computeTruth(query, isValid)) + return false; + if (!oracle->impl->computeTruth(query, answer)) + return false; + + if (isValid != answer) + assert(0 && "invalid solver result (computeTruth)"); + + return true; +} + +bool ValidatingSolver::computeValidity(const Query& query, + Solver::Validity &result) { + Solver::Validity answer; + + if (!solver->impl->computeValidity(query, result)) + return false; + if (!oracle->impl->computeValidity(query, answer)) + return false; + + if (result != answer) + assert(0 && "invalid solver result (computeValidity)"); + + return true; +} + +bool ValidatingSolver::computeValue(const Query& query, + ref<Expr> &result) { + bool answer; + + if (!solver->impl->computeValue(query, result)) + return false; + // We don't want to compare, but just make sure this is a legal + // solution. + if (!oracle->impl->computeTruth(query.withExpr(NeExpr::create(query.expr, + result)), + answer)) + return false; + + if (answer) + assert(0 && "invalid solver result (computeValue)"); + + return true; +} + +bool +ValidatingSolver::computeInitialValues(const Query& query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + bool answer; + + if (!solver->impl->computeInitialValues(query, objects, values, + hasSolution)) + return false; + + if (hasSolution) { + // Assert the bindings as constraints, and verify that the + // conjunction of the actual constraints is satisfiable. + std::vector< ref<Expr> > bindings; + for (unsigned i = 0; i != values.size(); ++i) { + const Array *array = objects[i]; + for (unsigned j=0; j<array->size; j++) { + unsigned char value = values[i][j]; + bindings.push_back(EqExpr::create(ReadExpr::create(UpdateList(array, + true, 0), + ref<Expr>(j, Expr::Int32)), + ref<Expr>(value, Expr::Int8))); + } + } + ConstraintManager tmp(bindings); + ref<Expr> constraints = Expr::createNot(query.expr); + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + constraints = AndExpr::create(constraints, *it); + + if (!oracle->impl->computeTruth(Query(tmp, constraints), answer)) + return false; + if (!answer) + assert(0 && "invalid solver result (computeInitialValues)"); + } else { + if (!oracle->impl->computeTruth(query, answer)) + return false; + if (!answer) + assert(0 && "invalid solver result (computeInitialValues)"); + } + + return true; +} + +Solver *klee::createValidatingSolver(Solver *s, Solver *oracle) { + return new Solver(new ValidatingSolver(s, oracle)); +} + +/***/ + +class STPSolverImpl : public SolverImpl { +private: + /// The solver we are part of, for access to public information. + STPSolver *solver; + VC vc; + STPBuilder *builder; + double timeout; + bool useForkedSTP; + +public: + STPSolverImpl(STPSolver *_solver, bool _useForkedSTP); + ~STPSolverImpl(); + + char *getConstraintLog(const Query&); + void setTimeout(double _timeout) { timeout = _timeout; } + + bool computeTruth(const Query&, bool &isValid); + bool computeValue(const Query&, ref<Expr> &result); + bool computeInitialValues(const Query&, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution); +}; + +static unsigned char *shared_memory_ptr; +static const unsigned shared_memory_size = 1<<20; +static int shared_memory_id; + +static void stp_error_handler(const char* err_msg) { + fprintf(stderr, "error: STP Error: %s\n", err_msg); + abort(); +} + +STPSolverImpl::STPSolverImpl(STPSolver *_solver, bool _useForkedSTP) + : solver(_solver), + vc(vc_createValidityChecker()), + builder(new STPBuilder(vc)), + timeout(0.0), + useForkedSTP(_useForkedSTP) +{ + assert(vc && "unable to create validity checker"); + assert(builder && "unable to create STPBuilder"); + + vc_registerErrorHandler(::stp_error_handler); + + if (useForkedSTP) { + shared_memory_id = shmget(IPC_PRIVATE, shared_memory_size, IPC_CREAT | 0700); + assert(shared_memory_id>=0 && "shmget failed"); + shared_memory_ptr = (unsigned char*) shmat(shared_memory_id, NULL, 0); + assert(shared_memory_ptr!=(void*)-1 && "shmat failed"); + shmctl(shared_memory_id, IPC_RMID, NULL); + } +} + +STPSolverImpl::~STPSolverImpl() { + delete builder; + + vc_Destroy(vc); +} + +/***/ + +STPSolver::STPSolver(bool useForkedSTP) + : Solver(new STPSolverImpl(this, useForkedSTP)) +{ +} + +char *STPSolver::getConstraintLog(const Query &query) { + return static_cast<STPSolverImpl*>(impl)->getConstraintLog(query); +} + +void STPSolver::setTimeout(double timeout) { + static_cast<STPSolverImpl*>(impl)->setTimeout(timeout); +} + +/***/ + +char *STPSolverImpl::getConstraintLog(const Query &query) { + vc_push(vc); + for (std::vector< ref<Expr> >::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + vc_assertFormula(vc, builder->construct(*it)); + assert(query.expr == ref<Expr>(0, Expr::Bool) && + "Unexpected expression in query!"); + + char *buffer; + unsigned long length; + vc_printQueryStateToBuffer(vc, builder->getFalse(), + &buffer, &length, false); + vc_pop(vc); + + return buffer; +} + +bool STPSolverImpl::computeTruth(const Query& query, + bool &isValid) { + std::vector<const Array*> objects; + std::vector< std::vector<unsigned char> > values; + bool hasSolution; + + if (!computeInitialValues(query, objects, values, hasSolution)) + return false; + + isValid = !hasSolution; + return true; +} + +bool STPSolverImpl::computeValue(const Query& query, + ref<Expr> &result) { + std::vector<const Array*> objects; + std::vector< std::vector<unsigned char> > values; + bool hasSolution; + + // Find the object used in the expression, and compute an assignment + // for them. + findSymbolicObjects(query.expr, objects); + if (!computeInitialValues(query.withFalse(), objects, values, hasSolution)) + return false; + assert(hasSolution && "state has invalid constraint set"); + + // Evaluate the expression with the computed assignment. + Assignment a(objects, values); + result = a.evaluate(query.expr); + + return true; +} + +static void runAndGetCex(::VC vc, STPBuilder *builder, ::VCExpr q, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > &values, + bool &hasSolution) { + // XXX I want to be able to timeout here, safely + hasSolution = !vc_query(vc, q); + + if (hasSolution) { + values.reserve(objects.size()); + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + std::vector<unsigned char> data; + + data.reserve(array->size); + for (unsigned offset = 0; offset < array->size; offset++) { + ExprHandle counter = + vc_getCounterExample(vc, builder->getInitialRead(array, offset)); + unsigned char val = getBVUnsigned(counter); + data.push_back(val); + } + + values.push_back(data); + } + } +} + +static void stpTimeoutHandler(int x) { + _exit(52); +} + +static bool runAndGetCexForked(::VC vc, + STPBuilder *builder, + ::VCExpr q, + const std::vector<const Array*> &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution, + double timeout) { + unsigned char *pos = shared_memory_ptr; + unsigned sum = 0; + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) + sum += (*it)->size; + assert(sum<shared_memory_size && "not enough shared memory for counterexample"); + + fflush(stdout); + fflush(stderr); + int pid = fork(); + if (pid==-1) { + fprintf(stderr, "error: fork failed (for STP)"); + return false; + } + + if (pid == 0) { + if (timeout) { + ::alarm(0); /* Turn off alarm so we can safely set signal handler */ + ::signal(SIGALRM, stpTimeoutHandler); + ::alarm(std::max(1, (int)timeout)); + } + unsigned res = vc_query(vc, q); + if (!res) { + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + for (unsigned offset = 0; offset < array->size; offset++) { + ExprHandle counter = + vc_getCounterExample(vc, builder->getInitialRead(array, offset)); + *pos++ = getBVUnsigned(counter); + } + } + } + _exit(res); + } else { + int status; + int res = waitpid(pid, &status, 0); + + if (res<0) { + fprintf(stderr, "error: waitpid() for STP failed"); + return false; + } + + // From timed_run.py: It appears that linux at least will on + // "occasion" return a status when the process was terminated by a + // signal, so test signal first. + if (WIFSIGNALED(status) || !WIFEXITED(status)) { + fprintf(stderr, "error: STP did not return successfully"); + return false; + } + + int exitcode = WEXITSTATUS(status); + if (exitcode==0) { + hasSolution = true; + } else if (exitcode==1) { + hasSolution = false; + } else if (exitcode==52) { + fprintf(stderr, "error: STP timed out"); + return false; + } else { + fprintf(stderr, "error: STP did not return a recognized code"); + return false; + } + + if (hasSolution) { + values = std::vector< std::vector<unsigned char> >(objects.size()); + unsigned i=0; + for (std::vector<const Array*>::const_iterator + it = objects.begin(), ie = objects.end(); it != ie; ++it) { + const Array *array = *it; + std::vector<unsigned char> &data = values[i++]; + data.insert(data.begin(), pos, pos + array->size); + pos += array->size; + } + } + + return true; + } +} + +bool +STPSolverImpl::computeInitialValues(const Query &query, + const std::vector<const Array*> + &objects, + std::vector< std::vector<unsigned char> > + &values, + bool &hasSolution) { + TimerStatIncrementer t(stats::queryTime); + + vc_push(vc); + + for (ConstraintManager::const_iterator it = query.constraints.begin(), + ie = query.constraints.end(); it != ie; ++it) + vc_assertFormula(vc, builder->construct(*it)); + + ++stats::queries; + ++stats::queryCounterexamples; + + ExprHandle stp_e = builder->construct(query.expr); + + bool success; + if (useForkedSTP) { + success = runAndGetCexForked(vc, builder, stp_e, objects, values, + hasSolution, timeout); + } else { + runAndGetCex(vc, builder, stp_e, objects, values, hasSolution); + success = true; + } + + if (success) { + if (hasSolution) + ++stats::queriesInvalid; + else + ++stats::queriesValid; + } + + vc_pop(vc); + + return success; +} diff --git a/lib/Solver/SolverStats.cpp b/lib/Solver/SolverStats.cpp new file mode 100644 index 00000000..9d48792a --- /dev/null +++ b/lib/Solver/SolverStats.cpp @@ -0,0 +1,23 @@ +//===-- SolverStats.cpp ---------------------------------------------------===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "SolverStats.h" + +using namespace klee; + +Statistic stats::cexCacheTime("CexCacheTime", "CCtime"); +Statistic stats::queries("Queries", "Q"); +Statistic stats::queriesInvalid("QueriesInvalid", "Qiv"); +Statistic stats::queriesValid("QueriesValid", "Qv"); +Statistic stats::queryCacheHits("QueryCacheHits", "QChits") ; +Statistic stats::queryCacheMisses("QueryCacheMisses", "QCmisses"); +Statistic stats::queryConstructTime("QueryConstructTime", "QBtime") ; +Statistic stats::queryConstructs("QueriesConstructs", "QB"); +Statistic stats::queryCounterexamples("QueriesCEX", "Qcex"); +Statistic stats::queryTime("QueryTime", "Qtime"); diff --git a/lib/Solver/SolverStats.h b/lib/Solver/SolverStats.h new file mode 100644 index 00000000..6fee7699 --- /dev/null +++ b/lib/Solver/SolverStats.h @@ -0,0 +1,32 @@ +//===-- SolverStats.h -------------------------------------------*- C++ -*-===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef KLEE_SOLVERSTATS_H +#define KLEE_SOLVERSTATS_H + +#include "klee/Statistic.h" + +namespace klee { +namespace stats { + + extern Statistic cexCacheTime; + extern Statistic queries; + extern Statistic queriesInvalid; + extern Statistic queriesValid; + extern Statistic queryCacheHits; + extern Statistic queryCacheMisses; + extern Statistic queryConstructTime; + extern Statistic queryConstructs; + extern Statistic queryCounterexamples; + extern Statistic queryTime; + +} +} + +#endif |