{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}

-- | Kind/type inference/checking.
module PlutusCore.TypeCheck
  ( ToKind
  , MonadKindCheck
  , MonadTypeCheck
  , TypeErrorPlc
  , Typecheckable

    -- * Configuration.
  , BuiltinTypes (..)
  , KindCheckConfig (..)
  , TypeCheckConfig (..)
  , tccBuiltinTypes
  , defKindCheckConfig
  , builtinMeaningsToTypes
  , getDefTypeCheckConfig

    -- * Kind/type inference/checking.
  , inferKind
  , checkKind
  , inferType
  , checkType
  , inferTypeOfProgram
  , checkTypeOfProgram
  ) where

import PlutusPrelude

import PlutusCore.Builtin
import PlutusCore.Core
import PlutusCore.Default
import PlutusCore.Error
import PlutusCore.Name.Unique
import PlutusCore.Normalize
import PlutusCore.Quote
import PlutusCore.Rename
import PlutusCore.TypeCheck.Internal

{-| The constraint for built-in types\/functions are kind\/type-checkable.

We keep this separate from 'MonadKindCheck'\/'MonadTypeCheck', because those mainly constrain the
monad and 'Typecheckable' constraints only the builtins. In particular useful when the monad gets
instantiated and builtins don't. Another reason is that 'Typecheckable' is not required during
type checking, since it's only needed for computing 'BuiltinTypes', which is passed as a regular
argument to the worker of the type checker. -}
type Typecheckable uni fun =
  (ToKind uni, HasUniApply uni, ToBuiltinMeaning uni fun, AnnotateCaseBuiltin uni)

-- | The default kind checking config.
defKindCheckConfig :: KindCheckConfig
defKindCheckConfig :: KindCheckConfig
defKindCheckConfig = HandleNameMismatches -> KindCheckConfig
KindCheckConfig HandleNameMismatches
DetectNameMismatches

{-| Extract the 'TypeScheme' from a 'BuiltinMeaning' and convert it to the
corresponding 'Type' for each built-in function. -}
builtinMeaningsToTypes
  :: (MonadKindCheck (TypeError term uni fun ann) term uni fun ann m, Typecheckable uni fun)
  => BuiltinSemanticsVariant fun
  -> ann
  -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes :: forall term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck (TypeError term uni fun ann) term uni fun ann m,
 Typecheckable uni fun) =>
BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes BuiltinSemanticsVariant fun
semvar ann
ann =
  QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun)
forall (m :: * -> *) a. Monad m => QuoteT m a -> m a
runQuoteT (QuoteT m (BuiltinTypes uni fun) -> m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array fun (Dupable (Normalized (Type TyName uni ())))
 -> BuiltinTypes uni fun)
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall a b. (a -> b) -> QuoteT m a -> QuoteT m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array fun (Dupable (Normalized (Type TyName uni ())))
-> BuiltinTypes uni fun
forall (uni :: * -> *) fun.
Array fun (Dupable (Normalized (Type TyName uni ())))
-> BuiltinTypes uni fun
BuiltinTypes (QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT m (BuiltinTypes uni fun))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> QuoteT
         m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (BuiltinTypes uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a.
Monad m =>
Array fun (m a) -> m (Array fun a)
sequence (Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> QuoteT
      m (Array fun (Dupable (Normalized (Type TyName uni ())))))
-> ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
    -> Array
         fun (QuoteT m (Dupable (Normalized (Type TyName uni ())))))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> QuoteT m (Array fun (Dupable (Normalized (Type TyName uni ()))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> Array fun (QuoteT m (Dupable (Normalized (Type TyName uni ()))))
forall i a. (Bounded i, Enum i, Ix i) => (i -> a) -> Array i a
tabulateArray ((fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
 -> m (BuiltinTypes uni fun))
-> (fun -> QuoteT m (Dupable (Normalized (Type TyName uni ()))))
-> m (BuiltinTypes uni fun)
forall a b. (a -> b) -> a -> b
$ \fun
fun -> do
    let ty :: Type TyName uni ()
ty = BuiltinSemanticsVariant fun -> fun -> Type TyName uni ()
forall (uni :: * -> *) fun.
ToBuiltinMeaning uni fun =>
BuiltinSemanticsVariant fun -> fun -> Type TyName uni ()
typeOfBuiltinFunction BuiltinSemanticsVariant fun
semvar fun
fun
    Kind ()
_ <- KindCheckConfig -> Type TyName uni ann -> QuoteT m (Kind ())
forall term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck (TypeError term uni fun ann) term uni fun ann m =>
KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind KindCheckConfig
defKindCheckConfig (Type TyName uni ann -> QuoteT m (Kind ()))
-> Type TyName uni ann -> QuoteT m (Kind ())
forall a b. (a -> b) -> a -> b
$ ann
ann ann -> Type TyName uni () -> Type TyName uni ann
forall a b. a -> Type TyName uni b -> Type TyName uni a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Type TyName uni ()
ty
    Normalized (Type TyName uni ())
-> Dupable (Normalized (Type TyName uni ()))
forall a. a -> Dupable a
dupable (Normalized (Type TyName uni ())
 -> Dupable (Normalized (Type TyName uni ())))
-> QuoteT m (Normalized (Type TyName uni ()))
-> QuoteT m (Dupable (Normalized (Type TyName uni ())))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type TyName uni () -> QuoteT m (Normalized (Type TyName uni ()))
forall tyname (uni :: * -> *) (m :: * -> *) ann.
(HasUnique tyname TypeUnique, MonadNormalizeType uni m) =>
Type tyname uni ann -> m (Normalized (Type tyname uni ann))
normalizeType Type TyName uni ()
ty

-- | Get the default type checking config.
getDefTypeCheckConfig
  :: (MonadKindCheck (TypeError term uni fun ann) term uni fun ann m, Typecheckable uni fun)
  => ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig :: forall term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck (TypeError term uni fun ann) term uni fun ann m,
 Typecheckable uni fun) =>
ann -> m (TypeCheckConfig uni fun)
getDefTypeCheckConfig ann
ann =
  KindCheckConfig -> BuiltinTypes uni fun -> TypeCheckConfig uni fun
forall (uni :: * -> *) fun.
KindCheckConfig -> BuiltinTypes uni fun -> TypeCheckConfig uni fun
TypeCheckConfig KindCheckConfig
defKindCheckConfig (BuiltinTypes uni fun -> TypeCheckConfig uni fun)
-> m (BuiltinTypes uni fun) -> m (TypeCheckConfig uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
forall term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck (TypeError term uni fun ann) term uni fun ann m,
 Typecheckable uni fun) =>
BuiltinSemanticsVariant fun -> ann -> m (BuiltinTypes uni fun)
builtinMeaningsToTypes BuiltinSemanticsVariant fun
forall a. Default a => a
def ann
ann

-- | Infer the kind of a type.
inferKind
  :: MonadKindCheck (TypeError term uni fun ann) term uni fun ann m
  => KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind :: forall term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck (TypeError term uni fun ann) term uni fun ann m =>
KindCheckConfig -> Type TyName uni ann -> m (Kind ())
inferKind KindCheckConfig
config = KindCheckConfig
-> TypeCheckT uni fun KindCheckConfig m (Kind ()) -> m (Kind ())
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM KindCheckConfig
config (TypeCheckT uni fun KindCheckConfig m (Kind ()) -> m (Kind ()))
-> (Type TyName uni ann
    -> TypeCheckT uni fun KindCheckConfig m (Kind ()))
-> Type TyName uni ann
-> m (Kind ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type TyName uni ann
-> TypeCheckT uni fun KindCheckConfig m (Kind ())
forall term (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadKindCheck (TypeError term uni fun ann) term uni fun ann m,
 HasKindCheckConfig cfg) =>
Type TyName uni ann -> TypeCheckT uni fun cfg m (Kind ())
inferKindM

{-| Check a type against a kind.
Infers the kind of the type and checks that it's equal to the given kind
throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise. -}
checkKind
  :: MonadKindCheck (TypeError term uni fun ann) term uni fun ann m
  => KindCheckConfig -> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind :: forall term (uni :: * -> *) fun ann (m :: * -> *).
MonadKindCheck (TypeError term uni fun ann) term uni fun ann m =>
KindCheckConfig -> ann -> Type TyName uni ann -> Kind () -> m ()
checkKind KindCheckConfig
config ann
ann Type TyName uni ann
ty = KindCheckConfig -> TypeCheckT uni fun KindCheckConfig m () -> m ()
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM KindCheckConfig
config (TypeCheckT uni fun KindCheckConfig m () -> m ())
-> (Kind () -> TypeCheckT uni fun KindCheckConfig m ())
-> Kind ()
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ann
-> Type TyName uni ann
-> Kind ()
-> TypeCheckT uni fun KindCheckConfig m ()
forall term (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadKindCheck (TypeError term uni fun ann) term uni fun ann m,
 HasKindCheckConfig cfg) =>
ann
-> Type TyName uni ann -> Kind () -> TypeCheckT uni fun cfg m ()
checkKindM ann
ann Type TyName uni ann
ty

-- | Infer the type of a term.
inferType
  :: MonadTypeCheckPlc uni fun ann m
  => TypeCheckConfig uni fun
  -> Term TyName Name uni fun ann
  -> m (Normalized (Type TyName uni ()))
inferType :: forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config = Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename (Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann))
-> (Term TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> TypeCheckConfig uni fun
-> TypeCheckT
     uni
     fun
     (TypeCheckConfig uni fun)
     m
     (Normalized (Type TyName uni ()))
-> m (Normalized (Type TyName uni ()))
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckT
   uni
   fun
   (TypeCheckConfig uni fun)
   m
   (Normalized (Type TyName uni ()))
 -> m (Normalized (Type TyName uni ())))
-> (Term TyName Name uni fun ann
    -> TypeCheckT
         uni
         fun
         (TypeCheckConfig uni fun)
         m
         (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun ann
-> TypeCheckT
     uni
     fun
     (TypeCheckConfig uni fun)
     m
     (Normalized (Type TyName uni ()))
forall (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadTypeCheckPlc uni fun ann m,
 HasTypeCheckConfig cfg uni fun) =>
Term TyName Name uni fun ann
-> TypeCheckT uni fun cfg m (Normalized (Type TyName uni ()))
inferTypeM

{-| Check a term against a type.
Infers the type of the term and checks that it's equal to the given type
throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise. -}
checkType
  :: MonadTypeCheckPlc uni fun ann m
  => TypeCheckConfig uni fun
  -> ann
  -> Term TyName Name uni fun ann
  -> Normalized (Type TyName uni ())
  -> m ()
checkType :: forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term Normalized (Type TyName uni ())
ty = do
  Term TyName Name uni fun ann
termRen <- Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename Term TyName Name uni fun ann
term
  TypeCheckConfig uni fun
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ()
forall cfg (uni :: * -> *) fun (m :: * -> *) a.
cfg -> TypeCheckT uni fun cfg m a -> m a
runTypeCheckM TypeCheckConfig uni fun
config (TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ())
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m () -> m ()
forall a b. (a -> b) -> a -> b
$ ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckT uni fun (TypeCheckConfig uni fun) m ()
forall (uni :: * -> *) fun ann (m :: * -> *) cfg.
(MonadTypeCheckPlc uni fun ann m,
 HasTypeCheckConfig cfg uni fun) =>
ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> TypeCheckT uni fun cfg m ()
checkTypeM ann
ann Term TyName Name uni fun ann
termRen Normalized (Type TyName uni ())
ty

-- | Infer the type of a program.
inferTypeOfProgram
  :: MonadTypeCheckPlc uni fun ann m
  => TypeCheckConfig uni fun
  -> Program TyName Name uni fun ann
  -> m (Normalized (Type TyName uni ()))
inferTypeOfProgram :: forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> Program TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferTypeOfProgram TypeCheckConfig uni fun
config (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType TypeCheckConfig uni fun
config Term TyName Name uni fun ann
term

{-| Check a program against a type.
Infers the type of the program and checks that it's equal to the given type
throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise. -}
checkTypeOfProgram
  :: MonadTypeCheckPlc uni fun ann m
  => TypeCheckConfig uni fun
  -> ann
  -> Program TyName Name uni fun ann
  -> Normalized (Type TyName uni ())
  -> m ()
checkTypeOfProgram :: forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Program TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkTypeOfProgram TypeCheckConfig uni fun
config ann
ann (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
forall (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPlc uni fun ann m =>
TypeCheckConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType TypeCheckConfig uni fun
config ann
ann Term TyName Name uni fun ann
term