{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts   #-}
{-# LANGUAGE NamedFieldPuns     #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE RankNTypes         #-}

module PlutusTx.Test.Run.Code (
  module Eval,
  evaluationResultMatchesHaskell,
  assertEvaluatesSuccessfully,
  assertEvaluatesWithError,
  assertResult,
) where

import Prelude

import Data.Text qualified as Text
import PlutusCore.Pretty
import PlutusCore.Test (TestNested, assertEqualPretty, embed)
import PlutusTx qualified as Tx
import PlutusTx.Code (CompiledCode)
import PlutusTx.Eval as Eval
import PlutusTx.Test.Util.Compiled (cekResultMatchesHaskellValue, compiledCodeToTerm)
import Test.Tasty (TestName)
import Test.Tasty.HUnit (Assertion, assertFailure, testCase)
import UntypedPlutusCore (DefaultUni)

{-| Evaluate 'CompiledCode' and check that the result matches a given Haskell value
   (perhaps obtained by running the Haskell code that the term was compiled
   from).  We evaluate the lifted Haskell value as well, because lifting may
   produce reducible terms. The function is polymorphic in the comparison
   operator so that we can use it with both HUnit Assertions and QuickCheck
   Properties.
-}
evaluationResultMatchesHaskell
  :: (Tx.Lift DefaultUni hask)
  => CompiledCode a
  -> (forall r. (Eq r, Show r) => r -> r -> k)
  -> hask
  -> k
evaluationResultMatchesHaskell :: forall hask a k.
Lift DefaultUni hask =>
CompiledCode a
-> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
evaluationResultMatchesHaskell CompiledCode a
actual =
  Term -> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
forall hask k.
Lift DefaultUni hask =>
Term -> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
cekResultMatchesHaskellValue (CompiledCode a -> Term
forall a. CompiledCodeIn DefaultUni DefaultFun a -> Term
compiledCodeToTerm CompiledCode a
actual)

assertEvaluatesSuccessfully :: CompiledCode a -> Assertion
assertEvaluatesSuccessfully :: forall a. CompiledCode a -> Assertion
assertEvaluatesSuccessfully CompiledCode a
code = do
  case CompiledCode a -> EvalResult
forall a. CompiledCode a -> EvalResult
evaluateCompiledCode CompiledCode a
code of
    EvalResult{evalResult :: EvalResult
-> Either
     (CekEvaluationException NamedDeBruijn DefaultUni DefaultFun) Term
evalResult = Right Term
_} -> () -> Assertion
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    EvalResult{evalResult :: EvalResult
-> Either
     (CekEvaluationException NamedDeBruijn DefaultUni DefaultFun) Term
evalResult = Left CekEvaluationException NamedDeBruijn DefaultUni DefaultFun
err, [Text]
evalResultTraces :: [Text]
evalResultTraces :: EvalResult -> [Text]
evalResultTraces} ->
      String -> Assertion
forall a. HasCallStack => String -> IO a
assertFailure (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
        Text -> String
Text.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
Text.unlines
            [ Text
"Evaluation failed with an error:"
            , Doc Any -> Text
forall ann. Doc ann -> Text
forall str ann. Render str => Doc ann -> str
render (CekEvaluationException NamedDeBruijn DefaultUni DefaultFun
-> Doc Any
forall a ann. PrettyPlc a => a -> Doc ann
prettyPlcClassicSimple CekEvaluationException NamedDeBruijn DefaultUni DefaultFun
err)
            , Text
"Evaluation traces:"
            , [Text] -> Text
Text.unlines [Text]
evalResultTraces
            ]

assertEvaluatesWithError :: CompiledCode a -> Assertion
assertEvaluatesWithError :: forall a. CompiledCode a -> Assertion
assertEvaluatesWithError CompiledCode a
code = do
  case CompiledCode a -> EvalResult
forall a. CompiledCode a -> EvalResult
evaluateCompiledCode CompiledCode a
code of
    EvalResult{evalResult :: EvalResult
-> Either
     (CekEvaluationException NamedDeBruijn DefaultUni DefaultFun) Term
evalResult = Left CekEvaluationException NamedDeBruijn DefaultUni DefaultFun
_} -> () -> Assertion
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    EvalResult{evalResult :: EvalResult
-> Either
     (CekEvaluationException NamedDeBruijn DefaultUni DefaultFun) Term
evalResult = Right Term
_, [Text]
evalResultTraces :: EvalResult -> [Text]
evalResultTraces :: [Text]
evalResultTraces} ->
      String -> Assertion
forall a. HasCallStack => String -> IO a
assertFailure (String -> Assertion) -> String -> Assertion
forall a b. (a -> b) -> a -> b
$
        Text -> String
Text.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
Text.unlines
            [ Text
"Evaluation succeeded, but expected an error."
            , Text
"Evaluation traces:"
            , [Text] -> Text
Text.unlines [Text]
evalResultTraces
            ]

assertResult :: TestName -> CompiledCode Bool -> TestNested
assertResult :: String -> CompiledCode Bool -> TestNested
assertResult String
name CompiledCode Bool
code =
  CompiledCode Bool
-> (forall r. (Eq r, Show r) => r -> r -> TestNested)
-> Bool
-> TestNested
forall hask a k.
Lift DefaultUni hask =>
CompiledCode a
-> (forall r. (Eq r, Show r) => r -> r -> k) -> hask -> k
evaluationResultMatchesHaskell
    CompiledCode Bool
code
    (\r
p r
h -> TestTree -> TestNested
forall a (m :: * -> *). MonadFree ((,) a) m => a -> m ()
embed (TestTree -> TestNested) -> TestTree -> TestNested
forall a b. (a -> b) -> a -> b
$ String -> Assertion -> TestTree
testCase String
name (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ String -> r -> r -> Assertion
forall a.
(Eq a, Show a, HasCallStack) =>
String -> a -> a -> Assertion
assertEqualPretty String
name r
p r
h)
    Bool
True