
open Qhasm
open Stp
open Big_int

exception Undefined of string
exception Unsupported of string

let wordsize = 64

let two = big_int_of_int 2

let rec d2b size n =
  if size = 0 then
    ""
  else
    d2b (size - 1) (div_big_int n two) ^ (string_of_big_int (mod_big_int n two))

let rec padding n i = if i = 0 then n else "0" ^ (padding n (i - 1))

class manager (cmap : (string, string) Hashtbl.t) =
object (self)
  (** the ID of the next STP variable *)
  val mutable var = 0

  (** a map from a Qhasm variable to the corresponding STP variable *)
  val mutable vmap : (qvar, Stp.var) Hashtbl.t = Hashtbl.create 100

  (** a map from a Qhasm variable to its bit-width *)
  val mutable smap : (qvar, int) Hashtbl.t = Hashtbl.create 100

  (** a map from a Qhasm variable to the corresponding index *)
  val mutable imap : (qvar, int) Hashtbl.t = Hashtbl.create 100

  (** a list of created STP statements *)
  val mutable vdecls : Stp.decl list = []
  val mutable vstmts : Stp.stmt list = []
  val mutable queries : bexp list = []

  (** a memory *)
  val mutable memory = None

  (** Assumptions. *)
  val mutable assumptions = []

  (** Assertions. *)
  val mutable assertions = []

  (** Generates a new STP variable for a Qhasm variable. *)
  method newvar qv =
    let i = try (Hashtbl.find imap qv) + 1 with Not_found -> 0 in
    let sv = qv.vname ^ "_" ^ string_of_int i in
    let _ = Hashtbl.replace imap qv i in
    let _ = Hashtbl.replace vmap qv sv in
    let _ = Hashtbl.replace smap qv qv.vsize in
    let _ = self#adddecl (mkvar sv qv.vsize) in
    mkv sv

  (** Returns the latest STP variable for a Qhasm variable. *)
  method getvar qv =
    try
      mkv (Hashtbl.find vmap qv)
    with Not_found ->
      if Hashtbl.mem smap qv then
        raise (Undefined ("No variable is defined in STP for " ^ qv.vname ^ "."))
      else
        (** 
            * This is an instrumentation variable. Create a new variable if the
            * instrumentation variable has not been assigned.
        *)
        self#newvar qv

  (** Returns the STP variable that corresponds to a Qhasm variable with a specified value. *)
  method assign qv e =
    let qv' = self#newvar qv in
    let _ = self#addstmt (mkassert (mkeq qv' e)) in
    qv'

  method const v =
    try
      Hashtbl.find cmap v
    with Not_found ->
      raise (Undefined ("The constant " ^ v ^ " is undefined."))

  (** Inserts an STP variable declaration. *)
  method adddecl (decl : decl) = vdecls <- vdecls@[decl]

  (** Inserts an STP statement. *)
  method addstmt (stmt : stmt) = vstmts <- vstmts@[stmt]

  (** Returns all STP statements as a program. *)
  method getprog = {vdecls = vdecls; vstmts = vstmts; vquery = mkbands queries}

  (** Returns the memory. *)
  method getmemory =
    match memory with
      None ->
        let sv = "_memory_" in
        let _ = self#adddecl (mkarray sv wordsize wordsize) in
        let _ = memory <- Some sv in
        mkv sv
    | Some mem -> mkv mem

  (** Sets the carry. The bit-width of the carry must be 1. *)
  method setcarry e = self#assign (Qhasm.mkvar "carry" 1) e

  (** Returns the carry. The returned value always has one bit. *)
  method getcarry =
    try
      mkv (Hashtbl.find vmap (Qhasm.mkvar "carry" 1))
    with Not_found ->
      self#setcarry (self#genconsti 1 0)

  (** Inserts an assumption. *)
  method add_assumption (e : bexp) = assumptions <- assumptions@[e]

  (** Inserts an assertion. *)
  method add_assertion (e : bexp) = assertions <- assertions@[e]

  (** 
      * Casts an STP variable to a specified bit-width. Both signed and unsigned variables are considered.
      * s1: the original bit-width
      * s2: the desired bit-width
  *)
  method cast sv s1 s2 =
    if s1 < s2 then
      mkextend sv (s2 - s1)
    else if s1 > s2 then
      mkextract sv (s2 - 1) 0
    else
      sv

  (** 
      * Casts an STP variable to a specified bit-width. Only unsigned variables are considered.
      * s1: the original bit-width
      * s2: the desired bit-width
  *)
  method ucast sv s1 s2 =
    if s1 < s2 then
      mkconcat (self#genzero (s2 - s1)) sv
    else if s1 > s2 then
      mkextract sv (s2 - 1) 0
    else
      sv

  (** Inserts a comment. *)
  method gencomment c =
    self#addstmt (mkcomment c)

  (** Returns a binary constant. *)
  method genconst size c = 
    let len = String.length c in
    mkconstb (
      if len = size then
        c
      else if len < size then
        padding c (size - len)
      else
        String.sub c (len - size) size
    )

  (** Returns a decimal constant. *)
  method genconstd size c = mkconstb (d2b size c)

  (** Returns a hexadecimal constant. *)
  method genconsth size c = 
    let len = (String.length c) * 4 in
    mkconsth (
      if len = size then
        c
      else if len < size then
        padding c ((size - len) / 4)
      else
        assert false
    )

  (** Returns a binary constant. *)
  method genconsti size c = self#genconstd size (big_int_of_int c)

  (** Generates the query as ~assumption \/ assertion. *)
  method genquery =
    let assumption = mkbands assumptions in
    let assertion = mkbands assertions in
    queries <- [mkbor (mkbnot assumption) assertion]

  method genzero size = self#genconsti size 0

end

let getshiftsize e =
  match e with
    QIVConst n -> n
  | QIVVar v -> raise (Unsupported "The number of bits shifted should be an integer.")

let genaddr m addr =
  match addr with
    QAddrBO (base, offset) -> mksum wordsize [m#getvar base; m#genconsti wordsize offset]
  | QAddrBI (base, index) -> mksum wordsize [m#getvar base; m#getvar index]
  | QAddrBIS (base, index) -> mksum wordsize [m#getvar base; mkmul wordsize (m#getvar index) (m#genconsti wordsize 8)]
  | QAddrBOIS (base, offset, index) -> mksum wordsize [m#getvar base; m#genconsti wordsize offset; mkmul wordsize (m#getvar index) (m#genconsti wordsize 8)]

let read m addr = mkread m#getmemory (genaddr m addr)

let write m addr value = mkwrite m#getmemory (genaddr m addr) value

let genconstvar m cv =
  match cv with
    QIVConst n -> m#genconsti wordsize n
  | QIVVar v -> m#getvar v

let genvarderef m vd =
  match vd with
    QVDVar v -> m#getvar v
  | QVDDeref (base, offset) -> read m (QAddrBO (base, offset))
  | QVDCoef co -> 
    let v = m#const co.vname in
    if String.length v > 2 && String.sub v 0 2 = "0x" then
      m#genconsth wordsize (String.sub v 2 (String.length v - 2))
    else
      m#genconstd wordsize (big_int_of_string v)

let genexpr m expr =
  match expr with
    QExprConst n -> m#genconsti wordsize n
  | QExprVar v -> m#getvar v
  | QExprCarry -> m#ucast m#getcarry 1 wordsize
  | QExprAddVarVar (v1, v2) -> mksum wordsize [m#getvar v1; m#getvar v2]
  | QExprAddVarVarConst (v1, v2, n) -> mksum wordsize [m#getvar v1; m#getvar v2; m#genconsti wordsize n]
  | QExprAddVarVarVar (v1, v2, v3) -> mksum wordsize [m#getvar v1; m#getvar v2; m#getvar v3]
  | QExprAddVarVarCarry (v1, v2) -> mksum wordsize [m#getvar v1; m#getvar v2; m#ucast m#getcarry 1 wordsize]
  | QExprMulVarConst (v, n) -> mkmul wordsize (m#getvar v) (m#genconsti wordsize n)
  | QExprMulVarCarry v -> mkmul wordsize (m#getvar v) (m#ucast m#getcarry 1 wordsize)

let gencoef m qv co =
  m#assign qv (genvarderef m (QVDCoef co))

let genaddexpr m ?carry:(carry=false) expr =
  let size = if carry then wordsize + 1 else wordsize in
  match expr with
    QAddExprConst n -> m#genconsti size n
  | QAddExprVar v -> m#ucast (m#getvar v) wordsize size
  | QAddExprCarry -> m#ucast m#getcarry 1 size
  | QAddExprDeref (base, offset) -> m#ucast (read m (QAddrBO (base, offset))) wordsize size
  | QAddExprConstCarry n -> mksum size [m#genconsti size n; m#ucast m#getcarry 1 size]
  | QAddExprVarConst (v, n) -> mksum size [m#ucast (m#getvar v) wordsize size; m#genconsti size n]
  | QAddExprVarCarry v -> mksum size [m#ucast (m#getvar v) wordsize size; m#ucast m#getcarry 1 size]
  | QAddExprDerefCarry (base, offset) -> mksum size [m#ucast (read m (QAddrBO (base, offset))) wordsize size; m#ucast m#getcarry 1 size]
  | QAddExprCoef v -> m#ucast (genvarderef m (QVDCoef v)) wordsize size

let gensubexpr m ?carry:(carry=false) expr =
  let size = if carry then wordsize + 1 else wordsize in
  match expr with
    QSubExprConst n -> m#genconsti size n
  | QSubExprVar v -> m#ucast (m#getvar v) wordsize size
  | QSubExprCarry -> m#ucast m#getcarry 1 size
  | QSubExprDeref (base, offset) -> m#ucast (read m (QAddrBO (base, offset))) wordsize size
  | QSubExprVarCarry v -> mksum size [m#ucast (m#getvar v) wordsize size; m#ucast m#getcarry 1 size]
  | QSubExprDerefCarry (base, offset) -> mksum size [m#ucast (read m (QAddrBO (base, offset))) wordsize size; m#ucast m#getcarry 1 size]

let rec genexp m size exp =
  if pure exp then
    m#genconstd size (eval exp)
  else
    match exp with
      QExpConst n -> m#genconstd size n
    | QExpCarry -> m#ucast m#getcarry 1 size
    | QExpVar vd -> genvarderef m vd
    | QExpNeg e -> mkneg (genexp m size e)
    | QExpNot e -> mknot (genexp m size e)
    | QExpCast (signed, e, s) ->
      begin
        if pure e then
          m#genconstd (max size s) (eval e)
        else
          (if signed then m#cast else m#ucast) (genexp m (size_of_exp e) e) (size_of_exp e) s
      end
    | QExpAdd (e1, e2) -> mkadd size (genexp m size e1) (genexp m size e2)
    | QExpSub (e1, e2) -> mksub size (genexp m size e1) (genexp m size e2)
    | QExpMul (e1, e2) -> mkmul size (genexp m size e1) (genexp m size e2)
    | QExpAnd (e1, e2) -> mkand (genexp m size e1) (genexp m size e2)
    | QExpOr (e1, e2) -> mkor (genexp m size e1) (genexp m size e2)
    | QExpXor (e1, e2) -> mkxor (genexp m size e1) (genexp m size e2)
    | QExpSmod (e1, e2) -> mksmod size (genexp m size e1) (genexp m size e2)
    | QExpUmod (e1, e2) -> 
      let e1 = m#ucast (genexp m (size - 1) e1) (size - 1) size in
      let e2 = m#ucast (genexp m (size - 1) e2) (size - 1) size in
      mksmod size e1 e2
    | QExpPow (e, n) ->
      let n = eval_int n in
      let rec helper res base n =
        if n = 1 then
          res
        else
          helper (mkmul size res base) base (n - 1) in
      if n < 0 then
        assert false
      else if n = 0 then
        m#genconsti size 1
      else
        let base = genexp m size e in
        helper base base n
    | QExpConcat (e1, e2) -> mkconcat (genexp m (size_of_exp e1) e1) (genexp m (size_of_exp e2) e2)
    | QExpSll (e1, e2) -> mkextract (mksll (genexp m size e1) (eval_int e2)) (size - 1) 0
    | QExpSrl (e1, e2) -> mksrl (genexp m size e1) (eval_int e2)
    | QExpSra (e1, e2) -> raise (Unsupported "Arithmetic right shifting is unsupported.")
    | QExpSlice (e, i, j) -> mkextract (genexp m (size_of_exp e) e) i j
    | QExpApp (fd, actuals) -> genexp m size ((mkfunctor fd.sexp fd.sformals) actuals)
    | QExpIte (b, e1, e2) -> mkifte (genbexp m b) (genexp m size e1) (genexp m size e2)

and genexps m es = 
  let max = List.fold_left (fun res e -> max res (size_of_exp e)) 0 es in
  (max, List.map (fun e -> genexp m max e) es)

and genbexp m exp =
  let helper f e1 e2 =
    let size = max (size_of_exp e1) (size_of_exp e2) in
    let bv1 = genexp m size e1 in
    let bv2 = genexp m size e2 in
    f bv1 bv2 in
  match exp with
    QBexpTrue -> True
  | QBexpEq (e1, e2) -> helper mkeq e1 e2
  | QBexpNe (e1, e2) -> mkbnot (helper mkeq e1 e2)
  | QBexpSlt (e1, e2) -> helper mkslt e1 e2
  | QBexpSle (e1, e2) -> helper mksle e1 e2
  | QBexpSgt (e1, e2) -> helper mksgt e1 e2
  | QBexpSge (e1, e2) -> helper mksge e1 e2
  | QBexpUlt (e1, e2) -> helper mklt e1 e2
  | QBexpUle (e1, e2) -> helper mkle e1 e2
  | QBexpUgt (e1, e2) -> helper mkgt e1 e2
  | QBexpUge (e1, e2) -> helper mkge e1 e2
  | QBexpNeg e -> mkbnot (genbexp m e)
  | QBexpAnd (e1, e2) -> mkband (genbexp m e1) (genbexp m e2)
  | QBexpOr (e1, e2) -> mkbor (genbexp m e1) (genbexp m e2)
  | QBexpImp (e1, e2) -> mkbor (mkbnot (genbexp m e1)) (genbexp m e2)
  | QBexpApp (p, actuals) -> genbexp m ((mkfunctor_b p.pbexp p.pformals) actuals)

let genannot (m : manager) annot =
  match annot with
    QAuxVar (qv, eop) ->
      begin
        match eop with
          None -> ignore(m#newvar qv)
        | Some e -> ignore(m#assign qv (genexp m (size_of_exp e) e))
      end
  | QConst e -> ignore(genexp m (size_of_exp e) e)
  | QFunction _ -> ()
  | QPredicate _ -> ()
  | QInvariant e -> assert false
  | QAssume e -> m#add_assumption (genbexp m e)
  | QAssert e -> m#add_assertion (genbexp m e)
  | QCut _ -> print_endline "The cut should be replaced by assume and assert."; assert false

let genstmt (m : manager) stmt : unit =
  let _ = m#gencomment (string_of_int stmt.sline ^ ": " ^ string_of_qstmt stmt) in
  match stmt.skind with
    QVar (qt, qv) -> ignore(m#newvar qv)
  | QLoad (qv, qt, addr) -> ignore(m#assign qv (read m addr))
  | QStore (qt, addr, cv) -> ignore(write m addr (genconstvar m cv))
  | QAssign (qv, expr) -> ignore(m#assign qv (genexpr m expr))
  | QAssignIfCarry (qv, expr, neg) -> 
    let c = if neg then mknot m#getcarry else m#getcarry in
    ignore(m#assign qv (mkifte (Eq (c, Const (Bin "1"))) (genexpr m expr) (m#getvar qv)))
  | QCoef (qv, co) -> ignore(gencoef m qv co)
  | QAdd (qv, expr) -> ignore(m#assign qv (mksum wordsize [m#getvar qv; genaddexpr m expr]))
  | QAddCarry (qv, expr) -> 
    let lv = m#ucast (m#getvar qv) wordsize (wordsize + 1) in
    let rv = genaddexpr m ~carry:true expr in
    let sum = mkadd (wordsize + 1) lv rv in
    let carry = mkextract sum wordsize wordsize in
    let e = mkextract sum (wordsize - 1) 0 in
    let _ = m#setcarry carry in
    ignore(m#assign qv e)
  | QSub (qv, expr) -> ignore(m#assign qv (mksub wordsize (m#getvar qv) (gensubexpr m expr)))
  | QSubCarry (qv, expr) ->
    let lv = m#ucast (m#getvar qv) wordsize (wordsize + 1) in
    let rv = gensubexpr m ~carry:true expr in
    let sum = mksub (wordsize + 1) lv rv in
    let carry = mkextract sum wordsize wordsize in
    let e = mkextract sum (wordsize - 1) 0 in
    let _ = m#setcarry carry in
    ignore(m#assign qv e)
  | QMul (qv, expr) -> 
    let e = mkmul wordsize (m#getvar qv) (genconstvar m expr) in
    ignore(m#assign qv e)
  | QAnd (qv, expr) -> ignore(m#assign qv (mkand (m#getvar qv) (genvarderef m expr)))
  | QOr (qv, expr) -> ignore(m#assign qv (mkor (m#getvar qv) (genvarderef m expr)))
  | QXor (qv, expr) -> ignore(m#assign qv (mkxor (m#getvar qv) (genvarderef m expr)))
  | QConcatMul (signed, qv1, qv2, expr) ->
    let cast = if signed then m#cast else m#ucast in
    let mul = mkmul (wordsize * 2) 
      (cast (m#getvar qv2) wordsize (wordsize * 2))
      (cast (genvarderef m expr) wordsize (wordsize * 2)) in
    let e1 = mkextract mul (wordsize * 2 - 1) wordsize in
    let e2 = mkextract mul (wordsize - 1) 0 in
    ignore(m#assign qv1 e1);
    ignore(m#assign qv2 e2)
  | QNeg qv -> ignore(m#assign qv (mkneg (m#getvar qv)))
  | QNot qv -> ignore(m#assign qv (mknot (m#getvar qv)))
  | QConcatShiftLeft (qv1, qv2, expr) ->
    let n = getshiftsize expr in
    let e = mkextract
      (mksll (mkconcat (m#getvar qv1) (m#getvar qv2)) n)
      (wordsize * 2 - 1) wordsize in
    ignore(m#assign qv1 e)
  | QShiftLeft (qv, expr) ->
    let n = getshiftsize expr in
    let e = mkextract (mksll (m#getvar qv) n) (wordsize - 1) 0 in
    ignore(m#assign qv e)
  | QConcatShiftRight (qv1, qv2, expr) ->
    let n = getshiftsize expr in
    let e = mkextract
      (mksrl (mkconcat (m#getvar qv2) (m#getvar qv1)) n)
      (wordsize - 1)
      0 in
    ignore(m#assign qv1 e)
  | QShiftRight (signed, qv, expr) ->
    if signed then
        raise (Unsupported "Arithmetic right shifting is unsupported.")
    else
      let n = getshiftsize expr in
      let e = mksrl (m#getvar qv) n in
      ignore(m#assign qv e)
  | QInput _
  | QCaller _
  | QEnter _
  | QLeave
  | QComment _ -> ()
  | QAnnot annot -> genannot m annot

(**
   * Returns a STP program as the verification condition of an annotated Qhasm program.
   * The first argument is a map from names of predefined constants to their values.
*)
let generate cmap prog = 
  let m = new manager cmap in
  let _ = List.iter (genstmt m) prog in
  let _ = m#genquery in
  let sp = m#getprog in
  sp
