-- editorconfig-checker-disable-file
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase     #-}

-- | Various helpers for defining evaluation tests.
module Evaluation.Helpers (
  -- * Generators
  forAllByteString,
  forAllByteStringThat,
  -- * Evaluation helpers
  evaluateTheSame,
  evaluatesToConstant,
  assertEvaluatesToConstant,
  evaluateToHaskell,
  ) where

import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.Kind (Type)
import Evaluation.Builtins.Common (typecheckEvaluateCek, typecheckReadKnownCek)
import GHC.Stack (HasCallStack)
import Hedgehog (PropertyT, annotateShow, failure, forAllWith, (===))
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Numeric (showHex)
import PlutusCore qualified as PLC
import PlutusCore.Builtin (ReadKnownIn)
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults (defaultBuiltinCostModelForTesting)
import PlutusCore.MkPlc (mkConstant)
import PlutusPrelude (Word8, def)
import Test.Tasty.HUnit (assertEqual, assertFailure)
import UntypedPlutusCore qualified as UPLC

-- | Given a lower and upper bound (both inclusive) on length, generate a 'ByteString' whose length
-- falls within these bounds. Furthermore, the generated 'ByteString' will show as a list of
-- hex-encoded bytes on a failure, instead of the default 'Show' output.
--
-- = Note
--
-- It is the caller's responsibility to ensure that the bounds are sensible: that is, that neither
-- the upper or lower bound are negative, and that the lower bound is not greater than the upper
-- bound.
forAllByteString :: forall (m :: Type -> Type) .
  (Monad m, HasCallStack) =>
  Int -> Int -> PropertyT m ByteString
forAllByteString :: forall (m :: * -> *).
(Monad m, HasCallStack) =>
Int -> Int -> PropertyT m ByteString
forAllByteString Int
lo = (ByteString -> String) -> Gen ByteString -> PropertyT m ByteString
forall (m :: * -> *) a.
(Monad m, HasCallStack) =>
(a -> String) -> Gen a -> PropertyT m a
forAllWith ByteString -> String
hexShow (Gen ByteString -> PropertyT m ByteString)
-> (Int -> Gen ByteString) -> Int -> PropertyT m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Range Int -> Gen ByteString
forall (m :: * -> *). MonadGen m => Range Int -> m ByteString
Gen.bytes (Range Int -> Gen ByteString)
-> (Int -> Range Int) -> Int -> Gen ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Range Int
forall a. Integral a => a -> a -> Range a
Range.linear Int
lo

-- | As 'forAllByteString', but with a postcondition.
--
-- = Note
--
-- If the postcondition is unlikely, the generator may eventually fail after too many retries.
-- Ensure that the postcondition is likely to avoid problems.
forAllByteStringThat :: forall (m :: Type -> Type) .
  (Monad m, HasCallStack) =>
  (ByteString -> Bool) -> Int -> Int -> PropertyT m ByteString
forAllByteStringThat :: forall (m :: * -> *).
(Monad m, HasCallStack) =>
(ByteString -> Bool) -> Int -> Int -> PropertyT m ByteString
forAllByteStringThat ByteString -> Bool
p Int
lo = (ByteString -> String) -> Gen ByteString -> PropertyT m ByteString
forall (m :: * -> *) a.
(Monad m, HasCallStack) =>
(a -> String) -> Gen a -> PropertyT m a
forAllWith ByteString -> String
hexShow (Gen ByteString -> PropertyT m ByteString)
-> (Int -> Gen ByteString) -> Int -> PropertyT m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> Gen ByteString -> Gen ByteString
forall (m :: * -> *) a. MonadGen m => (a -> Bool) -> m a -> m a
Gen.filterT ByteString -> Bool
p (Gen ByteString -> Gen ByteString)
-> (Int -> Gen ByteString) -> Int -> Gen ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Range Int -> Gen ByteString
forall (m :: * -> *). MonadGen m => Range Int -> m ByteString
Gen.bytes (Range Int -> Gen ByteString)
-> (Int -> Range Int) -> Int -> Gen ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Range Int
forall a. Integral a => a -> a -> Range a
Range.linear Int
lo

-- | Typechecks and evaluates both PLC expressions. If either of them fail to typecheck, fail the
-- test, noting what the failure was. If both typecheck, but either errors when run, fail the test,
-- noting the log(s) for any failing expression. If both run without error, compare the results
-- using '==='.
evaluateTheSame ::
  HasCallStack =>
  PLC.Term UPLC.TyName UPLC.Name UPLC.DefaultUni UPLC.DefaultFun () ->
  PLC.Term UPLC.TyName UPLC.Name UPLC.DefaultUni UPLC.DefaultFun () ->
  PropertyT IO ()
evaluateTheSame :: HasCallStack =>
Term TyName Name DefaultUni DefaultFun ()
-> Term TyName Name DefaultUni DefaultFun () -> PropertyT IO ()
evaluateTheSame Term TyName Name DefaultUni DefaultFun ()
lhs Term TyName Name DefaultUni DefaultFun ()
rhs =
  case BuiltinSemanticsVariant DefaultFun
-> CostingPart DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> Either
     (Error DefaultUni DefaultFun ())
     (EvaluationResult (Term Name DefaultUni DefaultFun ()), [Text])
forall (uni :: * -> *) fun (m :: * -> *).
(MonadError (Error uni fun ()) m, Typecheckable uni fun, GEq uni,
 Everywhere uni ExMemoryUsage, PrettyUni uni, Pretty fun) =>
BuiltinSemanticsVariant fun
-> CostingPart uni fun
-> Term TyName Name uni fun ()
-> m (EvaluationResult (Term Name uni fun ()), [Text])
typecheckEvaluateCek BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def BuiltinCostModel
CostingPart DefaultUni DefaultFun
defaultBuiltinCostModelForTesting Term TyName Name DefaultUni DefaultFun ()
lhs of
    Left Error DefaultUni DefaultFun ()
x -> Error DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow Error DefaultUni DefaultFun ()
x PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
    Right (EvaluationResult (Term Name DefaultUni DefaultFun ())
resLhs, [Text]
logsLhs) -> case BuiltinSemanticsVariant DefaultFun
-> CostingPart DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> Either
     (Error DefaultUni DefaultFun ())
     (EvaluationResult (Term Name DefaultUni DefaultFun ()), [Text])
forall (uni :: * -> *) fun (m :: * -> *).
(MonadError (Error uni fun ()) m, Typecheckable uni fun, GEq uni,
 Everywhere uni ExMemoryUsage, PrettyUni uni, Pretty fun) =>
BuiltinSemanticsVariant fun
-> CostingPart uni fun
-> Term TyName Name uni fun ()
-> m (EvaluationResult (Term Name uni fun ()), [Text])
typecheckEvaluateCek BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def BuiltinCostModel
CostingPart DefaultUni DefaultFun
defaultBuiltinCostModelForTesting Term TyName Name DefaultUni DefaultFun ()
rhs of
      Left Error DefaultUni DefaultFun ()
x -> Error DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow Error DefaultUni DefaultFun ()
x PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
      Right (EvaluationResult (Term Name DefaultUni DefaultFun ())
resRhs, [Text]
logsRhs) -> case (EvaluationResult (Term Name DefaultUni DefaultFun ())
resLhs, EvaluationResult (Term Name DefaultUni DefaultFun ())
resRhs) of
        (EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure, EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure) -> do
          [Text] -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow [Text]
logsLhs
          [Text] -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow [Text]
logsRhs
          PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
        (PLC.EvaluationSuccess Term Name DefaultUni DefaultFun ()
rLhs, PLC.EvaluationSuccess Term Name DefaultUni DefaultFun ()
rRhs) -> Term Name DefaultUni DefaultFun ()
rLhs Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Eq a, Show a, HasCallStack) =>
a -> a -> m ()
=== Term Name DefaultUni DefaultFun ()
rRhs
        (EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure, EvaluationResult (Term Name DefaultUni DefaultFun ())
_) -> [Text] -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow [Text]
logsLhs PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
        (EvaluationResult (Term Name DefaultUni DefaultFun ())
_, EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure) -> [Text] -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow [Text]
logsRhs PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure

-- | As 'evaluateTheSame', but for cases where we want to compare a more complex computation to a
-- constant (as if by @mkConstant@). This is slightly more efficient.
evaluatesToConstant :: forall (a :: Type) .
  PLC.Contains UPLC.DefaultUni a =>
  a ->
  PLC.Term UPLC.TyName UPLC.Name UPLC.DefaultUni UPLC.DefaultFun () ->
  PropertyT IO ()
evaluatesToConstant :: forall a.
Contains DefaultUni a =>
a -> Term TyName Name DefaultUni DefaultFun () -> PropertyT IO ()
evaluatesToConstant a
k Term TyName Name DefaultUni DefaultFun ()
expr =
  case BuiltinSemanticsVariant DefaultFun
-> CostingPart DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> Either
     (Error DefaultUni DefaultFun ())
     (EvaluationResult (Term Name DefaultUni DefaultFun ()), [Text])
forall (uni :: * -> *) fun (m :: * -> *).
(MonadError (Error uni fun ()) m, Typecheckable uni fun, GEq uni,
 Everywhere uni ExMemoryUsage, PrettyUni uni, Pretty fun) =>
BuiltinSemanticsVariant fun
-> CostingPart uni fun
-> Term TyName Name uni fun ()
-> m (EvaluationResult (Term Name uni fun ()), [Text])
typecheckEvaluateCek BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def BuiltinCostModel
CostingPart DefaultUni DefaultFun
defaultBuiltinCostModelForTesting Term TyName Name DefaultUni DefaultFun ()
expr of
    Left Error DefaultUni DefaultFun ()
err -> Error DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow Error DefaultUni DefaultFun ()
err PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
    Right (EvaluationResult (Term Name DefaultUni DefaultFun ())
res, [Text]
logs) -> case EvaluationResult (Term Name DefaultUni DefaultFun ())
res of
      EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure   -> [Text] -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow [Text]
logs PropertyT IO () -> PropertyT IO () -> PropertyT IO ()
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
      PLC.EvaluationSuccess Term Name DefaultUni DefaultFun ()
r -> Term Name DefaultUni DefaultFun ()
r Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Eq a, Show a, HasCallStack) =>
a -> a -> m ()
=== () -> a -> Term Name DefaultUni DefaultFun ()
forall a (uni :: * -> *) fun (term :: * -> *) tyname name ann.
(TermLike term tyname name uni fun, HasTermLevel uni a) =>
ann -> a -> term ann
mkConstant () a
k

-- | Given a PLC expression and an intended type (via a type argument), typecheck the expression,
-- evaluate it, then produce the required Haskell value from the results. If we fail at any stage,
-- instead fail the test and report the failure.
evaluateToHaskell :: forall (a :: Type) .
  ReadKnownIn UPLC.DefaultUni (UPLC.Term UPLC.Name UPLC.DefaultUni UPLC.DefaultFun ()) a =>
  PLC.Term UPLC.TyName UPLC.Name UPLC.DefaultUni UPLC.DefaultFun () ->
  PropertyT IO a
evaluateToHaskell :: forall a.
ReadKnownIn DefaultUni (Term Name DefaultUni DefaultFun ()) a =>
Term TyName Name DefaultUni DefaultFun () -> PropertyT IO a
evaluateToHaskell Term TyName Name DefaultUni DefaultFun ()
expr =
  case BuiltinSemanticsVariant DefaultFun
-> CostingPart DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> Either
     (Error DefaultUni DefaultFun ())
     (Either (CekEvaluationException Name DefaultUni DefaultFun) a)
forall (uni :: * -> *) fun (m :: * -> *) a.
(MonadError (Error uni fun ()) m, Typecheckable uni fun, GEq uni,
 Everywhere uni ExMemoryUsage, PrettyUni uni, Pretty fun,
 ReadKnown (Term Name uni fun ()) a) =>
BuiltinSemanticsVariant fun
-> CostingPart uni fun
-> Term TyName Name uni fun ()
-> m (Either (CekEvaluationException Name uni fun) a)
typecheckReadKnownCek BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def BuiltinCostModel
CostingPart DefaultUni DefaultFun
defaultBuiltinCostModelForTesting Term TyName Name DefaultUni DefaultFun ()
expr of
    Left Error DefaultUni DefaultFun ()
err         -> Error DefaultUni DefaultFun () -> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow Error DefaultUni DefaultFun ()
err PropertyT IO () -> PropertyT IO a -> PropertyT IO a
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO a
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
    Right (Left CekEvaluationException Name DefaultUni DefaultFun
err) -> CekEvaluationException Name DefaultUni DefaultFun
-> PropertyT IO ()
forall (m :: * -> *) a.
(MonadTest m, Show a, HasCallStack) =>
a -> m ()
annotateShow CekEvaluationException Name DefaultUni DefaultFun
err PropertyT IO () -> PropertyT IO a -> PropertyT IO a
forall a b. PropertyT IO a -> PropertyT IO b -> PropertyT IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PropertyT IO a
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
failure
    Right (Right a
x)  -> a -> PropertyT IO a
forall a. a -> PropertyT IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

-- | As 'evaluatesToConstant', but for a unit instead of a property.
assertEvaluatesToConstant :: forall (a :: Type) .
  PLC.Contains UPLC.DefaultUni a =>
  a ->
  PLC.Term UPLC.TyName UPLC.Name UPLC.DefaultUni UPLC.DefaultFun () ->
  IO ()
assertEvaluatesToConstant :: forall a.
Contains DefaultUni a =>
a -> Term TyName Name DefaultUni DefaultFun () -> IO ()
assertEvaluatesToConstant a
k Term TyName Name DefaultUni DefaultFun ()
expr =
  case BuiltinSemanticsVariant DefaultFun
-> CostingPart DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> Either
     (Error DefaultUni DefaultFun ())
     (EvaluationResult (Term Name DefaultUni DefaultFun ()), [Text])
forall (uni :: * -> *) fun (m :: * -> *).
(MonadError (Error uni fun ()) m, Typecheckable uni fun, GEq uni,
 Everywhere uni ExMemoryUsage, PrettyUni uni, Pretty fun) =>
BuiltinSemanticsVariant fun
-> CostingPart uni fun
-> Term TyName Name uni fun ()
-> m (EvaluationResult (Term Name uni fun ()), [Text])
typecheckEvaluateCek BuiltinSemanticsVariant DefaultFun
forall a. Default a => a
def BuiltinCostModel
CostingPart DefaultUni DefaultFun
defaultBuiltinCostModelForTesting Term TyName Name DefaultUni DefaultFun ()
expr of
    Left Error DefaultUni DefaultFun ()
err -> String -> IO ()
forall a. HasCallStack => String -> IO a
assertFailure (String -> IO ())
-> (Error DefaultUni DefaultFun () -> String)
-> Error DefaultUni DefaultFun ()
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error DefaultUni DefaultFun () -> String
forall a. Show a => a -> String
show (Error DefaultUni DefaultFun () -> IO ())
-> Error DefaultUni DefaultFun () -> IO ()
forall a b. (a -> b) -> a -> b
$ Error DefaultUni DefaultFun ()
err
    Right (EvaluationResult (Term Name DefaultUni DefaultFun ())
res, [Text]
logs) -> case EvaluationResult (Term Name DefaultUni DefaultFun ())
res of
      EvaluationResult (Term Name DefaultUni DefaultFun ())
PLC.EvaluationFailure   -> String -> IO ()
forall a. HasCallStack => String -> IO a
assertFailure (String -> IO ()) -> ([Text] -> String) -> [Text] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> String
forall a. Show a => a -> String
show ([Text] -> IO ()) -> [Text] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Text]
logs
      PLC.EvaluationSuccess Term Name DefaultUni DefaultFun ()
r -> String
-> Term Name DefaultUni DefaultFun ()
-> Term Name DefaultUni DefaultFun ()
-> IO ()
forall a. (Eq a, Show a, HasCallStack) => String -> a -> a -> IO ()
assertEqual String
"" Term Name DefaultUni DefaultFun ()
r (() -> a -> Term Name DefaultUni DefaultFun ()
forall a (uni :: * -> *) fun (term :: * -> *) tyname name ann.
(TermLike term tyname name uni fun, HasTermLevel uni a) =>
ann -> a -> term ann
mkConstant () a
k)

-- Helpers

hexShow :: ByteString -> String
hexShow :: ByteString -> String
hexShow ByteString
bs = String
"[" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ([Word8] -> String
go ([Word8] -> String)
-> (ByteString -> [Word8]) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ByteString
bs) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"]"
  where
    go :: [Word8] -> String
    go :: [Word8] -> String
go = \case
      [] -> String
""
      [Word8
w8] -> Word8 -> String
byteToHex Word8
w8
      (Word8
w8 : [Word8]
w8s) -> Word8 -> String
byteToHex Word8
w8 String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
", " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Word8] -> String
go [Word8]
w8s

byteToHex :: Word8 -> String
byteToHex :: Word8 -> String
byteToHex Word8
w8
  | Word8
w8 Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
128 = String
"0x0" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Word8 -> String -> String
forall a. Integral a => a -> String -> String
showHex Word8
w8 String
""
  | Bool
otherwise = String
"0x" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Word8 -> String -> String
forall a. Integral a => a -> String -> String
showHex Word8
w8 String
""