{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes       #-}

module PlutusTx.Test.Util.Compiled (
  Program,
  Term,
  toAnonDeBruijnTerm,
  toAnonDeBruijnProg,
  toNamedDeBruijnTerm,
  compiledCodeToTerm,
  cekResultMatchesHaskellValue,
)
where

import Prelude

import PlutusCore qualified as PLC
import PlutusCore.Default
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults qualified as PLC
import PlutusTx qualified as Tx
import UntypedPlutusCore qualified as UPLC
import UntypedPlutusCore.Evaluation.Machine.Cek as Cek

type Term = UPLC.Term PLC.NamedDeBruijn DefaultUni DefaultFun ()
type Program = UPLC.Program PLC.NamedDeBruijn DefaultUni DefaultFun ()

{-| Given a DeBruijn-named term, give every variable the name "v".  If we later
   call unDeBruijn, that will rename the variables to things like "v123", where
   123 is the relevant de Bruijn index.
-}
toNamedDeBruijnTerm
  :: UPLC.Term UPLC.DeBruijn DefaultUni DefaultFun ()
  -> UPLC.Term UPLC.NamedDeBruijn DefaultUni DefaultFun ()
toNamedDeBruijnTerm :: Term DeBruijn DefaultUni DefaultFun ()
-> Term NamedDeBruijn DefaultUni DefaultFun ()
toNamedDeBruijnTerm = (DeBruijn -> NamedDeBruijn)
-> Term DeBruijn DefaultUni DefaultFun ()
-> Term NamedDeBruijn DefaultUni DefaultFun ()
forall name name' (uni :: * -> *) fun ann.
(name -> name') -> Term name uni fun ann -> Term name' uni fun ann
UPLC.termMapNames DeBruijn -> NamedDeBruijn
UPLC.fakeNameDeBruijn

-- | Remove the textual names from a NamedDeBruijn term
toAnonDeBruijnTerm :: Term -> UPLC.Term UPLC.DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm :: Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm = (NamedDeBruijn -> DeBruijn)
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
forall name name' (uni :: * -> *) fun ann.
(name -> name') -> Term name uni fun ann -> Term name' uni fun ann
UPLC.termMapNames NamedDeBruijn -> DeBruijn
UPLC.unNameDeBruijn

toAnonDeBruijnProg
  :: UPLC.Program UPLC.NamedDeBruijn DefaultUni DefaultFun ()
  -> UPLC.Program UPLC.DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnProg :: Program NamedDeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnProg (UPLC.Program () Version
ver Term NamedDeBruijn DefaultUni DefaultFun ()
body) =
  ()
-> Version
-> Term DeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
forall name (uni :: * -> *) fun ann.
ann -> Version -> Term name uni fun ann -> Program name uni fun ann
UPLC.Program () Version
ver (Term DeBruijn DefaultUni DefaultFun ()
 -> Program DeBruijn DefaultUni DefaultFun ())
-> Term DeBruijn DefaultUni DefaultFun ()
-> Program DeBruijn DefaultUni DefaultFun ()
forall a b. (a -> b) -> a -> b
$ Term NamedDeBruijn DefaultUni DefaultFun ()
-> Term DeBruijn DefaultUni DefaultFun ()
toAnonDeBruijnTerm Term NamedDeBruijn DefaultUni DefaultFun ()
body

{-| Just extract the body of a program wrapped in a 'CompiledCodeIn'.
We use this a lot.
-}
compiledCodeToTerm :: Tx.CompiledCodeIn DefaultUni DefaultFun a -> Term
compiledCodeToTerm :: forall a.
CompiledCodeIn DefaultUni DefaultFun a
-> Term NamedDeBruijn DefaultUni DefaultFun ()
compiledCodeToTerm CompiledCodeIn DefaultUni DefaultFun a
code = let UPLC.Program ()
_ Version
_ Term NamedDeBruijn DefaultUni DefaultFun ()
body = CompiledCodeIn DefaultUni DefaultFun a
-> Program NamedDeBruijn DefaultUni DefaultFun ()
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
Tx.getPlcNoAnn CompiledCodeIn DefaultUni DefaultFun a
code in Term NamedDeBruijn DefaultUni DefaultFun ()
body

{-| Evaluate a PLC term and check that the result matches a given Haskell value
   (perhaps obtained by running the Haskell code that the term was compiled
   from).  We evaluate the lifted Haskell value as well, because lifting may
   produce reducible terms. The function is polymorphic in the comparison
   operator so that we can use it with both HUnit Assertions and QuickCheck
   Properties.
-}
cekResultMatchesHaskellValue
  :: (Tx.Lift DefaultUni hask)
  => Term
  -> (forall r. (Eq r, Show r) => r -> r -> k)
  -> hask
  -> k
cekResultMatchesHaskellValue :: forall hask k.
Lift DefaultUni hask =>
Term NamedDeBruijn DefaultUni DefaultFun ()
-> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
cekResultMatchesHaskellValue Term NamedDeBruijn DefaultUni DefaultFun ()
actual forall r. (Eq r, Show r) => r -> r -> k
matches hask
expected =
  EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
-> k
forall r. (Eq r, Show r) => r -> r -> k
matches
    (Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek Term NamedDeBruijn DefaultUni DefaultFun ()
actual)
    (Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek (CompiledCodeIn DefaultUni DefaultFun hask
-> Term NamedDeBruijn DefaultUni DefaultFun ()
forall a.
CompiledCodeIn DefaultUni DefaultFun a
-> Term NamedDeBruijn DefaultUni DefaultFun ()
compiledCodeToTerm (hask -> CompiledCodeIn DefaultUni DefaultFun hask
forall (uni :: * -> *) a fun.
(Lift uni a, GEq uni, ThrowableBuiltins uni fun,
 Typecheckable uni fun, Default (CostingPart uni fun),
 Default (BuiltinsInfo uni fun), Default (RewriteRules uni fun),
 Hashable fun) =>
a -> CompiledCodeIn uni fun a
Tx.liftCodeDef hask
expected)))
 where
  unsafeRunTermCek :: Term -> EvaluationResult Term
  unsafeRunTermCek :: Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
unsafeRunTermCek =
    Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
forall structural term operational a.
(PrettyPlc structural, PrettyPlc term, Typeable structural,
 Typeable term) =>
Either (EvaluationException structural operational term) a
-> EvaluationResult a
unsafeSplitStructuralOperational
      (Either
   (EvaluationException
      (MachineError DefaultFun)
      CekUserError
      (Term NamedDeBruijn DefaultUni DefaultFun ()))
   (Term NamedDeBruijn DefaultUni DefaultFun ())
 -> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> (Term NamedDeBruijn DefaultUni DefaultFun ()
    -> Either
         (EvaluationException
            (MachineError DefaultFun)
            CekUserError
            (Term NamedDeBruijn DefaultUni DefaultFun ()))
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> EvaluationResult (Term NamedDeBruijn DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
res, RestrictingSt
_, [Text]
_) -> Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term NamedDeBruijn DefaultUni DefaultFun ()))
  (Term NamedDeBruijn DefaultUni DefaultFun ())
res)
      ((Either
    (EvaluationException
       (MachineError DefaultFun)
       CekUserError
       (Term NamedDeBruijn DefaultUni DefaultFun ()))
    (Term NamedDeBruijn DefaultUni DefaultFun ()),
  RestrictingSt, [Text])
 -> Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
      (Term NamedDeBruijn DefaultUni DefaultFun ()))
-> (Term NamedDeBruijn DefaultUni DefaultFun ()
    -> (Either
          (EvaluationException
             (MachineError DefaultFun)
             CekUserError
             (Term NamedDeBruijn DefaultUni DefaultFun ()))
          (Term NamedDeBruijn DefaultUni DefaultFun ()),
        RestrictingSt, [Text]))
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> Either
     (EvaluationException
        (MachineError DefaultFun)
        CekUserError
        (Term NamedDeBruijn DefaultUni DefaultFun ()))
     (Term NamedDeBruijn DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
-> ExBudgetMode RestrictingSt DefaultUni DefaultFun
-> EmitterMode DefaultUni DefaultFun
-> Term NamedDeBruijn DefaultUni DefaultFun ()
-> (Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term NamedDeBruijn DefaultUni DefaultFun ()))
      (Term NamedDeBruijn DefaultUni DefaultFun ()),
    RestrictingSt, [Text])
forall (uni :: * -> *) fun ann cost.
ThrowableBuiltins uni fun =>
MachineParameters CekMachineCosts fun (CekValue uni fun ann)
-> ExBudgetMode cost uni fun
-> EmitterMode uni fun
-> NTerm uni fun ann
-> (Either
      (CekEvaluationException NamedDeBruijn uni fun) (NTerm uni fun ()),
    cost, [Text])
runCekDeBruijn
        MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
forall ann.
Typeable ann =>
MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ann)
PLC.defaultCekParametersForTesting
        ExBudgetMode RestrictingSt DefaultUni DefaultFun
forall (uni :: * -> *) fun.
ThrowableBuiltins uni fun =>
ExBudgetMode RestrictingSt uni fun
Cek.restrictingEnormous
        EmitterMode DefaultUni DefaultFun
forall (uni :: * -> *) fun. EmitterMode uni fun
Cek.noEmitter