//===-- ExprPPrinter.cpp -   ----------------------------------------------===//
//
//                     The KLEE Symbolic Virtual Machine
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "klee/util/ExprPPrinter.h"

#include "klee/Constraints.h"

#include "llvm/Support/CommandLine.h"

#include <map>
#include <vector>
#include <iostream>
#include <sstream>
#include <iomanip>

using namespace klee;

namespace {
  llvm::cl::opt<bool>
  PCWidthAsArg("pc-width-as-arg", llvm::cl::init(true));

  llvm::cl::opt<bool>
  PCAllWidths("pc-all-widths", llvm::cl::init(false));

  llvm::cl::opt<bool>
  PCPrefixWidth("pc-prefix-width", llvm::cl::init(true));

  llvm::cl::opt<bool>
  PCMultibyteReads("pc-multibyte-reads", llvm::cl::init(true));

  llvm::cl::opt<bool>
  PCAllConstWidths("pc-all-const-widths",  llvm::cl::init(false));
}

/// PrintContext - Helper class for storing extra information for
/// the pretty printer.
class PrintContext {
private:
  std::ostream &os;
  std::stringstream ss;
  std::string newline;

public:
  /// Number of characters on the current line.
  unsigned pos;

public:
  PrintContext(std::ostream &_os) : os(_os), newline("\n"), pos(0) {}

  void setNewline(const std::string &_newline) {
    newline = _newline;
  }

  void breakLine(unsigned indent=0) {
    os << newline;
    if (indent)
      os << std::setw(indent) << ' ';
    pos = indent;
  }

  /// write - Output a string to the stream and update the
  /// position. The stream should not have any newlines.
  void write(const std::string &s) {
    os << s;
    pos += s.length();
  }

  template <typename T>
  PrintContext &operator<<(T elt) {
    ss.str("");
    ss << elt;
    write(ss.str());
    return *this;
  }
};

class PPrinter : public ExprPPrinter {
  std::map<ref<Expr>, unsigned> bindings;
  std::map<const UpdateNode*, unsigned> updateBindings;
  std::set< ref<Expr> > couldPrint, shouldPrint;
  std::set<const UpdateNode*> couldPrintUpdates, shouldPrintUpdates;
  std::ostream &os;
  unsigned counter;
  unsigned updateCounter;
  bool hasScan;
  std::string newline;

  /// shouldPrintWidth - Predicate for whether this expression should
  /// be printed with its width.
  bool shouldPrintWidth(ref<Expr> e) {
    if (PCAllWidths)
      return true;
    return e.getWidth() != Expr::Bool;
  }

  bool isVerySimple(const ref<Expr> &e) { 
    return e.isConstant() || bindings.find(e)!=bindings.end();
  }

  bool isVerySimpleUpdate(const UpdateNode *un) {
    return !un || updateBindings.find(un)!=updateBindings.end();
  }


  // document me!
  bool isSimple(const ref<Expr> &e) { 
    if (isVerySimple(e)) {
      return true;
    } else if (const ReadExpr *re = dyn_ref_cast<ReadExpr>(e)) {
      return isVerySimple(re->index) && isVerySimpleUpdate(re->updates.head);
    } else {
      Expr *ep = e.get();
      for (unsigned i=0; i<ep->getNumKids(); i++)
        if (!isVerySimple(ep->getKid(i)))
          return false;
      return true;
    }
  }

  bool hasSimpleKids(const Expr *ep) {
      for (unsigned i=0; i<ep->getNumKids(); i++)
        if (!isSimple(ep->getKid(i)))
          return false;
      return true;
  }
  
  void scanUpdate(const UpdateNode *un) {
    if (un) {
      if (couldPrintUpdates.insert(un).second) {
        scanUpdate(un->next);
        scan1(un->index);
        scan1(un->value);
      } else {
        shouldPrintUpdates.insert(un);
      }
    }
  }

  void scan1(const ref<Expr> &e) {
    if (!e.isConstant()) {
      if (couldPrint.insert(e).second) {
        Expr *ep = e.get();
        for (unsigned i=0; i<ep->getNumKids(); i++)
          scan1(ep->getKid(i));
        if (const ReadExpr *re = dyn_ref_cast<ReadExpr>(e)) 
          scanUpdate(re->updates.head);
      } else {
        shouldPrint.insert(e);
      }
    }
  }

  void printUpdateList(const UpdateList &updates, PrintContext &PC) {
    const UpdateNode *head = updates.head;

    // Special case empty list.
    if (!head) {
      if (updates.isRooted) {
        PC << "arr" << updates.root->id;
      } else {
        PC << "[]";
      }
      return;
    }

    // FIXME: Explain this breaking policy.
    bool openedList = false, nextShouldBreak = false;
    unsigned outerIndent = PC.pos;
    unsigned middleIndent = 0;
    for (const UpdateNode *un = head; un; un = un->next) {      
      // We are done if we hit the cache.
      std::map<const UpdateNode*, unsigned>::iterator it = 
        updateBindings.find(un);
      if (it!=updateBindings.end()) {
        if (openedList)
          PC << "] @ ";
        PC << "U" << it->second;
        return;
      } else if (!hasScan || shouldPrintUpdates.count(un)) {
        if (openedList)
          PC << "] @";
        if (un != head)
          PC.breakLine(outerIndent);
        PC << "U" << updateCounter << ":"; 
        updateBindings.insert(std::make_pair(un, updateCounter++));
        openedList = nextShouldBreak = false;
     }
    
      if (!openedList) {
        openedList = 1;
        PC << '[';
        middleIndent = PC.pos;
      } else {
        PC << ',';
        printSeparator(PC, !nextShouldBreak, middleIndent);
      }
      //PC << "(=";
      //unsigned innerIndent = PC.pos;
      print(un->index, PC);
      //printSeparator(PC, isSimple(un->index), innerIndent);
      PC << "=";
      print(un->value, PC);
      //PC << ')';
      
      nextShouldBreak = !(un->index.isConstant() && un->value.isConstant());
    }

    if (openedList)
      PC << ']';

    if (updates.isRooted)
      PC << " @ arr" << updates.root->id;
  }

  void printWidth(PrintContext &PC, ref<Expr> e) {
    if (!shouldPrintWidth(e))
      return;

    if (PCWidthAsArg) {
      PC << ' ';
      if (PCPrefixWidth)
        PC << 'w';
    }

    PC << e.getWidth();
  }

  /// hasOrderedReads - True iff all children are reads with
  /// consecutive offsets according to the given \arg stride.
  bool hasOrderedReads(const Expr *ep, int stride) {
    const ReadExpr *base = dyn_ref_cast<ReadExpr>(ep->getKid(0));
    if (!base)
      return false;

    // Get stride expr in proper index width.
    Expr::Width idxWidth = base->index.getWidth();
    ref<Expr> strideExpr(stride, idxWidth), offset(0, idxWidth);
    for (unsigned i=1; i<ep->getNumKids(); ++i) {
      const ReadExpr *re = dyn_ref_cast<ReadExpr>(ep->getKid(i));
      if (!re) 
        return false;

      // Check if the index follows the stride. 
      // FIXME: How aggressive should this be simplified. The
      // canonicalizing builder is probably the right choice, but this
      // is yet another area where we would really prefer it to be
      // global or else use static methods.
      offset = AddExpr::create(offset, strideExpr);
      if (SubExpr::create(re->index, base->index) != offset)
        return false;
    }

    return true;
  }

  /// hasAllByteReads - True iff all children are byte level reads.
  bool hasAllByteReads(const Expr *ep) {
    for (unsigned i=0; i<ep->getNumKids(); ++i) {
      const ReadExpr *re = dyn_ref_cast<ReadExpr>(ep->getKid(i));
      if (!re || re->getWidth() != Expr::Int8)
        return false;
    }
    return true;
  }

  void printRead(const ReadExpr *re, PrintContext &PC, unsigned indent) {
    print(re->index, PC);
    printSeparator(PC, isVerySimple(re->index), indent);
    printUpdateList(re->updates, PC);
  }

  void printExtract(const ExtractExpr *ee, PrintContext &PC, unsigned indent) {
    PC << ee->offset << ' ';
    print(ee->expr, PC);
  }

  void printExpr(const Expr *ep, PrintContext &PC, unsigned indent, bool printConstWidth=false) {
    bool simple = hasSimpleKids(ep);
    
    print(ep->getKid(0), PC);
    for (unsigned i=1; i<ep->getNumKids(); i++) {
      printSeparator(PC, simple, indent);
      print(ep->getKid(i), PC, printConstWidth);
    }
  }

public:
  PPrinter(std::ostream &_os) : os(_os), newline("\n") {
    reset();
  }

  void setNewline(const std::string &_newline) {
    newline = _newline;
  }

  void reset() {
    counter = 0;
    updateCounter = 0;
    hasScan = false;
    bindings.clear();
    updateBindings.clear();
    couldPrint.clear();
    shouldPrint.clear();
    couldPrintUpdates.clear();
    shouldPrintUpdates.clear();
  }

  void scan(const ref<Expr> &e) {
    hasScan = true;
    scan1(e);
  }

  void print(const ref<Expr> &e, unsigned level=0) {
    PrintContext PC(os);
    PC.pos = level;
    print(e, PC);
  }

  void printConst(const ref<Expr> &e, PrintContext &PC, bool printWidth) {
    assert(e.isConstant());

    if (e.getWidth() == Expr::Bool)
      PC << (e.getConstantValue() ? "true" : "false");
    else {
      if (PCAllConstWidths)
	printWidth = true;
    
      if (printWidth)
	PC << "(w" << e.getWidth() << " ";

      PC << e.getConstantValue();

      if (printWidth)
	PC << ")";
    }    
  }

  void print(const ref<Expr> &e, PrintContext &PC, bool printConstWidth=false) {
    if (e.isConstant()) 
      printConst(e, PC, printConstWidth);
    else {
      std::map<ref<Expr>, unsigned>::iterator it = bindings.find(e);
      if (it!=bindings.end()) {
        PC << 'N' << it->second;
      } else {
        if (!hasScan || shouldPrint.count(e)) {
          PC << 'N' << counter << ':';
          bindings.insert(std::make_pair(e, counter++));
        }

        // Detect Not.
        // FIXME: This should be in common code.
        if (const EqExpr *ee = dyn_ref_cast<EqExpr>(e)) {
          if (ee->left == ref<Expr>(false, Expr::Bool)) {
            PC << "(Not";
            printWidth(PC, e);
            PC << ' ';
            // FIXME: This is a boom if right is a constant.
            print(ee->right, PC);
            PC << ')';
            return;
          }
        }

        // Detect multibyte reads.
        // FIXME: Hrm. One problem with doing this is that we are
        // masking the sharing of the indices which aren't
        // visible. Need to think if this matters... probably not
        // because if they are offset reads then its either constant,
        // or they are (base + offset) and base will get printed with
        // a declaration.
        if (PCMultibyteReads && e.getKind() == Expr::Concat) {
          const Expr *ep = e.get();
          if (hasAllByteReads(ep)) {
            bool isMSB = hasOrderedReads(ep, 1);
            if (isMSB || hasOrderedReads(ep, -1)) {
              PC << "(Read" << (isMSB ? "MSB" : "LSB");
              printWidth(PC, e);
              PC << ' ';
              unsigned firstIdx = isMSB ? 0 : ep->getNumKids()-1;
              printRead(static_ref_cast<ReadExpr>(ep->getKid(firstIdx)), 
                        PC, PC.pos);
              PC << ')';
              return;
            }
          }
        }

	PC << '(' << e.getKind();
        printWidth(PC, e);
        PC << ' ';

        // Indent at first argument and dispatch to appropriate print
        // routine for exprs which require special handling.
        unsigned indent = PC.pos;
        if (const ReadExpr *re = dyn_ref_cast<ReadExpr>(e)) {
          printRead(re, PC, indent);
        } else if (const ExtractExpr *ee = dyn_ref_cast<ExtractExpr>(e)) {
          printExtract(ee, PC, indent);
        } else if (e.getKind() == Expr::Concat || e.getKind() == Expr::SExt)
	  printExpr(e.get(), PC, indent, true);
	else
          printExpr(e.get(), PC, indent);	
        PC << ")";
      }
    }
  }

  /* Public utility functions */

  void printSeparator(PrintContext &PC, bool simple, unsigned indent) {
    if (simple) {
      PC << ' ';
    } else {
      PC.breakLine(indent);
    }
  }
};

ExprPPrinter *klee::ExprPPrinter::create(std::ostream &os) {
  return new PPrinter(os);
}

void ExprPPrinter::printOne(std::ostream &os,
                            const char *message, 
                            const ref<Expr> &e) {
  PPrinter p(os);
  p.scan(e);

  // FIXME: Need to figure out what to do here. Probably print as a
  // "forward declaration" with whatever syntax we pick for that.
  PrintContext PC(os);
  PC << message << ": ";
  p.print(e, PC);
  PC.breakLine();
}

void ExprPPrinter::printConstraints(std::ostream &os,
                                    const ConstraintManager &constraints) {
  printQuery(os, constraints, ref<Expr>(false, Expr::Bool));
}

void ExprPPrinter::printQuery(std::ostream &os,
                              const ConstraintManager &constraints,
                              const ref<Expr> &q) {
  PPrinter p(os);
  
  for (ConstraintManager::const_iterator it = constraints.begin(),
         ie = constraints.end(); it != ie; ++it)
    p.scan(*it);
  p.scan(q);

  PrintContext PC(os);
  PC << "(query [";
  
  // Ident at constraint list;
  unsigned indent = PC.pos;
  for (ConstraintManager::const_iterator it = constraints.begin(),
         ie = constraints.end(); it != ie;) {
    p.print(*it, PC);
    ++it;
    if (it != ie)
      PC.breakLine(indent);
  }
  PC << ']';

  p.printSeparator(PC, constraints.empty(), indent-1);
  p.print(q, PC);

  PC << ')';
  PC.breakLine();
}