{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE RankNTypes        #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module PlutusIR.Pass.Test where

import Control.Monad.Except
import Data.Typeable
import PlutusCore qualified as PLC
import PlutusCore.Builtin
import PlutusCore.Generators.QuickCheck (forAllDoc)
import PlutusCore.Pretty qualified as PLC
import PlutusIR.Core.Type
import PlutusIR.Error qualified as PIR
import PlutusIR.Generators.QuickCheck
import PlutusIR.Pass
import PlutusIR.TypeCheck
import PlutusIR.TypeCheck qualified as TC
import PlutusPrelude
import Test.QuickCheck

-- Convert Either Error () to Either String () to match with the Testable (Either String ())
-- instance.
convertToEitherString :: Either (PIR.Error PLC.DefaultUni PLC.DefaultFun ()) ()
  -> Either String ()
convertToEitherString :: Either (Error DefaultUni DefaultFun ()) () -> Either String ()
convertToEitherString = \case
  Left Error DefaultUni DefaultFun ()
err -> String -> Either String ()
forall a b. a -> Either a b
Left (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$ Error DefaultUni DefaultFun () -> String
forall a. Show a => a -> String
show Error DefaultUni DefaultFun ()
err
  Right () -> () -> Either String ()
forall a b. b -> Either a b
Right ()

instance Arbitrary (BuiltinSemanticsVariant PLC.DefaultFun) where
    arbitrary :: Gen (BuiltinSemanticsVariant DefaultFun)
arbitrary = [BuiltinSemanticsVariant DefaultFun]
-> Gen (BuiltinSemanticsVariant DefaultFun)
forall a. [a] -> Gen a
elements [BuiltinSemanticsVariant DefaultFun]
forall a. (Enum a, Bounded a) => [a]
enumerate

-- | An appropriate number of tests for a compiler pass property, so that we get some decent
-- exploration of the program space. If you also take other arguments, then consider multiplying
-- this up in order to account for the larger space.
numTestsForPassProp :: Int
numTestsForPassProp :: Int
numTestsForPassProp = Int
99

-- | Run a 'Pass' on a 'Term', setting up the typechecking config and throwing errors.
runTestPass
  :: (PLC.ThrowableBuiltins uni fun
     , PLC.Typecheckable uni fun
     , PLC.Pretty a
     , Typeable a
     , Monoid a
     , Monad m
     )
  => (TC.PirTCConfig uni fun -> Pass m tyname name uni fun a)
  -> Term tyname name uni fun a
  -> m (Term tyname name uni fun a)
runTestPass :: forall (uni :: * -> *) fun a (m :: * -> *) tyname name.
(ThrowableBuiltins uni fun, Typecheckable uni fun, Pretty a,
 Typeable a, Monoid a, Monad m) =>
(PirTCConfig uni fun -> Pass m tyname name uni fun a)
-> Term tyname name uni fun a -> m (Term tyname name uni fun a)
runTestPass PirTCConfig uni fun -> Pass m tyname name uni fun a
pass Term tyname name uni fun a
t = do
  Either (Error uni fun a) (Term tyname name uni fun a)
res <- ExceptT (Error uni fun a) m (Term tyname name uni fun a)
-> m (Either (Error uni fun a) (Term tyname name uni fun a))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT (Error uni fun a) m (Term tyname name uni fun a)
 -> m (Either (Error uni fun a) (Term tyname name uni fun a)))
-> ExceptT (Error uni fun a) m (Term tyname name uni fun a)
-> m (Either (Error uni fun a) (Term tyname name uni fun a))
forall a b. (a -> b) -> a -> b
$ do
    PirTCConfig uni fun
tcconfig <- a -> ExceptT (Error uni fun a) m (PirTCConfig uni fun)
forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (PirTCConfig uni fun)
TC.getDefTypeCheckConfig a
forall a. Monoid a => a
mempty
    (String -> m ())
-> Bool
-> Pass m tyname name uni fun a
-> Term tyname name uni fun a
-> ExceptT (Error uni fun a) m (Term tyname name uni fun a)
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
Monad m =>
(String -> m ())
-> Bool
-> Pass m tyname name uni fun a
-> Term tyname name uni fun a
-> ExceptT (Error uni fun a) m (Term tyname name uni fun a)
runPass (\String
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) Bool
True (PirTCConfig uni fun -> Pass m tyname name uni fun a
pass PirTCConfig uni fun
tcconfig) Term tyname name uni fun a
t
  case Either (Error uni fun a) (Term tyname name uni fun a)
res of
    Left Error uni fun a
e  -> Error uni fun a -> m (Term tyname name uni fun a)
forall a e. Exception e => e -> a
throw Error uni fun a
e
    Right Term tyname name uni fun a
v -> Term tyname name uni fun a -> m (Term tyname name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term tyname name uni fun a
v

-- | Run a 'Pass' on generated 'Terms's, setting up the typechecking config
-- and throwing errors.
testPassProp ::
  Monad m
  => (forall a . m a -> a)
  -> (TC.PirTCConfig PLC.DefaultUni PLC.DefaultFun
      -> Pass m TyName Name PLC.DefaultUni PLC.DefaultFun ())
  -> Property
testPassProp :: forall (m :: * -> *).
Monad m =>
(forall a. m a -> a)
-> (PirTCConfig DefaultUni DefaultFun
    -> Pass m TyName Name DefaultUni DefaultFun ())
-> Property
testPassProp forall a. m a -> a
exitMonad PirTCConfig DefaultUni DefaultFun
-> Pass m TyName Name DefaultUni DefaultFun ()
pass =
  ()
-> (Term TyName Name DefaultUni DefaultFun ()
    -> Term TyName Name DefaultUni DefaultFun ())
-> (ExceptT (Error DefaultUni DefaultFun ()) m ()
    -> Either String ())
-> (PirTCConfig DefaultUni DefaultFun
    -> Pass m TyName Name DefaultUni DefaultFun ())
-> Property
forall (m :: * -> *) tyname name a prop.
(Monad m, Testable prop) =>
a
-> (Term TyName Name DefaultUni DefaultFun ()
    -> Term tyname name DefaultUni DefaultFun a)
-> (ExceptT (Error DefaultUni DefaultFun a) m () -> prop)
-> (PirTCConfig DefaultUni DefaultFun
    -> Pass m tyname name DefaultUni DefaultFun a)
-> Property
testPassProp'
    ()
    Term TyName Name DefaultUni DefaultFun ()
-> Term TyName Name DefaultUni DefaultFun ()
forall a. a -> a
id
    ExceptT (Error DefaultUni DefaultFun ()) m () -> Either String ()
forall {a}.
ExceptT (Error DefaultUni DefaultFun a) m () -> Either String ()
after
    PirTCConfig DefaultUni DefaultFun
-> Pass m TyName Name DefaultUni DefaultFun ()
pass
  where
    after :: ExceptT (Error DefaultUni DefaultFun a) m () -> Either String ()
after ExceptT (Error DefaultUni DefaultFun a) m ()
res = Either (Error DefaultUni DefaultFun ()) () -> Either String ()
convertToEitherString (Either (Error DefaultUni DefaultFun ()) () -> Either String ())
-> Either (Error DefaultUni DefaultFun ()) () -> Either String ()
forall a b. (a -> b) -> a -> b
$ (Error DefaultUni DefaultFun a -> Error DefaultUni DefaultFun ())
-> Either (Error DefaultUni DefaultFun a) ()
-> Either (Error DefaultUni DefaultFun ()) ()
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Error DefaultUni DefaultFun a -> Error DefaultUni DefaultFun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Either (Error DefaultUni DefaultFun a) ()
 -> Either (Error DefaultUni DefaultFun ()) ())
-> Either (Error DefaultUni DefaultFun a) ()
-> Either (Error DefaultUni DefaultFun ()) ()
forall a b. (a -> b) -> a -> b
$ m (Either (Error DefaultUni DefaultFun a) ())
-> Either (Error DefaultUni DefaultFun a) ()
forall a. m a -> a
exitMonad (m (Either (Error DefaultUni DefaultFun a) ())
 -> Either (Error DefaultUni DefaultFun a) ())
-> m (Either (Error DefaultUni DefaultFun a) ())
-> Either (Error DefaultUni DefaultFun a) ()
forall a b. (a -> b) -> a -> b
$ ExceptT (Error DefaultUni DefaultFun a) m ()
-> m (Either (Error DefaultUni DefaultFun a) ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT (Error DefaultUni DefaultFun a) m ()
res

-- | A version of 'testPassProp' with more control, allowing some pre-processing
-- of the term, and a more specific "exit" function.
testPassProp' ::
  forall m tyname name a prop
  . (Monad m, Testable prop)
  => a
  -> (Term TyName Name PLC.DefaultUni PLC.DefaultFun ()
      -> Term tyname name PLC.DefaultUni PLC.DefaultFun a)
  -> (ExceptT (PIR.Error PLC.DefaultUni PLC.DefaultFun a) m () -> prop)
  -> (TC.PirTCConfig PLC.DefaultUni PLC.DefaultFun
      -> Pass m tyname name PLC.DefaultUni PLC.DefaultFun a)
  -> Property
testPassProp' :: forall (m :: * -> *) tyname name a prop.
(Monad m, Testable prop) =>
a
-> (Term TyName Name DefaultUni DefaultFun ()
    -> Term tyname name DefaultUni DefaultFun a)
-> (ExceptT (Error DefaultUni DefaultFun a) m () -> prop)
-> (PirTCConfig DefaultUni DefaultFun
    -> Pass m tyname name DefaultUni DefaultFun a)
-> Property
testPassProp' a
ann Term TyName Name DefaultUni DefaultFun ()
-> Term tyname name DefaultUni DefaultFun a
before ExceptT (Error DefaultUni DefaultFun a) m () -> prop
after PirTCConfig DefaultUni DefaultFun
-> Pass m tyname name DefaultUni DefaultFun a
pass =
  String
-> Gen
     (Type TyName DefaultUni (),
      Term TyName Name DefaultUni DefaultFun ())
-> ((Type TyName DefaultUni (),
     Term TyName Name DefaultUni DefaultFun ())
    -> [(Type TyName DefaultUni (),
         Term TyName Name DefaultUni DefaultFun ())])
-> ((Type TyName DefaultUni (),
     Term TyName Name DefaultUni DefaultFun ())
    -> prop)
-> Property
forall a p.
(PrettyPir a, Testable p) =>
String -> Gen a -> (a -> [a]) -> (a -> p) -> Property
forAllDoc String
"ty,tm" Gen
  (Type TyName DefaultUni (),
   Term TyName Name DefaultUni DefaultFun ())
genTypeAndTerm_ (Type TyName DefaultUni (),
 Term TyName Name DefaultUni DefaultFun ())
-> [(Type TyName DefaultUni (),
     Term TyName Name DefaultUni DefaultFun ())]
shrinkClosedTypedTerm (((Type TyName DefaultUni (),
   Term TyName Name DefaultUni DefaultFun ())
  -> prop)
 -> Property)
-> ((Type TyName DefaultUni (),
     Term TyName Name DefaultUni DefaultFun ())
    -> prop)
-> Property
forall a b. (a -> b) -> a -> b
$ \ (Type TyName DefaultUni ()
_ty, Term TyName Name DefaultUni DefaultFun ()
tm) ->
    let
      res :: ExceptT (PIR.Error PLC.DefaultUni PLC.DefaultFun a) m ()
      res :: ExceptT (Error DefaultUni DefaultFun a) m ()
res = do
        PirTCConfig DefaultUni DefaultFun
tcconfig <- a
-> ExceptT
     (Error DefaultUni DefaultFun a)
     m
     (PirTCConfig DefaultUni DefaultFun)
forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (PirTCConfig uni fun)
getDefTypeCheckConfig a
ann
        let tm' :: Term tyname name DefaultUni DefaultFun a
tm' = Term TyName Name DefaultUni DefaultFun ()
-> Term tyname name DefaultUni DefaultFun a
before Term TyName Name DefaultUni DefaultFun ()
tm
        Term tyname name DefaultUni DefaultFun a
_ <- (String -> m ())
-> Bool
-> Pass m tyname name DefaultUni DefaultFun a
-> Term tyname name DefaultUni DefaultFun a
-> ExceptT
     (Error DefaultUni DefaultFun a)
     m
     (Term tyname name DefaultUni DefaultFun a)
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
Monad m =>
(String -> m ())
-> Bool
-> Pass m tyname name uni fun a
-> Term tyname name uni fun a
-> ExceptT (Error uni fun a) m (Term tyname name uni fun a)
runPass (\String
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) Bool
True (PirTCConfig DefaultUni DefaultFun
-> Pass m tyname name DefaultUni DefaultFun a
pass PirTCConfig DefaultUni DefaultFun
tcconfig) Term tyname name DefaultUni DefaultFun a
tm'
        () -> ExceptT (Error DefaultUni DefaultFun a) m ()
forall a. a -> ExceptT (Error DefaultUni DefaultFun a) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    in ExceptT (Error DefaultUni DefaultFun a) m () -> prop
after ExceptT (Error DefaultUni DefaultFun a) m ()
res