{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes       #-}

module PlutusTx.Test.Util.Compiled (
  Program,
  Term,
  countFlatBytes,
  toAnonDeBruijnTerm,
  toAnonDeBruijnProg,
  toNamedDeBruijnTerm,
  compiledCodeToTerm,
  cekResultMatchesHaskellValue,
)
where

import Prelude

import Codec.Extras.SerialiseViaFlat (SerialiseViaFlat (..))
import Codec.Serialise (serialise)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import PlutusCore qualified as PLC
import PlutusCore.Default
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults qualified as PLC
import PlutusTx qualified as Tx
import PlutusTx.Code (CompiledCode, getPlcNoAnn)
import UntypedPlutusCore qualified as UPLC
import UntypedPlutusCore.Evaluation.Machine.Cek as Cek

type Term = UPLC.Term PLC.NamedDeBruijn DefaultUni DefaultFun ()
type Program = UPLC.Program PLC.NamedDeBruijn DefaultUni DefaultFun ()

{-| The size of a 'CompiledCodeIn' as measured in Flat bytes.

This function serialises the code to 'ByteString' and counts the number
of bytes. It uses the same serialisation format as used by the ledger:
CBOR(Flat(StripNames(Strip Annotations(UPLC))))

Caveat: the 'SerialisedCode' constructor of the 'CompiledCode' type
already contains a PLC program as 'ByteString', but it isn't the same byte
representation as the one produced by 'serialiseCompiledCode' function:
in uses the 'NamedDeBruijn' representation, which also stores names.
On the mainnet we don't serialise names, only DeBruijn indices, so this function
re-serialises the code to get the size in bytes that we would actually
use on the mainnet.
-}
countFlatBytes :: CompiledCode ann -> Integer
countFlatBytes :: forall ann. CompiledCode ann -> Integer
countFlatBytes =
  Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    (Int -> Integer)
-> (CompiledCode ann -> Int) -> CompiledCode ann -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
BS.length
    (ByteString -> Int)
-> (CompiledCode ann -> ByteString) -> CompiledCode ann -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.toStrict
    (ByteString -> ByteString)
-> (CompiledCode ann -> ByteString)
-> CompiledCode ann
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SerialiseViaFlat
  (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
-> ByteString
forall a. Serialise a => a -> ByteString
serialise
    (SerialiseViaFlat
   (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
 -> ByteString)
-> (CompiledCode ann
    -> SerialiseViaFlat
         (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()))
-> CompiledCode ann
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()
-> SerialiseViaFlat
     (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
forall a. a -> SerialiseViaFlat a
SerialiseViaFlat
    (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()
 -> SerialiseViaFlat
      (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()))
-> (CompiledCode ann
    -> UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
-> CompiledCode ann
-> SerialiseViaFlat
     (UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Program DeBruijn DefaultUni DefaultFun ()
-> UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()
forall name (uni :: * -> *) fun ann.
Program name uni fun ann -> UnrestrictedProgram name uni fun ann
UPLC.UnrestrictedProgram
    (Program DeBruijn DefaultUni DefaultFun ()
 -> UnrestrictedProgram DeBruijn DefaultUni DefaultFun ())
-> (CompiledCode ann -> Program DeBruijn DefaultUni DefaultFun ())
-> CompiledCode ann
-> UnrestrictedProgram DeBruijn DefaultUni DefaultFun ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Program NamedDeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnProg
    (Program NamedDeBruijn DefaultUni DefaultFun ()
 -> Program DeBruijn DefaultUni DefaultFun ())
-> (CompiledCode ann
    -> Program NamedDeBruijn DefaultUni DefaultFun ())
-> CompiledCode ann
-> Program DeBruijn DefaultUni DefaultFun ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompiledCode ann -> Program NamedDeBruijn DefaultUni DefaultFun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlcNoAnn

{-| Given a DeBruijn-named term, give every variable the name "v".  If we later
   call unDeBruijn, that will rename the variables to things like "v123", where
   123 is the relevant de Bruijn index.
-}
toNamedDeBruijnTerm
  :: UPLC.Term UPLC.DeBruijn DefaultUni DefaultFun ()
  -> UPLC.Term UPLC.NamedDeBruijn DefaultUni DefaultFun ()
toNamedDeBruijnTerm :: Term DeBruijn DefaultUni DefaultFun ()
-> Term NamedDeBruijn DefaultUni DefaultFun ()
toNamedDeBruijnTerm = (DeBruijn -> NamedDeBruijn)
-> Term DeBruijn DefaultUni DefaultFun ()
-> Term NamedDeBruijn DefaultUni DefaultFun ()
forall name name' (uni :: * -> *) fun ann.
(name -> name') -> Term name uni fun ann -> Term name' uni fun ann
UPLC.termMapNames DeBruijn -> NamedDeBruijn
UPLC.fakeNameDeBruijn

-- | Remove the textual names from a NamedDeBruijn term
toAnonDeBruijnTerm :: Term -> UPLC.Term UPLC.DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm :: Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm = (NamedDeBruijn -> DeBruijn)
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
forall name name' (uni :: * -> *) fun ann.
(name -> name') -> Term name uni fun ann -> Term name' uni fun ann
UPLC.termMapNames NamedDeBruijn -> DeBruijn
UPLC.unNameDeBruijn

toAnonDeBruijnProg
  :: UPLC.Program UPLC.NamedDeBruijn DefaultUni DefaultFun ()
  -> UPLC.Program UPLC.DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnProg :: Program NamedDeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnProg (UPLC.Program () Version
ver Term NamedDeBruijn DefaultUni DefaultFun ()
body) =
  ()
-> Version
-> Term DeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
forall name (uni :: * -> *) fun ann.
ann -> Version -> Term name uni fun ann -> Program name uni fun ann
UPLC.Program () Version
ver (Term DeBruijn DefaultUni DefaultFun ()
 -> Program DeBruijn DefaultUni DefaultFun ())
-> Term DeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
forall a b. (a -> b) -> a -> b
$ Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm Term NamedDeBruijn DefaultUni DefaultFun ()
body

{-| Just extract the body of a program wrapped in a 'CompiledCodeIn'.
We use this a lot.
-}
compiledCodeToTerm :: Tx.CompiledCodeIn DefaultUni DefaultFun a -> Term
compiledCodeToTerm :: forall a.
CompiledCodeIn DefaultUni DefaultFun a
-> Term NamedDeBruijn DefaultUni DefaultFun ()
compiledCodeToTerm CompiledCodeIn DefaultUni DefaultFun a
code = let UPLC.Program ()
_ Version
_ Term NamedDeBruijn DefaultUni DefaultFun ()
body = CompiledCodeIn DefaultUni DefaultFun a
-> Program NamedDeBruijn DefaultUni DefaultFun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
Tx.getPlcNoAnn CompiledCodeIn DefaultUni DefaultFun a
code in Term NamedDeBruijn DefaultUni DefaultFun ()
body

{-| Evaluate a PLC term and check that the result matches a given Haskell value
   (perhaps obtained by running the Haskell code that the term was compiled
   from).  We evaluate the lifted Haskell value as well, because lifting may
   produce reducible terms. The function is polymorphic in the comparison
   operator so that we can use it with both HUnit Assertions and QuickCheck
   Properties.
-}
cekResultMatchesHaskellValue
  :: (Tx.Lift DefaultUni hask)
  => Term
  -> (forall r. (Eq r, Show r) => r -> r -> k)
  -> hask
  -> k
cekResultMatchesHaskellValue :: forall hask k.
Lift DefaultUni hask =>
Term NamedDeBruijn DefaultUni DefaultFun ()
-> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
cekResultMatchesHaskellValue Term NamedDeBruijn DefaultUni DefaultFun ()
actual forall r. (Eq r, Show r) => r -> r -> k
matches hask
expected =
  EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
-> k
forall r. (Eq r, Show r) => r -> r -> k
matches
    (Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek Term NamedDeBruijn DefaultUni DefaultFun ()
actual)
    (Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek (CompiledCodeIn DefaultUni DefaultFun hask
-> Term NamedDeBruijn DefaultUni DefaultFun ()
forall a.
CompiledCodeIn DefaultUni DefaultFun a
-> Term NamedDeBruijn DefaultUni DefaultFun ()
compiledCodeToTerm (hask -> CompiledCodeIn DefaultUni DefaultFun hask
forall (uni :: * -> *) a fun.
(Lift uni a, GEq uni, Everywhere uni Eq, ThrowableBuiltins uni fun,
 Typecheckable uni fun, CaseBuiltin uni,
 Default (CostingPart uni fun), Default (BuiltinsInfo uni fun),
 Default (RewriteRules uni fun), Hashable fun) =>
a -> CompiledCodeIn uni fun a
Tx.liftCodeDef hask
expected)))
 where
  unsafeRunTermCek :: Term -> EvaluationResult Term
  unsafeRunTermCek :: Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek =
    Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
forall structural term operational a.
(PrettyPlc structural, PrettyPlc term, Typeable structural,
 Typeable term) =>
Either (EvaluationException structural operational term) a
-> EvaluationResult a
unsafeSplitStructuralOperational
      (Either
   (EvaluationException
      (MachineError DefaultFun)
      CekUserError
      (Term NamedDeBruijn DefaultUni DefaultFun ()))
   (Term NamedDeBruijn DefaultUni DefaultFun ())
 -> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> (Term NamedDeBruijn DefaultUni DefaultFun ()
    -> Either
         (EvaluationException
            (MachineError DefaultFun)
            CekUserError
            (Term NamedDeBruijn DefaultUni DefaultFun ()))
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
res, RestrictingSt
_, [Text]
_) -> Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
res)
      ((Either
    (EvaluationException
       (MachineError DefaultFun)
       CekUserError
       (Term NamedDeBruijn DefaultUni DefaultFun ()))
    (Term NamedDeBruijn DefaultUni DefaultFun ()),
  RestrictingSt, [Text])
 -> Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
      (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> (Term NamedDeBruijn DefaultUni DefaultFun ()
    -> (Either
          (EvaluationException
             (MachineError DefaultFun)
             CekUserError
             (Term NamedDeBruijn DefaultUni DefaultFun ()))
          (Term NamedDeBruijn DefaultUni DefaultFun ()),
        RestrictingSt, [Text]))
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> Either
     (EvaluationException
        (MachineError DefaultFun)
        CekUserError
        (Term NamedDeBruijn DefaultUni DefaultFun ()))
     (Term NamedDeBruijn DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
-> ExBudgetMode RestrictingSt DefaultUni DefaultFun
-> EmitterMode DefaultUni DefaultFun
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> (Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
      (Term NamedDeBruijn DefaultUni DefaultFun ()),
    RestrictingSt, [Text])
forall (uni :: * -> *) fun ann cost.
ThrowableBuiltins uni fun =>
MachineParameters CekMachineCosts fun (CekValue uni fun ann)
-> ExBudgetMode cost uni fun
-> EmitterMode uni fun
-> NTerm uni fun ann
-> (Either
      (CekEvaluationException NamedDeBruijn uni fun) (NTerm uni fun ()),
    cost, [Text])
runCekDeBruijn
        MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
forall ann.
Typeable ann =>
MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ann)
PLC.defaultCekParametersForTesting
        ExBudgetMode RestrictingSt DefaultUni DefaultFun
forall (uni :: * -> *) fun.
ThrowableBuiltins uni fun =>
ExBudgetMode RestrictingSt uni fun
Cek.restrictingEnormous
        EmitterMode DefaultUni DefaultFun
forall (uni :: * -> *) fun. EmitterMode uni fun
Cek.noEmitter