{-# LANGUAGE TypeFamilies  #-}
{-# LANGUAGE TypeOperators #-}

-- | Kind/type inference/checking, mirroring PlutusCore.TypeCheck
module PlutusIR.TypeCheck (
  -- * Configuration.
  BuiltinTypes (..),
  PirTCConfig (..),
  tccBuiltinTypes,
  getDefTypeCheckConfig,

  -- * Type checking, extending the plc typechecker
  inferType,
  checkType,
  inferTypeOfProgram,
  checkTypeOfProgram,
  MonadTypeCheckPir,
) where

import PlutusCore.Rename
import PlutusCore.TypeCheck qualified as PLC
import PlutusIR
import PlutusIR.Error
import PlutusIR.Transform.Rename ()
import PlutusIR.TypeCheck.Internal

import Control.Monad ((>=>))

{- Note [Goal of PIR typechecker]

The PIR typechecker is an extension  of the PLC typechecker; whereas the PLC typechecker
works on PLC terms, the PIR typechecker works on the PIR terms. A PIR term
can be thought of as a superset of the PLC term language: it adds the `LetRec` and
`LetNonRec` syntactic constructs. Because of this, the PIR typechecker simply extends the
PLC typechecker by adding checks for these two let constructs of PIR.

Since we already have a PIR->PLC compiler, some would say that it would suffice to first
compile the PIR to PLC and then only run the PLC typechecker. While this is mostly true,
there are some reasons for having also the PIR typechecker as an extra step on the
compiler pipeline:

- The error-messages can refer to features of PIR syntax which don't exist in PLC,
  such as let-terms

- Although PIR is an IR and as such is not supposed to be written by humans, we do have
  some hand-written PIR code in our examples/samples/testcases that we would like to make
  sure they typecheck.

- Our deadcode eliminator which works on PIR (in `PlutusIR.Optimizer.Deadcode`) may
  eliminate ill-typed code, which would turn, much to a surprise, an ill-typed program
  to a well-typed one.

- Some lets of the PIR user may be declared as recursive although they do not *have to*
  be, e.g. `let (rec) x = 3 in` would be better written as `let (nonrec) x = 3 in`.
  In such cases we could signal a warning/error (NB: not implemented atm, and probably
  not the job of the typechecker pass).

- In general, as an extra source of (type) safety.
-}

-- | The default 'TypeCheckConfig'.
getDefTypeCheckConfig ::
  (MonadKindCheck err term uni fun ann m, PLC.Typecheckable uni fun) =>
  ann ->
  m (PirTCConfig uni fun)
getDefTypeCheckConfig :: forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (PirTCConfig uni fun)
getDefTypeCheckConfig ann
ann = do
  TypeCheckConfig uni fun
configPlc <- ann -> m (TypeCheckConfig uni fun)
forall err term (uni :: * -> *) fun ann (m :: * -> *).
(MonadKindCheck err term uni fun ann m, Typecheckable uni fun) =>
ann -> m (TypeCheckConfig uni fun)
PLC.getDefTypeCheckConfig ann
ann
  PirTCConfig uni fun -> m (PirTCConfig uni fun)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PirTCConfig uni fun -> m (PirTCConfig uni fun))
-> PirTCConfig uni fun -> m (PirTCConfig uni fun)
forall a b. (a -> b) -> a -> b
$ TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
forall (uni :: * -> *) fun.
TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
PirTCConfig TypeCheckConfig uni fun
configPlc AllowEscape
YesEscape

{- | Infer the type of a term.
Note: The "inferred type" can escape its scope if YesEscape config is passed, see
[PIR vs Paper Escaping Types Difference]
-}
inferType ::
  (MonadTypeCheckPir err uni fun ann m) =>
  PirTCConfig uni fun ->
  Term TyName Name uni fun ann ->
  m (Normalized (Type TyName uni ()))
inferType :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType PirTCConfig uni fun
config = Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename (Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann))
-> (Term TyName Name uni fun ann
    -> m (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> PirTCConfig uni fun
-> PirTCEnv uni fun m (Normalized (Type TyName uni ()))
-> m (Normalized (Type TyName uni ()))
forall (uni :: * -> *) fun (m :: * -> *) a.
PirTCConfig uni fun -> PirTCEnv uni fun m a -> m a
runTypeCheckM PirTCConfig uni fun
config (PirTCEnv uni fun m (Normalized (Type TyName uni ()))
 -> m (Normalized (Type TyName uni ())))
-> (Term TyName Name uni fun ann
    -> PirTCEnv uni fun m (Normalized (Type TyName uni ())))
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun ann
-> PirTCEnv uni fun m (Normalized (Type TyName uni ()))
forall err (m :: * -> *) (uni :: * -> *) fun ann.
MonadTypeCheckPir err uni fun ann m =>
Term TyName Name uni fun ann
-> PirTCEnv uni fun m (Normalized (Type TyName uni ()))
inferTypeM

{- | Check a term against a type.
Infers the type of the term and checks that it's equal to the given type
throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
Note: this may allow witnessing a type that escapes its scope, see
[PIR vs Paper Escaping Types Difference]
-}
checkType ::
  (MonadTypeCheckPir err uni fun ann m) =>
  PirTCConfig uni fun ->
  ann ->
  Term TyName Name uni fun ann ->
  Normalized (Type TyName uni ()) ->
  m ()
checkType :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType PirTCConfig uni fun
config ann
ann Term TyName Name uni fun ann
term Normalized (Type TyName uni ())
ty = do
  Term TyName Name uni fun ann
termRen <- Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term TyName Name uni fun ann -> m (Term TyName Name uni fun ann)
rename Term TyName Name uni fun ann
term
  PirTCConfig uni fun -> PirTCEnv uni fun m () -> m ()
forall (uni :: * -> *) fun (m :: * -> *) a.
PirTCConfig uni fun -> PirTCEnv uni fun m a -> m a
runTypeCheckM PirTCConfig uni fun
config (PirTCEnv uni fun m () -> m ()) -> PirTCEnv uni fun m () -> m ()
forall a b. (a -> b) -> a -> b
$ ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> PirTCEnv uni fun m ()
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> PirTCEnv uni fun m ()
checkTypeM ann
ann Term TyName Name uni fun ann
termRen Normalized (Type TyName uni ())
ty

{- | Infer the type of a program.
Note: The "inferred type" can escape its scope if YesEscape config is passed, see
[PIR vs Paper Escaping Types Difference]
-}
inferTypeOfProgram ::
  (MonadTypeCheckPir err uni fun ann m) =>
  PirTCConfig uni fun ->
  Program TyName Name uni fun ann ->
  m (Normalized (Type TyName uni ()))
inferTypeOfProgram :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> Program TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferTypeOfProgram PirTCConfig uni fun
config (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = PirTCConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> Term TyName Name uni fun ann
-> m (Normalized (Type TyName uni ()))
inferType PirTCConfig uni fun
config Term TyName Name uni fun ann
term

{- | Check a program against a type.
Infers the type of the program and checks that it's equal to the given type
throwing a 'TypeError' (annotated with the value of the @ann@ argument) otherwise.
Note: this may allow witnessing a type that escapes its scope, see
[PIR vs Paper Escaping Types Difference]
-}
checkTypeOfProgram ::
  (MonadTypeCheckPir err uni fun ann m) =>
  PirTCConfig uni fun ->
  ann ->
  Program TyName Name uni fun ann ->
  Normalized (Type TyName uni ()) ->
  m ()
checkTypeOfProgram :: forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> ann
-> Program TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkTypeOfProgram PirTCConfig uni fun
config ann
ann (Program ann
_ Version
_ Term TyName Name uni fun ann
term) = PirTCConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
forall err (uni :: * -> *) fun ann (m :: * -> *).
MonadTypeCheckPir err uni fun ann m =>
PirTCConfig uni fun
-> ann
-> Term TyName Name uni fun ann
-> Normalized (Type TyName uni ())
-> m ()
checkType PirTCConfig uni fun
config ann
ann Term TyName Name uni fun ann
term