//===-- ExprSMTLIBLetPrinter.cpp ------------------------------------------*-
//C++ -*-===//
//
//                     The KLEE Symbolic Virtual Machine
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/CommandLine.h"
#include "klee/util/ExprSMTLIBLetPrinter.h"

namespace ExprSMTLIBOptions {
llvm::cl::opt<bool>
useLetExpressions("smtlib-use-let-expressions",
                  llvm::cl::desc("Enables generated SMT-LIBv2 files to use "
                                 "(let) expressions (default=on)"),
                  llvm::cl::init(true));
}

namespace klee {
const char ExprSMTLIBLetPrinter::BINDING_PREFIX[] = "?B";

ExprSMTLIBLetPrinter::ExprSMTLIBLetPrinter()
    : bindings(), firstEO(), twoOrMoreEO(), disablePrintedAbbreviations(false) {
}

void ExprSMTLIBLetPrinter::generateOutput() {
  if (p == NULL || query == NULL || o == NULL) {
    llvm::errs() << "Can't print SMTLIBv2. Output or query bad!\n";
    return;
  }

  generateBindings();

  if (isHumanReadable())
    printNotice();
  printOptions();
  printSetLogic();
  printArrayDeclarations();
  printLetExpression();
  printAction();
  printExit();
}

void ExprSMTLIBLetPrinter::scan(const ref<Expr> &e) {
  if (isa<ConstantExpr>(e))
    return; // we don't need to scan simple constants

  if (firstEO.insert(e).second) {
    // We've not seen this expression before

    if (const ReadExpr *re = dyn_cast<ReadExpr>(e)) {

      // Attempt to insert array and if array wasn't present before do more
      // things
      if (usedArrays.insert(re->updates.root).second) {

        // check if the array is constant
        if (re->updates.root->isConstantArray())
          haveConstantArray = true;

        // scan the update list
        scanUpdates(re->updates.head);
      }
    }

    // recurse into the children
    Expr *ep = e.get();
    for (unsigned int i = 0; i < ep->getNumKids(); i++)
      scan(ep->getKid(i));
  } else {
    /* We must of seen the expression before. Add it to
     * the set of twoOrMoreOccurances. We don't need to
     * check if the insertion fails.
     */
    twoOrMoreEO.insert(e);
  }
}

void ExprSMTLIBLetPrinter::generateBindings() {
  // Assign a number to each binding that will be used
  unsigned int counter = 0;
  for (std::set<ref<Expr> >::const_iterator i = twoOrMoreEO.begin();
       i != twoOrMoreEO.end(); ++i) {
    bindings.insert(std::make_pair(*i, counter));
    ++counter;
  }
}

void ExprSMTLIBLetPrinter::printExpression(
    const ref<Expr> &e, ExprSMTLIBPrinter::SMTLIB_SORT expectedSort) {
  std::map<const ref<Expr>, unsigned int>::const_iterator i = bindings.find(e);

  if (disablePrintedAbbreviations || i == bindings.end()) {
    /*There is no abbreviation for this expression so print it normally.
     * Do this by using the parent method.
     */
    ExprSMTLIBPrinter::printExpression(e, expectedSort);
  } else {
    // Use binding name e.g. "?B1"

    /* Handle the corner case where the expectedSort
     * doesn't match the sort of the abbreviation. Not really very efficient
     * (calls bindings.find() twice),
     * we'll cast and call ourself again but with the correct expectedSort.
     */
    if (getSort(e) != expectedSort) {
      printCastToSort(e, expectedSort);
      return;
    }

    // No casting is needed in this depth of recursion, just print the
    // abbreviation
    *p << BINDING_PREFIX << i->second;
  }
}

void ExprSMTLIBLetPrinter::reset() {
  // Let parent clean up first
  ExprSMTLIBPrinter::reset();

  firstEO.clear();
  twoOrMoreEO.clear();
  bindings.clear();
}

void ExprSMTLIBLetPrinter::printLetExpression() {
  *p << "(assert";
  p->pushIndent();
  printSeperator();

  if (bindings.size() != 0) {
    // Only print let expression if we have bindings to use.
    *p << "(let";
    p->pushIndent();
    printSeperator();
    *p << "( ";
    p->pushIndent();

    // Print each binding
    for (std::map<const ref<Expr>, unsigned int>::const_iterator i =
             bindings.begin();
         i != bindings.end(); ++i) {
      printSeperator();
      *p << "(" << BINDING_PREFIX << i->second << " ";
      p->pushIndent();

      // Disable abbreviations so none are used here.
      disablePrintedAbbreviations = true;

      // We can abbreviate SORT_BOOL or SORT_BITVECTOR in let expressions
      printExpression(i->first, getSort(i->first));

      p->popIndent();
      printSeperator();
      *p << ")";
    }

    p->popIndent();
    printSeperator();
    *p << ")";
    printSeperator();

    // Re-enable printing abbreviations.
    disablePrintedAbbreviations = false;
  }

  // print out Expressions with abbreviations.
  unsigned int numberOfItems = query->constraints.size() + 1; //+1 for query
  unsigned int itemsLeft = numberOfItems;
  std::vector<ref<Expr> >::const_iterator constraint =
      query->constraints.begin();

  /* Produce nested (and () () statements. If the constraint set
   * is empty then we will only print the "queryAssert".
   */
  for (; itemsLeft != 0; --itemsLeft) {
    if (itemsLeft >= 2) {
      *p << "(and";
      p->pushIndent();
      printSeperator();
      printExpression(*constraint,
                      SORT_BOOL); // We must and together bool expressions
      printSeperator();
      ++constraint;
      continue;
    } else {
      // must have 1 item left (i.e. the "queryAssert")
      if (isHumanReadable()) {
        *p << "; query";
        p->breakLineI();
      }
      printExpression(queryAssert,
                      SORT_BOOL); // The query must be a bool expression
    }
  }

  /* Produce closing brackets for nested "and statements".
   * Number of "nested ands" = numberOfItems -1
   */
  itemsLeft = numberOfItems - 1;
  for (; itemsLeft != 0; --itemsLeft) {
    p->popIndent();
    printSeperator();
    *p << ")";
  }

  if (bindings.size() != 0) {
    // end let expression
    p->popIndent();
    printSeperator();
    *p << ")";
    printSeperator();
  }

  // end assert
  p->popIndent();
  printSeperator();
  *p << ")";
  p->breakLineI();
}

ExprSMTLIBPrinter *createSMTLIBPrinter() {
  if (ExprSMTLIBOptions::useLetExpressions)
    return new ExprSMTLIBLetPrinter();
  else
    return new ExprSMTLIBPrinter();
}
}