{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE TypeApplications #-}

module Transform.CaseOfCase.Spec where

import Data.ByteString.Lazy qualified as BSL
import Data.Text.Encoding (encodeUtf8)
import PlutusCore qualified as PLC
import PlutusCore.Evaluation.Machine.BuiltinCostModel (BuiltinCostModel)
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults
  ( defaultBuiltinCostModelForTesting
  , defaultCekMachineCostsForTesting
  )
import PlutusCore.Evaluation.Machine.MachineParameters
  ( CostModel (..)
  , MachineParameters (..)
  , mkMachineVariantParameters
  )
import PlutusCore.Evaluation.Machine.MachineParameters.Default (DefaultMachineParameters)
import PlutusCore.MkPlc (mkConstant)
import PlutusCore.Pretty
import PlutusPrelude (Default (def))
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.Golden (goldenVsString)
import Test.Tasty.HUnit (testCase, (@?=))
import Transform.Lib (builtinTrue, case_, con, constr, err, ite, sopFalse, sopTrue, var)
import UntypedPlutusCore (DefaultFun, DefaultUni, Name, Term (..))
import UntypedPlutusCore.Core qualified as UPLC
import UntypedPlutusCore.Evaluation.Machine.Cek
  ( CekMachineCosts
  , EvaluationResult (..)
  , evaluateCek
  , noEmitter
  , unsafeSplitStructuralOperational
  )
import UntypedPlutusCore.Transform.CaseOfCase (caseOfCase)
import UntypedPlutusCore.Transform.Optimizer (evalOptimizer)

test_caseOfCase :: TestTree
test_caseOfCase :: TestTree
test_caseOfCase =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"CaseOfCase"
    [ TestName -> Term Name DefaultUni DefaultFun () -> TestTree
goldenVsSimplified TestName
"1" Term Name DefaultUni DefaultFun ()
caseOfCase1
    , TestName -> Term Name DefaultUni DefaultFun () -> TestTree
goldenVsSimplified TestName
"2" Term Name DefaultUni DefaultFun ()
caseOfCase2
    , TestName -> Term Name DefaultUni DefaultFun () -> TestTree
goldenVsSimplified TestName
"3" Term Name DefaultUni DefaultFun ()
caseOfCase3
    , TestName -> Term Name DefaultUni DefaultFun () -> TestTree
goldenVsSimplified TestName
"withError" Term Name DefaultUni DefaultFun ()
caseOfCaseWithError
    , TestTree
testCaseOfCaseWithError
    ]

caseOfCase1 :: Term Name PLC.DefaultUni PLC.DefaultFun ()
caseOfCase1 :: Term Name DefaultUni DefaultFun ()
caseOfCase1 =
  Term Name DefaultUni DefaultFun ()
-> [Term Name DefaultUni DefaultFun ()]
-> Term Name DefaultUni DefaultFun ()
case_
    (Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
ite (TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"b") Term Name DefaultUni DefaultFun ()
sopTrue Term Name DefaultUni DefaultFun ()
sopFalse)
    [ Integer -> Term Name DefaultUni DefaultFun ()
con Integer
1 -- True branch
    , Integer -> Term Name DefaultUni DefaultFun ()
con Integer
2 -- False branch
    ]

{-| This should not simplify, because one of the branches of `ifThenElse` is not a `Constr`.
Unless both branches are known constructors, the case-of-case transformation
may increase the program size. -}
caseOfCase2 :: Term Name PLC.DefaultUni PLC.DefaultFun ()
caseOfCase2 :: Term Name DefaultUni DefaultFun ()
caseOfCase2 = Term Name DefaultUni DefaultFun ()
-> [Term Name DefaultUni DefaultFun ()]
-> Term Name DefaultUni DefaultFun ()
case_ (Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
ite (TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"b") (TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"t") Term Name DefaultUni DefaultFun ()
sopFalse) [Integer -> Term Name DefaultUni DefaultFun ()
con Integer
1, Integer -> Term Name DefaultUni DefaultFun ()
con Integer
2]

{-| Similar to `caseOfCase1`, but the type of the @true@ and @false@ branches is
@[Integer]@ rather than Bool (note that @constr 0@ has two parameters, @x@ and @xs@). -}
caseOfCase3 :: Term Name PLC.DefaultUni PLC.DefaultFun ()
caseOfCase3 :: Term Name DefaultUni DefaultFun ()
caseOfCase3 =
  Term Name DefaultUni DefaultFun ()
-> [Term Name DefaultUni DefaultFun ()]
-> Term Name DefaultUni DefaultFun ()
case_
    (Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
ite (TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"b") (Word64
-> [Term Name DefaultUni DefaultFun ()]
-> Term Name DefaultUni DefaultFun ()
constr Word64
0 [TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"x", TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"xs"]) Term Name DefaultUni DefaultFun ()
sopFalse)
    [TestName -> Term Name DefaultUni DefaultFun ()
var TestName
"f", Integer -> Term Name DefaultUni DefaultFun ()
con Integer
2]

{-|

@
  case (force ifThenElse) True True False of
    True -> ()
    False -> _|_
@

Evaluates to `()` because the first case alternative is selected.
(The _|_ is not evaluated because case alternatives are evaluated lazily).

After the `CaseOfCase` transformation the program should evaluate to `()` as well.

@
  force ((force ifThenElse) True (delay ()) (delay _|_))
@ -}
caseOfCaseWithError :: Term Name DefaultUni DefaultFun ()
caseOfCaseWithError :: Term Name DefaultUni DefaultFun ()
caseOfCaseWithError = Term Name DefaultUni DefaultFun ()
-> [Term Name DefaultUni DefaultFun ()]
-> Term Name DefaultUni DefaultFun ()
case_ (Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
ite Term Name DefaultUni DefaultFun ()
builtinTrue Term Name DefaultUni DefaultFun ()
sopTrue Term Name DefaultUni DefaultFun ()
sopFalse) [forall a (uni :: * -> *) fun (term :: * -> *) tyname name ann.
(TermLike term tyname name uni fun, HasTermLevel uni a) =>
ann -> a -> term ann
mkConstant @() () (), Term Name DefaultUni DefaultFun ()
err]

testCaseOfCaseWithError :: TestTree
testCaseOfCaseWithError :: TestTree
testCaseOfCaseWithError =
  TestName -> Assertion -> TestTree
testCase TestName
"Transformation doesn't evaluate error eagerly" do
    let simplifiedTerm :: Term Name DefaultUni DefaultFun ()
simplifiedTerm = Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
evalCaseOfCase Term Name DefaultUni DefaultFun ()
caseOfCaseWithError
    Term Name DefaultUni DefaultFun ()
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
evaluateUplc Term Name DefaultUni DefaultFun ()
simplifiedTerm EvaluationResult (Term Name DefaultUni DefaultFun ())
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
-> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Term Name DefaultUni DefaultFun ()
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
evaluateUplc Term Name DefaultUni DefaultFun ()
caseOfCaseWithError

----------------------------------------------------------------------------------------------------
-- Helper functions --------------------------------------------------------------------------------

evalCaseOfCase
  :: Term Name DefaultUni DefaultFun ()
  -> Term Name DefaultUni DefaultFun ()
evalCaseOfCase :: Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
evalCaseOfCase Term Name DefaultUni DefaultFun ()
term = Optimizer
  Name DefaultUni DefaultFun () (Term Name DefaultUni DefaultFun ())
-> Term Name DefaultUni DefaultFun ()
forall name (uni :: * -> *) fun ann a.
Optimizer name uni fun ann a -> a
evalOptimizer (Optimizer
   Name DefaultUni DefaultFun () (Term Name DefaultUni DefaultFun ())
 -> Term Name DefaultUni DefaultFun ())
-> Optimizer
     Name DefaultUni DefaultFun () (Term Name DefaultUni DefaultFun ())
-> Term Name DefaultUni DefaultFun ()
forall a b. (a -> b) -> a -> b
$ Term Name DefaultUni DefaultFun ()
-> Optimizer
     Name DefaultUni DefaultFun () (Term Name DefaultUni DefaultFun ())
forall fun (m :: * -> *) (uni :: * -> *) name a.
(fun ~ DefaultFun, Monad m, CaseBuiltin uni, GEq uni, Closed uni,
 Everywhere uni Eq) =>
Term name uni fun a
-> OptimizerT name uni fun a m (Term name uni fun a)
caseOfCase Term Name DefaultUni DefaultFun ()
term

evaluateUplc
  :: UPLC.Term Name DefaultUni DefaultFun ()
  -> EvaluationResult (UPLC.Term Name DefaultUni DefaultFun ())
evaluateUplc :: Term Name DefaultUni DefaultFun ()
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
evaluateUplc = Either
  (EvaluationException
     (MachineError DefaultFun)
     CekUserError
     (Term Name DefaultUni DefaultFun ()))
  (Term Name DefaultUni DefaultFun ())
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
forall structural term operational a.
(PrettyPlc structural, PrettyPlc term, Typeable structural,
 Typeable term) =>
Either (EvaluationException structural operational term) a
-> EvaluationResult a
unsafeSplitStructuralOperational (Either
   (EvaluationException
      (MachineError DefaultFun)
      CekUserError
      (Term Name DefaultUni DefaultFun ()))
   (Term Name DefaultUni DefaultFun ())
 -> EvaluationResult (Term Name DefaultUni DefaultFun ()))
-> ((Either
       (EvaluationException
          (MachineError DefaultFun)
          CekUserError
          (Term Name DefaultUni DefaultFun ()))
       (Term Name DefaultUni DefaultFun ()),
     [Text])
    -> Either
         (EvaluationException
            (MachineError DefaultFun)
            CekUserError
            (Term Name DefaultUni DefaultFun ()))
         (Term Name DefaultUni DefaultFun ()))
-> (Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term Name DefaultUni DefaultFun ()))
      (Term Name DefaultUni DefaultFun ()),
    [Text])
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Either
   (EvaluationException
      (MachineError DefaultFun)
      CekUserError
      (Term Name DefaultUni DefaultFun ()))
   (Term Name DefaultUni DefaultFun ()),
 [Text])
-> Either
     (EvaluationException
        (MachineError DefaultFun)
        CekUserError
        (Term Name DefaultUni DefaultFun ()))
     (Term Name DefaultUni DefaultFun ())
forall a b. (a, b) -> a
fst ((Either
    (EvaluationException
       (MachineError DefaultFun)
       CekUserError
       (Term Name DefaultUni DefaultFun ()))
    (Term Name DefaultUni DefaultFun ()),
  [Text])
 -> EvaluationResult (Term Name DefaultUni DefaultFun ()))
-> (Term Name DefaultUni DefaultFun ()
    -> (Either
          (EvaluationException
             (MachineError DefaultFun)
             CekUserError
             (Term Name DefaultUni DefaultFun ()))
          (Term Name DefaultUni DefaultFun ()),
        [Text]))
-> Term Name DefaultUni DefaultFun ()
-> EvaluationResult (Term Name DefaultUni DefaultFun ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EmitterMode DefaultUni DefaultFun
-> MachineParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
-> Term Name DefaultUni DefaultFun ()
-> (Either
      (EvaluationException
         (MachineError DefaultFun)
         CekUserError
         (Term Name DefaultUni DefaultFun ()))
      (Term Name DefaultUni DefaultFun ()),
    [Text])
forall (uni :: * -> *) fun ann.
ThrowableBuiltins uni fun =>
EmitterMode uni fun
-> MachineParameters CekMachineCosts fun (CekValue uni fun ann)
-> Term Name uni fun ann
-> (Either
      (CekEvaluationException Name uni fun) (Term Name uni fun ()),
    [Text])
evaluateCek EmitterMode DefaultUni DefaultFun
forall (uni :: * -> *) fun. EmitterMode uni fun
noEmitter MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
machineParameters
  where
    costModel :: CostModel CekMachineCosts BuiltinCostModel
    costModel :: CostModel CekMachineCosts BuiltinCostModel
costModel =
      CekMachineCosts
-> BuiltinCostModel -> CostModel CekMachineCosts BuiltinCostModel
forall machinecosts builtincosts.
machinecosts -> builtincosts -> CostModel machinecosts builtincosts
CostModel CekMachineCosts
defaultCekMachineCostsForTesting BuiltinCostModel
defaultBuiltinCostModelForTesting

    machineParameters :: DefaultMachineParameters
    machineParameters :: MachineParameters
  CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
machineParameters =
      -- TODO: proper semantic variant. What should def be?
      CaserBuiltin (UniOf (CekValue DefaultUni DefaultFun ()))
-> MachineVariantParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
-> MachineParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
forall machineCosts fun val.
CaserBuiltin (UniOf val)
-> MachineVariantParameters machineCosts fun val
-> MachineParameters machineCosts fun val
MachineParameters CaserBuiltin (UniOf (CekValue DefaultUni DefaultFun ()))
CaserBuiltin DefaultUni
forall a. Default a => a
def (MachineVariantParameters
   CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
 -> MachineParameters
      CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ()))
-> MachineVariantParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
-> MachineParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
forall a b. (a -> b) -> a -> b
$ BuiltinSemanticsVariant DefaultFun
-> CostModel CekMachineCosts BuiltinCostModel
-> MachineVariantParameters
     CekMachineCosts DefaultFun (CekValue DefaultUni DefaultFun ())
forall (uni :: * -> *) fun builtincosts val machineCosts.
(CostingPart uni fun ~ builtincosts, HasMeaningIn uni val,
 ToBuiltinMeaning uni fun) =>
BuiltinSemanticsVariant fun
-> CostModel machineCosts builtincosts
-> MachineVariantParameters machineCosts fun val
mkMachineVariantParameters BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def CostModel CekMachineCosts BuiltinCostModel
costModel

goldenVsSimplified :: String -> Term Name PLC.DefaultUni PLC.DefaultFun () -> TestTree
goldenVsSimplified :: TestName -> Term Name DefaultUni DefaultFun () -> TestTree
goldenVsSimplified TestName
testName =
  TestName -> TestName -> IO ByteString -> TestTree
goldenVsString
    TestName
testName
    (TestName
"untyped-plutus-core/test/Transform/CaseOfCase/" TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
testName TestName -> TestName -> TestName
forall a. [a] -> [a] -> [a]
++ TestName
".golden.uplc")
    (IO ByteString -> TestTree)
-> (Term Name DefaultUni DefaultFun () -> IO ByteString)
-> Term Name DefaultUni DefaultFun ()
-> TestTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (ByteString -> IO ByteString)
-> (Term Name DefaultUni DefaultFun () -> ByteString)
-> Term Name DefaultUni DefaultFun ()
-> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.fromStrict
    (ByteString -> ByteString)
-> (Term Name DefaultUni DefaultFun () -> ByteString)
-> Term Name DefaultUni DefaultFun ()
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8
    (Text -> ByteString)
-> (Term Name DefaultUni DefaultFun () -> Text)
-> Term Name DefaultUni DefaultFun ()
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc Any -> Text
forall ann. Doc ann -> Text
forall str ann. Render str => Doc ann -> str
render
    (Doc Any -> Text)
-> (Term Name DefaultUni DefaultFun () -> Doc Any)
-> Term Name DefaultUni DefaultFun ()
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term Name DefaultUni DefaultFun () -> Doc Any
forall a ann. PrettyClassic a => a -> Doc ann
prettyClassicSimple
    (Term Name DefaultUni DefaultFun () -> Doc Any)
-> (Term Name DefaultUni DefaultFun ()
    -> Term Name DefaultUni DefaultFun ())
-> Term Name DefaultUni DefaultFun ()
-> Doc Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
evalCaseOfCase