-- | Kind/type inference/checking.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies    #-}

module PlutusCore.TypeCheck
    ( ToKind
    , MonadKindCheck
    , MonadTypeCheck
    , Typecheckable
    -- * Configuration.
    , BuiltinTypes (..)
    , KindCheckConfig (..)
    , TypeCheckConfig (..)
    , tccBuiltinTypes
    , defKindCheckConfig
    , builtinMeaningsToTypes
    , getDefTypeCheckConfig
    -- * Kind/type inference/checking.
    , inferKind
    , checkKind
    , inferType
    , checkType
    , inferTypeOfProgram
    , checkTypeOfProgram
    ) where

import PlutusPrelude

import PlutusCore.Builtin
import PlutusCore.Core
import PlutusCore.Default
import PlutusCore.Name.Unique
import PlutusCore.Normalize
import PlutusCore.Quote
import PlutusCore.Rename
import PlutusCore.TypeCheck.Internal

-- | The constraint for built-in types\/functions are kind\/type-checkable.
--
-- We keep this separate from 'MonadKindCheck'\/'MonadTypeCheck', because those mainly constrain the
-- monad and 'Typecheckable' constraints only the builtins. In particular useful when the monad gets
-- instantiated and builtins don't. Another reason is that 'Typecheckable' is not required during
-- type checking, since it's only needed for computing 'BuiltinTypes', which is passed as a regular
-- argument to the worker of the type checker.
type Typecheckable uni fun = (ToKind uni, HasUniApply uni, ToBuiltinMeaning uni fun)

-- | The default kind checking config.
defKindCheckConfig :: KindCheckConfig
defKindCheckConfig :: KindCheckConfig
defKindCheckConfig = HandleNameMismatches -> KindCheckConfig
KindCheckConfig HandleNameMismatches
DetectNameMismatches

-- | Extract the 'TypeScheme' from a 'BuiltinMeaning' and convert it to the
-- corresponding 'Type' for each built-in function.
builtinMeaningsToTypes
    :: (MonadKindCheck err term uni fun ann m, Typecheckable uni fun)
    => BuiltinSemanticsVariant fun
    -> ann
    -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes :: forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes BuiltinSemanticsVariant fun
semvar ann
ann =
    QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun)
forall (m :: * -> *) a. Monad m => QuoteT m a -> m a
runQuoteT (QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array fun (Dupable (Normalized (Type TyName uni ())))
 -> BuiltinTypes uni fun)
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall a b. (a -> b) -> QuoteT m a -> QuoteT m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array fun (Dupable (Normalized (Type TyName uni ())))
-> BuiltinTypes uni fun
forall (uni :: * -> *) fun.
Array fun (Dupable (Normalized (Type TyName uni ())))
-> BuiltinTypes uni fun
BuiltinTypes (QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT
         m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a.
Monad m =>
Array fun (m a) -> m (Array fun a)
sequence (Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT
      m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> Array
         fun (QuoteT m (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
forall i a. (Bounded i, Enum i, Ix i) => (i -> a) -> Array i a
tabulateArray ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall a b. (a -> b) -> a -> b
$ \fun
fun -> do
        let ty :: Type TyName uni ()
ty = BuiltinSemanticsVariant fun -> fun -> Type TyName uni ()
forall (uni :: * -> *) fun.
ToBuiltinMeaning uni fun =>
BuiltinSemanticsVariant fun -> fun -> Type TyName uni ()
typeOfBuiltinFunction BuiltinSemanticsVariant fun
semvar fun
fun
        Kind ()
_ <- KindCheckConfig -> Type TyName uni ann -> QuoteT m (Kind ())
forall err term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck err term uni fun ann m =>
KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind KindCheckConfig
defKindCheckConfig (Type TyName uni ann -> QuoteT m (Kind ()))
-> Type TyName uni ann -> QuoteT m (Kind ())
forall a b. (a -> b) -> a -> b
$ ann
ann ann -> Type TyName uni () -> Type TyName uni ann
forall a b. a -> Type TyName uni b -> Type TyName uni a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Type TyName uni ()
ty
        Normalized (Type TyName uni ())
-> Dupable (Normalized (Type TyName uni ()))
forall a. a -> Dupable a
dupable (Normalized (Type TyName uni ())
 -> Dupable (Normalized (Type TyName uni ())))
-> QuoteT m (Normalized (Type TyName uni ()))
-> QuoteT m (Dupable (Normalized (Type TyName uni ())))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type TyName uni () -> QuoteT m (Normalized (Type TyName uni ()))
forall tyname (uni :: * -> *) (m :: * -> *) ann.
(HasUnique tyname TypeUnique, MonadNormalizeType uni m) =>
Type tyname uni ann -> m (Normalized (Type tyname uni ann))
normalizeType Type TyName uni ()
ty

-- | Get the default type checking config.
getDefTypeCheckConfig
    :: (MonadKindCheck err term uni fun ann m, Typecheckable uni fun)
    => ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig :: forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig ann
ann =
    KindCheckConfig -> BuiltinTypes uni fun -> TypeCheckConfig uni fun
forall (uni :: * -> *) fun.
KindCheckConfig -> BuiltinTypes uni fun -> TypeCheckConfig uni fun
TypeCheckConfig KindCheckConfig
defKindCheckConfig (BuiltinTypes uni fun -> TypeCheckConfig uni fun)
-> m (BuiltinTypes uni fun) -> m (TypeCheckConfig uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes BuiltinSemanticsVariant fun
forall a. Default a => a
def ann
ann

-- | Infer the kind of a type.
inferKind
    :: MonadKindCheck err term uni fun ann m
    => KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind :: forall err term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck err term uni fun ann m =>
KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind KindCheckConfig
config = KindCheckConfig
-> TypeCheckT uni fun KindCheckConfig m (Kind ()) -> m (Kind ())
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM KindCheckConfig
config (TypeCheckT uni fun KindCheckConfig m (Kind ()) -> m (Kind ()))
-> (Type TyName uni ann
    -> TypeCheckT uni fun KindCheckConfig m (Kind ()))
-> Type TyName uni ann
-> m (Kind ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type TyName uni ann
-> TypeCheckT uni fun KindCheckConfig m (Kind ())
forall err term (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadKindCheck err term uni fun ann m, HasKindCheckConfig cfg) =>
Type TyName uni ann -> TypeCheckT uni fun cfg m (Kind ())
inferKindM

-- | Check a type against a kind.
-- Infers the kind of the type and checks that it's equal to the given kind
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkKind
    :: MonadKindCheck err term uni fun ann m
    => KindCheckConfig -> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind :: forall err term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck err term uni fun ann m =>
KindCheckConfig -> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind KindCheckConfig
config ann
ann Type TyName uni ann
ty = KindCheckConfig -> TypeCheckT uni fun KindCheckConfig m () -> m ()
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM KindCheckConfig
config (TypeCheckT uni fun KindCheckConfig m () -> m ())
-> (Kind () -> TypeCheckT uni fun KindCheckConfig m ())
-> Kind ()
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ann
-> Type TyName uni ann
-> Kind ()
-> TypeCheckT uni fun KindCheckConfig m ()
forall err term (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadKindCheck err term uni fun ann m, HasKindCheckConfig cfg) =>
ann
-> Type TyName uni ann -> Kind () -> TypeCheckT uni fun cfg m ()
checkKindM ann
ann Type TyName uni ann
ty

-- | Infer the type of a term.
inferType
    :: MonadTypeCheckPlc err uni fun ann m
    => TypeCheckConfig uni fun
    -> Term TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ()))
inferType :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config = Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename (Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann))
-> (Term TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> TypeCheckConfig uni fun
-> TypeCheckT
     uni
     fun
     (TypeCheckConfig uni fun)
     m
     (Normalized (Type TyName uni ()))
-> m (Normalized (Type TyName uni ()))
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckT
   uni
   fun
   (TypeCheckConfig uni fun)
   m
   (Normalized (Type TyName uni ()))
 -> m (Normalized (Type TyName uni ())))
-> (Term TyName Name uni fun ann
    -> TypeCheckT
         uni
         fun
         (TypeCheckConfig uni fun)
         m
         (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun ann
-> TypeCheckT
     uni
     fun
     (TypeCheckConfig uni fun)
     m
     (Normalized (Type TyName uni ()))
forall err (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadTypeCheckPlc err uni fun ann m,
 HasTypeCheckConfig cfg uni fun) =>
Term TyName Name uni fun ann
-> TypeCheckT uni fun cfg m (Normalized (Type TyName uni ()))
inferTypeM

-- | Check a term against a type.
-- Infers the type of the term and checks that it's equal to the given type
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkType
    :: MonadTypeCheckPlc err uni fun ann m
    => TypeCheckConfig uni fun
    -> ann
    -> Term TyName Name uni fun ann
    -> Normalized (Type TyName uni ())
    -> m ()
checkType :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term Normalized (Type TyName uni ())
ty = do
    Term TyName Name uni fun ann
termRen <- Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename Term TyName Name uni fun ann
term
    TypeCheckConfig uni fun
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ()
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ())
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ()
forall a b. (a -> b) -> a -> b
$ ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m ()
forall err (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadTypeCheckPlc err uni fun ann m,
 HasTypeCheckConfig cfg uni fun) =>
ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckT uni fun cfg m ()
checkTypeM ann
ann Term TyName Name uni fun ann
termRen Normalized (Type TyName uni ())
ty

-- | Infer the type of a program.
inferTypeOfProgram
    :: MonadTypeCheckPlc err uni fun ann m
    => TypeCheckConfig uni fun
    -> Program TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ()))
inferTypeOfProgram :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> Program TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferTypeOfProgram TypeCheckConfig uni fun
config (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config Term TyName Name uni fun ann
term

-- | Check a program against a type.
-- Infers the type of the program and checks that it's equal to the given type
-- throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
checkTypeOfProgram
    :: MonadTypeCheckPlc err uni fun ann m
    => TypeCheckConfig uni fun
    -> ann
    -> Program TyName Name uni fun ann
    -> Normalized (Type TyName uni ())
    -> m ()
checkTypeOfProgram :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Program TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkTypeOfProgram TypeCheckConfig uni fun
config ann
ann (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc err uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term