{-# LANGUAGE LambdaCase #-}
-- | Definition analysis for Plutus IR.
-- This mostly adapts term-related code from PlutusCore.Analysis.Definitions;
-- we just re-use the typed machinery to do the hard work here.
module PlutusIR.Analysis.Definitions
    ( UniqueInfos
    , termDefs
    , runTermDefs
    ) where

import Control.Lens (forMOf_)
import Control.Monad (forM_)
import Control.Monad.State (MonadState, execStateT)
import Control.Monad.Writer (MonadWriter, runWriterT)
import PlutusCore.Error (UniqueError)
import PlutusCore.Name.Unique
import PlutusIR.Core.Plated
import PlutusIR.Core.Type

import PlutusCore.Analysis.Definitions hiding (runTermDefs, termDefs)

-- | Add declarations to definition maps.
addBindingDef :: (Ord ann,
    HasUnique name TermUnique,
    HasUnique tyname TypeUnique,
    MonadState (UniqueInfos ann) m,
    MonadWriter [UniqueError ann] m)
    => Binding tyname name uni fun ann -> m ()
addBindingDef :: forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Binding tyname name uni fun ann -> m ()
addBindingDef Binding tyname name uni fun ann
bd = case Binding tyname name uni fun ann
bd of
    TermBind ann
_a Strictness
_s (VarDecl ann
varAnn name
n Type tyname uni ann
_) Term tyname name uni fun ann
_ -> do
        name -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef name
n ann
varAnn ScopeType
TermScope
    TypeBind ann
_a (TyVarDecl ann
tyAnn tyname
tyN  Kind ann
_) Type tyname uni ann
_ -> do
        tyname -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef tyname
tyN ann
tyAnn ScopeType
TypeScope
    DatatypeBind
        ann
_a
        (Datatype
            ann
dataAnn
            (TyVarDecl ann
tyAnn tyname
tyN  Kind ann
_)
            [TyVarDecl tyname ann]
tyVarDecls
            name
dataName
            [VarDecl tyname name uni ann]
varDecls
        ) -> do
        let addTyVarDecl :: (Ord ann,
                HasUnique tyname TypeUnique,
                MonadState (UniqueInfos ann) m,
                MonadWriter [UniqueError ann] m)
                => TyVarDecl tyname ann -> m ()
            addTyVarDecl :: forall ann tyname (m :: * -> *).
(Ord ann, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
TyVarDecl tyname ann -> m ()
addTyVarDecl (TyVarDecl ann
tyVarAnn tyname
tyVarN  Kind ann
_) =
                tyname -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef tyname
tyVarN ann
tyVarAnn ScopeType
TypeScope
            addVarDecl :: (Ord ann,
                HasUnique name TermUnique,
                MonadState (UniqueInfos ann) m,
                MonadWriter [UniqueError ann] m)
                => VarDecl tyname name uni ann -> m ()
            addVarDecl :: forall ann name (m :: * -> *) tyname (uni :: * -> *).
(Ord ann, HasUnique name TermUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
VarDecl tyname name uni ann -> m ()
addVarDecl (VarDecl ann
varAnn name
n Type tyname uni ann
_) = do
                name -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef name
n ann
varAnn ScopeType
TermScope
        name -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef name
dataName ann
dataAnn ScopeType
TermScope
        tyname -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef tyname
tyN ann
tyAnn ScopeType
TypeScope
        [TyVarDecl tyname ann] -> (TyVarDecl tyname ann -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [TyVarDecl tyname ann]
tyVarDecls TyVarDecl tyname ann -> m ()
forall ann tyname (m :: * -> *).
(Ord ann, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
TyVarDecl tyname ann -> m ()
addTyVarDecl
        [VarDecl tyname name uni ann]
-> (VarDecl tyname name uni ann -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VarDecl tyname name uni ann]
varDecls VarDecl tyname name uni ann -> m ()
forall ann name (m :: * -> *) tyname (uni :: * -> *).
(Ord ann, HasUnique name TermUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
VarDecl tyname name uni ann -> m ()
addVarDecl

-- | Given a PIR term, add all of its term and type definitions and usages, including its subterms
-- and subtypes, to a global map.
termDefs
    :: (Ord ann,
        HasUnique name TermUnique,
        HasUnique tyname TypeUnique,
        MonadState (UniqueInfos ann) m,
        MonadWriter [UniqueError ann] m)
    => Term tyname name uni fun ann
    -> m ()
termDefs :: forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Term tyname name uni fun ann -> m ()
termDefs Term tyname name uni fun ann
tm = do
   Getting
  (Sequenced () m)
  (Term tyname name uni fun ann)
  (Term tyname name uni fun ann)
-> Term tyname name uni fun ann
-> (Term tyname name uni fun ann -> m ())
-> m ()
forall (m :: * -> *) r s a.
Monad m =>
Getting (Sequenced r m) s a -> s -> (a -> m r) -> m ()
forMOf_ Getting
  (Sequenced () m)
  (Term tyname name uni fun ann)
  (Term tyname name uni fun ann)
forall tyname name (uni :: * -> *) fun ann (f :: * -> *).
(Contravariant f, Applicative f) =>
(Term tyname name uni fun ann -> f (Term tyname name uni fun ann))
-> Term tyname name uni fun ann -> f (Term tyname name uni fun ann)
termSubtermsDeep Term tyname name uni fun ann
tm Term tyname name uni fun ann -> m ()
forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Term tyname name uni fun ann -> m ()
handleTerm
   Getting
  (Sequenced () m)
  (Term tyname name uni fun ann)
  (Type tyname uni ann)
-> Term tyname name uni fun ann
-> (Type tyname uni ann -> m ())
-> m ()
forall (m :: * -> *) r s a.
Monad m =>
Getting (Sequenced r m) s a -> s -> (a -> m r) -> m ()
forMOf_ Getting
  (Sequenced () m)
  (Term tyname name uni fun ann)
  (Type tyname uni ann)
forall tyname name (uni :: * -> *) fun ann (f :: * -> *).
(Contravariant f, Applicative f) =>
(Type tyname uni ann -> f (Type tyname uni ann))
-> Term tyname name uni fun ann -> f (Term tyname name uni fun ann)
termSubtypesDeep Term tyname name uni fun ann
tm Type tyname uni ann -> m ()
forall ann tyname (m :: * -> *) (uni :: * -> *).
(Ord ann, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Type tyname uni ann -> m ()
handleType

handleTerm :: (Ord ann,
        HasUnique name TermUnique,
        HasUnique tyname TypeUnique,
        MonadState (UniqueInfos ann) m,
        MonadWriter [UniqueError ann] m)
    => Term tyname name uni fun ann
    -> m ()
handleTerm :: forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Term tyname name uni fun ann -> m ()
handleTerm = \case
    Let ann
_ann Recursivity
_r NonEmpty (Binding tyname name uni fun ann)
bindings Term tyname name uni fun ann
_ ->
            NonEmpty (Binding tyname name uni fun ann)
-> (Binding tyname name uni fun ann -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ NonEmpty (Binding tyname name uni fun ann)
bindings Binding tyname name uni fun ann -> m ()
forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Binding tyname name uni fun ann -> m ()
addBindingDef
    Var ann
ann name
n         ->
        name -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addUsage name
n ann
ann ScopeType
TermScope
    LamAbs ann
ann name
n Type tyname uni ann
_ Term tyname name uni fun ann
_ ->
        name -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef name
n ann
ann ScopeType
TermScope
    TyAbs ann
ann tyname
tn Kind ann
_ Term tyname name uni fun ann
_  ->
        tyname -> ann -> ScopeType -> m ()
forall ann n unique (m :: * -> *).
(Ord ann, HasUnique n unique, MonadState (UniqueInfos ann) m,
 MonadWriter [UniqueError ann] m) =>
n -> ann -> ScopeType -> m ()
addDef tyname
tn ann
ann ScopeType
TypeScope
    Term tyname name uni fun ann
_                  -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

runTermDefs
    :: (Ord ann,
        HasUnique name TermUnique,
        HasUnique tyname TypeUnique,
        Monad m)
    => Term tyname name uni fun ann
    -> m (UniqueInfos ann, [UniqueError ann])
runTermDefs :: forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 Monad m) =>
Term tyname name uni fun ann
-> m (UniqueInfos ann, [UniqueError ann])
runTermDefs = WriterT [UniqueError ann] m (UniqueInfos ann)
-> m (UniqueInfos ann, [UniqueError ann])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [UniqueError ann] m (UniqueInfos ann)
 -> m (UniqueInfos ann, [UniqueError ann]))
-> (Term tyname name uni fun ann
    -> WriterT [UniqueError ann] m (UniqueInfos ann))
-> Term tyname name uni fun ann
-> m (UniqueInfos ann, [UniqueError ann])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ()
 -> UniqueInfos ann
 -> WriterT [UniqueError ann] m (UniqueInfos ann))
-> UniqueInfos ann
-> StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ()
-> WriterT [UniqueError ann] m (UniqueInfos ann)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ()
-> UniqueInfos ann -> WriterT [UniqueError ann] m (UniqueInfos ann)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT UniqueInfos ann
forall a. Monoid a => a
mempty (StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ()
 -> WriterT [UniqueError ann] m (UniqueInfos ann))
-> (Term tyname name uni fun ann
    -> StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ())
-> Term tyname name uni fun ann
-> WriterT [UniqueError ann] m (UniqueInfos ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term tyname name uni fun ann
-> StateT (UniqueInfos ann) (WriterT [UniqueError ann] m) ()
forall ann name tyname (m :: * -> *) (uni :: * -> *) fun.
(Ord ann, HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadState (UniqueInfos ann) m, MonadWriter [UniqueError ann] m) =>
Term tyname name uni fun ann -> m ()
termDefs