-- editorconfig-checker-disable-file
{-# LANGUAGE TypeApplications #-}

module PlutusIR.Generators.QuickCheck.Common where

import PlutusCore.Generators.QuickCheck.Common
import PlutusCore.Generators.QuickCheck.Substitutions
import PlutusCore.Generators.QuickCheck.Unification

import PlutusCore.Default
import PlutusCore.Name.Unique
import PlutusCore.Quote (runQuoteT)
import PlutusCore.Rename
import PlutusIR
import PlutusIR.Compiler
import PlutusIR.Error
import PlutusIR.Subst
import PlutusIR.TypeCheck

import Control.Monad (void)
import Data.Bifunctor
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Set.Lens (setOf)
import Text.PrettyBy

-- | Compute the datatype declarations that escape from a term.
datatypes :: Term TyName Name DefaultUni DefaultFun ()
          -> [(TyName, (Kind ()))]
datatypes :: Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
tm = case Term TyName Name DefaultUni DefaultFun ()
tm of
  Var ()
_ Name
_           -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  Builtin ()
_ DefaultFun
_       -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  Constant ()
_ Some (ValueOf DefaultUni)
_      -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  Apply ()
_ Term TyName Name DefaultUni DefaultFun ()
_ Term TyName Name DefaultUni DefaultFun ()
_       -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  LamAbs ()
_ Name
_ Type TyName DefaultUni ()
_ Term TyName Name DefaultUni DefaultFun ()
tm'  -> Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
tm'
  TyAbs ()
_ TyName
_ Kind ()
_ Term TyName Name DefaultUni DefaultFun ()
tm'   -> Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
tm'
  TyInst ()
_ Term TyName Name DefaultUni DefaultFun ()
_ Type TyName DefaultUni ()
_    -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  Let ()
_ Recursivity
_ NonEmpty (Binding TyName Name DefaultUni DefaultFun ())
binds Term TyName Name DefaultUni DefaultFun ()
tm' -> (Binding TyName Name DefaultUni DefaultFun ()
 -> [(TyName, Kind ())] -> [(TyName, Kind ())])
-> [(TyName, Kind ())]
-> NonEmpty (Binding TyName Name DefaultUni DefaultFun ())
-> [(TyName, Kind ())]
forall a b. (a -> b -> b) -> b -> NonEmpty a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Binding TyName Name DefaultUni DefaultFun ()
-> [(TyName, Kind ())] -> [(TyName, Kind ())]
forall {a} {name} {uni :: * -> *} {fun} {a}.
Binding a name uni fun a -> [(a, Kind a)] -> [(a, Kind a)]
addDatatype (Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
tm') NonEmpty (Binding TyName Name DefaultUni DefaultFun ())
binds
    where
      addDatatype :: Binding a name uni fun a -> [(a, Kind a)] -> [(a, Kind a)]
addDatatype (DatatypeBind a
_ (Datatype a
_ (TyVarDecl a
_ a
a Kind a
k) [TyVarDecl a a]
_ name
_ [VarDecl a name uni a]
_)) = ((a
a, Kind a
k)(a, Kind a) -> [(a, Kind a)] -> [(a, Kind a)]
forall a. a -> [a] -> [a]
:)
      addDatatype Binding a name uni fun a
_                                                     = [(a, Kind a)] -> [(a, Kind a)]
forall a. a -> a
id
  Error ()
_ Type TyName DefaultUni ()
_ -> [(TyName, Kind ())]
forall a. Monoid a => a
mempty
  Term TyName Name DefaultUni DefaultFun ()
_ -> String -> [(TyName, Kind ())]
forall a. HasCallStack => String -> a
error String
"nope"

-- | Try to infer the type of an expression in a given type and term context.
-- NOTE: one can't just use out-of-the-box type inference here because the
-- `inferType` algorithm happy renames things.
inferTypeInContext :: TypeCtx
                   -> Map Name (Type TyName DefaultUni ())
                   -> Term TyName Name DefaultUni DefaultFun ()
                   -> Either String (Type TyName DefaultUni ())
inferTypeInContext :: TypeCtx
-> Map Name (Type TyName DefaultUni ())
-> Term TyName Name DefaultUni DefaultFun ()
-> Either String (Type TyName DefaultUni ())
inferTypeInContext TypeCtx
tyctx Map Name (Type TyName DefaultUni ())
ctx Term TyName Name DefaultUni DefaultFun ()
tm0 = (Error DefaultUni DefaultFun () -> String)
-> Either
     (Error DefaultUni DefaultFun ()) (Type TyName DefaultUni ())
-> Either String (Type TyName DefaultUni ())
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Error DefaultUni DefaultFun () -> String
forall str a. (Pretty a, Render str) => a -> str
display
                                 (Either
   (Error DefaultUni DefaultFun ()) (Type TyName DefaultUni ())
 -> Either String (Type TyName DefaultUni ()))
-> Either
     (Error DefaultUni DefaultFun ()) (Type TyName DefaultUni ())
-> Either String (Type TyName DefaultUni ())
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => QuoteT m a -> m a
runQuoteT @(Either (Error DefaultUni DefaultFun ())) (QuoteT
   (Either (Error DefaultUni DefaultFun ()))
   (Type TyName DefaultUni ())
 -> Either
      (Error DefaultUni DefaultFun ()) (Type TyName DefaultUni ()))
-> QuoteT
     (Either (Error DefaultUni DefaultFun ()))
     (Type TyName DefaultUni ())
-> Either
     (Error DefaultUni DefaultFun ()) (Type TyName DefaultUni ())
forall a b. (a -> b) -> a -> b
$ do
  -- CODE REVIEW: this algorithm is fragile, it relies on knowing that `inferType`
  -- does renaming to compute the `esc` substitution for datatypes. However, there is also
  -- not any other way to do this in a way that makes type inference actually useful - you
  -- want to do type inference in non-top-level contexts. Ideally I think type inference
  -- probably shouldn't do renaming of datatypes... Or alternatively we need to ensure that
  -- the renaming behaviour of type inference is documented and maintained.
  PirTCConfig DefaultUni DefaultFun
cfg <- ()
-> QuoteT
     (Either (Error DefaultUni DefaultFun ()))
     (PirTCConfig DefaultUni DefaultFun)
forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (PirTCConfig uni fun)
getDefTypeCheckConfig ()
  -- Infer the type of `tm` by adding the contexts as (type and term) lambdas
  Normalized Type TyName DefaultUni ()
_ty' <- PirTCConfig DefaultUni DefaultFun
-> Term TyName Name DefaultUni DefaultFun ()
-> QuoteT
     (Either (Error DefaultUni DefaultFun ()))
     (Normalized (Type TyName DefaultUni ()))
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType PirTCConfig DefaultUni DefaultFun
cfg Term TyName Name DefaultUni DefaultFun ()
tm'
  -- Substitute the free variables and escaping datatypes to get back to the un-renamed type.
  let ty' :: Type TyName DefaultUni ()
ty' = Set TyName
-> Map TyName (Type TyName DefaultUni ())
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
substEscape (Map TyName (Type TyName Any ()) -> Set TyName
forall k a. Map k a -> Set k
Map.keysSet Map TyName (Type TyName Any ())
forall {uni :: * -> *}. Map TyName (Type TyName uni ())
esc Set TyName -> Set TyName -> Set TyName
forall a. Semigroup a => a -> a -> a
<> (Set TyName -> Set TyName -> Set TyName)
-> Set TyName -> Map TyName (Set TyName) -> Set TyName
forall a b. (a -> b -> b) -> b -> Map TyName a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Set TyName -> Set TyName -> Set TyName
forall a. Semigroup a => a -> a -> a
(<>) (Getting (Set TyName) (Type TyName DefaultUni ()) TyName
-> Type TyName DefaultUni () -> Set TyName
forall a s. Getting (Set a) s a -> s -> Set a
setOf Getting (Set TyName) (Type TyName DefaultUni ()) TyName
forall tyname unique (uni :: * -> *) ann.
HasUnique tyname unique =>
Traversal' (Type tyname uni ann) tyname
Traversal' (Type TyName DefaultUni ()) TyName
ftvTy Type TyName DefaultUni ()
_ty') (Getting (Set TyName) (Type TyName Any ()) TyName
-> Type TyName Any () -> Set TyName
forall a s. Getting (Set a) s a -> s -> Set a
setOf Getting (Set TyName) (Type TyName Any ()) TyName
forall tyname unique (uni :: * -> *) ann.
HasUnique tyname unique =>
Traversal' (Type tyname uni ann) tyname
Traversal' (Type TyName Any ()) TyName
ftvTy (Type TyName Any () -> Set TyName)
-> Map TyName (Type TyName Any ()) -> Map TyName (Set TyName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map TyName (Type TyName Any ())
forall {uni :: * -> *}. Map TyName (Type TyName uni ())
esc)) Map TyName (Type TyName DefaultUni ())
forall {uni :: * -> *}. Map TyName (Type TyName uni ())
esc Type TyName DefaultUni ()
_ty' -- yuck
  -- Get rid of the stuff we had to add for the context.
  Type TyName DefaultUni ()
-> QuoteT
     (Either (Error DefaultUni DefaultFun ()))
     (Type TyName DefaultUni ())
forall a. a -> QuoteT (Either (Error DefaultUni DefaultFun ())) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type TyName DefaultUni ()
 -> QuoteT
      (Either (Error DefaultUni DefaultFun ()))
      (Type TyName DefaultUni ()))
-> Type TyName DefaultUni ()
-> QuoteT
     (Either (Error DefaultUni DefaultFun ()))
     (Type TyName DefaultUni ())
forall a b. (a -> b) -> a -> b
$ [(Name, Type TyName DefaultUni ())]
-> Type TyName DefaultUni () -> Type TyName DefaultUni ()
forall {a} {tyname} {uni :: * -> *} {ann}.
[a] -> Type tyname uni ann -> Type tyname uni ann
stripFuns [(Name, Type TyName DefaultUni ())]
tms (Type TyName DefaultUni () -> Type TyName DefaultUni ())
-> Type TyName DefaultUni () -> Type TyName DefaultUni ()
forall a b. (a -> b) -> a -> b
$ Map TyName (Type TyName DefaultUni ())
-> [(TyName, Kind ())]
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
forall {b}.
Map TyName (Type TyName DefaultUni ())
-> [(TyName, b)]
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
stripForalls Map TyName (Type TyName DefaultUni ())
forall a. Monoid a => a
mempty [(TyName, Kind ())]
tys Type TyName DefaultUni ()
ty'
  where
    tm' :: Term TyName Name DefaultUni DefaultFun ()
tm' = [(TyName, Kind ())]
-> Term TyName Name DefaultUni DefaultFun ()
-> Term TyName Name DefaultUni DefaultFun ()
forall {tyname} {name} {uni :: * -> *} {fun}.
[(tyname, Kind ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addTyLams [(TyName, Kind ())]
tys (Term TyName Name DefaultUni DefaultFun ()
 -> Term TyName Name DefaultUni DefaultFun ())
-> Term TyName Name DefaultUni DefaultFun ()
-> Term TyName Name DefaultUni DefaultFun ()
forall a b. (a -> b) -> a -> b
$ [(Name, Type TyName DefaultUni ())]
-> Term TyName Name DefaultUni DefaultFun ()
-> Term TyName Name DefaultUni DefaultFun ()
forall {name} {tyname} {uni :: * -> *} {fun}.
[(name, Type tyname uni ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addLams [(Name, Type TyName DefaultUni ())]
tms Term TyName Name DefaultUni DefaultFun ()
tm0
    rntm :: Term TyName Name DefaultUni DefaultFun ()
rntm = case QuoteT (Either Any) (Term TyName Name DefaultUni DefaultFun ())
-> Either Any (Term TyName Name DefaultUni DefaultFun ())
forall (m :: * -> *) a. Monad m => QuoteT m a -> m a
runQuoteT (QuoteT (Either Any) (Term TyName Name DefaultUni DefaultFun ())
 -> Either Any (Term TyName Name DefaultUni DefaultFun ()))
-> QuoteT (Either Any) (Term TyName Name DefaultUni DefaultFun ())
-> Either Any (Term TyName Name DefaultUni DefaultFun ())
forall a b. (a -> b) -> a -> b
$ Term TyName Name DefaultUni DefaultFun ()
-> QuoteT (Either Any) (Term TyName Name DefaultUni DefaultFun ())
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name DefaultUni DefaultFun ()
-> m (Term TyName Name DefaultUni DefaultFun ())
rename Term TyName Name DefaultUni DefaultFun ()
tm' of
      Left Any
_     -> String -> Term TyName Name DefaultUni DefaultFun ()
forall a. HasCallStack => String -> a
error String
"impossible"
      Right Term TyName Name DefaultUni DefaultFun ()
tm'' -> Term TyName Name DefaultUni DefaultFun ()
tm''

    -- Compute the substitution that takes datatypes that escape
    -- the scope in the inferred type (given by computing them from the
    -- renamed term) and turns them into datatypes in the old type.
    esc :: Map TyName (Type TyName uni ())
esc = [(TyName, Type TyName uni ())] -> Map TyName (Type TyName uni ())
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([TyName] -> [Type TyName uni ()] -> [(TyName, Type TyName uni ())]
forall a b. [a] -> [b] -> [(a, b)]
zip [TyName]
dats' ([Type TyName uni ()] -> [(TyName, Type TyName uni ())])
-> [Type TyName uni ()] -> [(TyName, Type TyName uni ())]
forall a b. (a -> b) -> a -> b
$ (TyName -> Type TyName uni ()) -> [TyName] -> [Type TyName uni ()]
forall a b. (a -> b) -> [a] -> [b]
map (() -> TyName -> Type TyName uni ()
forall tyname (uni :: * -> *) ann.
ann -> tyname -> Type tyname uni ann
TyVar ()) [TyName]
dats)

    dats' :: [TyName]
dats' = ((TyName, Kind ()) -> TyName) -> [(TyName, Kind ())] -> [TyName]
forall a b. (a -> b) -> [a] -> [b]
map (TyName, Kind ()) -> TyName
forall a b. (a, b) -> a
fst ([(TyName, Kind ())] -> [TyName])
-> [(TyName, Kind ())] -> [TyName]
forall a b. (a -> b) -> a -> b
$ Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
rntm
    dats :: [TyName]
dats = ((TyName, Kind ()) -> TyName) -> [(TyName, Kind ())] -> [TyName]
forall a b. (a -> b) -> [a] -> [b]
map (TyName, Kind ()) -> TyName
forall a b. (a, b) -> a
fst ([(TyName, Kind ())] -> [TyName])
-> [(TyName, Kind ())] -> [TyName]
forall a b. (a -> b) -> a -> b
$ Term TyName Name DefaultUni DefaultFun () -> [(TyName, Kind ())]
datatypes Term TyName Name DefaultUni DefaultFun ()
tm'

    tys :: [(TyName, Kind ())]
tys = TypeCtx -> [(TyName, Kind ())]
forall k a. Map k a -> [(k, a)]
Map.toList TypeCtx
tyctx
    tms :: [(Name, Type TyName DefaultUni ())]
tms = Map Name (Type TyName DefaultUni ())
-> [(Name, Type TyName DefaultUni ())]
forall k a. Map k a -> [(k, a)]
Map.toList Map Name (Type TyName DefaultUni ())
ctx

    addTyLams :: [(tyname, Kind ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addTyLams [] Term tyname name uni fun ()
tm            = Term tyname name uni fun ()
tm
    addTyLams ((tyname
x, Kind ()
k) : [(tyname, Kind ())]
xs) Term tyname name uni fun ()
tm = ()
-> tyname
-> Kind ()
-> Term tyname name uni fun ()
-> Term tyname name uni fun ()
forall tyname name (uni :: * -> *) fun a.
a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
TyAbs () tyname
x Kind ()
k (Term tyname name uni fun () -> Term tyname name uni fun ())
-> Term tyname name uni fun () -> Term tyname name uni fun ()
forall a b. (a -> b) -> a -> b
$ [(tyname, Kind ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addTyLams [(tyname, Kind ())]
xs Term tyname name uni fun ()
tm

    addLams :: [(name, Type tyname uni ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addLams [] Term tyname name uni fun ()
tm             = Term tyname name uni fun ()
tm
    addLams ((name
x, Type tyname uni ()
ty) : [(name, Type tyname uni ())]
xs) Term tyname name uni fun ()
tm = ()
-> name
-> Type tyname uni ()
-> Term tyname name uni fun ()
-> Term tyname name uni fun ()
forall tyname name (uni :: * -> *) fun a.
a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
LamAbs () name
x Type tyname uni ()
ty (Term tyname name uni fun () -> Term tyname name uni fun ())
-> Term tyname name uni fun () -> Term tyname name uni fun ()
forall a b. (a -> b) -> a -> b
$ [(name, Type tyname uni ())]
-> Term tyname name uni fun () -> Term tyname name uni fun ()
addLams [(name, Type tyname uni ())]
xs Term tyname name uni fun ()
tm

    stripForalls :: Map TyName (Type TyName DefaultUni ())
-> [(TyName, b)]
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
stripForalls Map TyName (Type TyName DefaultUni ())
sub [] Type TyName DefaultUni ()
ty                            = Map TyName (Type TyName DefaultUni ())
-> Type TyName DefaultUni () -> Type TyName DefaultUni ()
substTypeParallel Map TyName (Type TyName DefaultUni ())
sub Type TyName DefaultUni ()
ty
    stripForalls Map TyName (Type TyName DefaultUni ())
sub ((TyName
x, b
_) : [(TyName, b)]
xs) (TyForall ()
_ TyName
y Kind ()
_ Type TyName DefaultUni ()
b) = Map TyName (Type TyName DefaultUni ())
-> [(TyName, b)]
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
stripForalls (TyName
-> Type TyName DefaultUni ()
-> Map TyName (Type TyName DefaultUni ())
-> Map TyName (Type TyName DefaultUni ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TyName
y (() -> TyName -> Type TyName DefaultUni ()
forall tyname (uni :: * -> *) ann.
ann -> tyname -> Type tyname uni ann
TyVar () TyName
x) Map TyName (Type TyName DefaultUni ())
sub) [(TyName, b)]
xs Type TyName DefaultUni ()
b
    stripForalls Map TyName (Type TyName DefaultUni ())
_ [(TyName, b)]
_ Type TyName DefaultUni ()
_                                = String -> Type TyName DefaultUni ()
forall a. HasCallStack => String -> a
error String
"stripForalls"

    stripFuns :: [a] -> Type tyname uni ann -> Type tyname uni ann
stripFuns [] Type tyname uni ann
ty                  = Type tyname uni ann
ty
    stripFuns (a
_ : [a]
xs) (TyFun ann
_ Type tyname uni ann
_ Type tyname uni ann
b) = [a] -> Type tyname uni ann -> Type tyname uni ann
stripFuns [a]
xs Type tyname uni ann
b
    stripFuns [a]
_ Type tyname uni ann
_                    = String -> Type tyname uni ann
forall a. HasCallStack => String -> a
error String
"stripFuns"

typeCheckTerm :: Term TyName Name DefaultUni DefaultFun ()
              -> Type TyName DefaultUni ()
              -> Either String ()
typeCheckTerm :: Term TyName Name DefaultUni DefaultFun ()
-> Type TyName DefaultUni () -> Either String ()
typeCheckTerm = TypeCtx
-> Map Name (Type TyName DefaultUni ())
-> Term TyName Name DefaultUni DefaultFun ()
-> Type TyName DefaultUni ()
-> Either String ()
typeCheckTermInContext TypeCtx
forall k a. Map k a
Map.empty Map Name (Type TyName DefaultUni ())
forall k a. Map k a
Map.empty

typeCheckTermInContext :: TypeCtx
                       -> Map Name (Type TyName DefaultUni ())
                       -> Term TyName Name DefaultUni DefaultFun ()
                       -> Type TyName DefaultUni ()
                       -> Either String ()
typeCheckTermInContext :: TypeCtx
-> Map Name (Type TyName DefaultUni ())
-> Term TyName Name DefaultUni DefaultFun ()
-> Type TyName DefaultUni ()
-> Either String ()
typeCheckTermInContext TypeCtx
tyctx Map Name (Type TyName DefaultUni ())
ctx Term TyName Name DefaultUni DefaultFun ()
tm Type TyName DefaultUni ()
ty = Either String (Map TyName (Type TyName DefaultUni ()))
-> Either String ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Either String (Map TyName (Type TyName DefaultUni ()))
 -> Either String ())
-> Either String (Map TyName (Type TyName DefaultUni ()))
-> Either String ()
forall a b. (a -> b) -> a -> b
$ do
    Type TyName DefaultUni ()
ty' <- TypeCtx
-> Map Name (Type TyName DefaultUni ())
-> Term TyName Name DefaultUni DefaultFun ()
-> Either String (Type TyName DefaultUni ())
inferTypeInContext TypeCtx
tyctx Map Name (Type TyName DefaultUni ())
ctx Term TyName Name DefaultUni DefaultFun ()
tm
    TypeCtx
-> Set TyName
-> Type TyName DefaultUni ()
-> Type TyName DefaultUni ()
-> Either String (Map TyName (Type TyName DefaultUni ()))
unifyType TypeCtx
tyctx Set TyName
forall a. Monoid a => a
mempty Type TyName DefaultUni ()
ty' Type TyName DefaultUni ()
ty