-- editorconfig-checker-disable-file
{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs            #-}

-- | Convenient functions for compiling binders.
module PlutusTx.Compiler.Binders where

import PlutusTx.Compiler.Names
import PlutusTx.Compiler.Types
import PlutusTx.PIRTypes

import GHC.Plugins qualified as GHC

import PlutusIR qualified as PIR

import Control.Monad.Reader

import Data.Traversable

-- Binder helpers

{- Note [Iterated abstraction and application]
PLC doesn't expose iterated abstraction and application.

We typically build these with a fold.
- Iterated application uses a *left* fold, since we want to apply the first variable
first.
- Iterated abstraction uses a *right* fold, since we want to bind the first
variable *last* (so it is on the outside, so will be first when applying).
-}

withVarScoped ::
    CompilingDefault uni fun m ann =>
    GHC.Var ->
    (PIR.VarDecl PIR.TyName PIR.Name uni Ann -> m a) ->
    m a
withVarScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann a.
CompilingDefault uni fun m ann =>
Var -> (VarDecl TyName Name uni Ann -> m a) -> m a
withVarScoped Var
v VarDecl TyName Name uni Ann -> m a
k = do
    let ghcName :: Name
ghcName = Var -> Name
forall a. NamedThing a => a -> Name
GHC.getName Var
v
    VarDecl TyName Name uni Ann
var <- Ann -> Var -> m (VarDecl TyName Name uni Ann)
forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
Ann -> Var -> m (PLCVar uni)
compileVarFresh Ann
annMayInline Var
v
    (CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall a.
(CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\CompileContext uni DefaultFun
c -> CompileContext uni DefaultFun
c {ccScope=pushName ghcName var (ccScope c)}) (VarDecl TyName Name uni Ann -> m a
k VarDecl TyName Name uni Ann
var)

-- | Like `withVarScoped`, but takes a `PIRType`, and uses it for the type
-- of the compiled `GHC.Var`.
withVarTyScoped ::
    CompilingDefault uni fun m ann =>
    GHC.Var ->
    PIRType uni ->
    (PIR.VarDecl PIR.TyName PIR.Name uni Ann -> m a) ->
    m a
withVarTyScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann a.
CompilingDefault uni fun m ann =>
Var -> PIRType uni -> (VarDecl TyName Name uni Ann -> m a) -> m a
withVarTyScoped Var
v PIRType uni
t VarDecl TyName Name uni Ann -> m a
k = do
    let ghcName :: Name
ghcName = Var -> Name
forall a. NamedThing a => a -> Name
GHC.getName Var
v
    VarDecl TyName Name uni Ann
var <- Ann -> Var -> PIRType uni -> m (VarDecl TyName Name uni Ann)
forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
Ann -> Var -> PIRType uni -> m (PLCVar uni)
compileVarWithTyFresh Ann
annMayInline Var
v PIRType uni
t
    (CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall a.
(CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\CompileContext uni DefaultFun
c -> CompileContext uni DefaultFun
c {ccScope=pushName ghcName var (ccScope c)}) (VarDecl TyName Name uni Ann -> m a
k VarDecl TyName Name uni Ann
var)

withVarsScoped ::
    CompilingDefault uni fun m ann =>
    [GHC.Var] ->
    ([PIR.VarDecl PIR.TyName PIR.Name uni Ann] -> m a) ->
    m a
withVarsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann a.
CompilingDefault uni fun m ann =>
[Var] -> ([VarDecl TyName Name uni Ann] -> m a) -> m a
withVarsScoped [Var]
vs [VarDecl TyName Name uni Ann] -> m a
k = do
    [(Name, VarDecl TyName Name uni Ann)]
vars <- [Var]
-> (Var -> m (Name, VarDecl TyName Name uni Ann))
-> m [(Name, VarDecl TyName Name uni Ann)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Var]
vs ((Var -> m (Name, VarDecl TyName Name uni Ann))
 -> m [(Name, VarDecl TyName Name uni Ann)])
-> (Var -> m (Name, VarDecl TyName Name uni Ann))
-> m [(Name, VarDecl TyName Name uni Ann)]
forall a b. (a -> b) -> a -> b
$ \Var
v -> do
        let name :: Name
name = Var -> Name
forall a. NamedThing a => a -> Name
GHC.getName Var
v
        VarDecl TyName Name uni Ann
var' <- Ann -> Var -> m (VarDecl TyName Name uni Ann)
forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
Ann -> Var -> m (PLCVar uni)
compileVarFresh Ann
annMayInline Var
v
        (Name, VarDecl TyName Name uni Ann)
-> m (Name, VarDecl TyName Name uni Ann)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name
name, VarDecl TyName Name uni Ann
var')
    (CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall a.
(CompileContext uni DefaultFun -> CompileContext uni DefaultFun)
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\CompileContext uni DefaultFun
c -> CompileContext uni DefaultFun
c {ccScope=pushNames vars (ccScope c)}) ([VarDecl TyName Name uni Ann] -> m a
k (((Name, VarDecl TyName Name uni Ann)
 -> VarDecl TyName Name uni Ann)
-> [(Name, VarDecl TyName Name uni Ann)]
-> [VarDecl TyName Name uni Ann]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, VarDecl TyName Name uni Ann) -> VarDecl TyName Name uni Ann
forall a b. (a, b) -> b
snd [(Name, VarDecl TyName Name uni Ann)]
vars))

withTyVarScoped ::
    Compiling uni fun m ann =>
    GHC.Var ->
    (PIR.TyVarDecl PIR.TyName Ann -> m a) ->
    m a
withTyVarScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann a.
Compiling uni fun m ann =>
Var -> (TyVarDecl TyName Ann -> m a) -> m a
withTyVarScoped Var
v TyVarDecl TyName Ann -> m a
k = do
    let ghcName :: Name
ghcName = Var -> Name
forall a. NamedThing a => a -> Name
GHC.getName Var
v
    TyVarDecl TyName Ann
var <- Var -> m (TyVarDecl TyName Ann)
forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (TyVarDecl TyName Ann)
compileTyVarFresh Var
v
    (CompileContext uni fun -> CompileContext uni fun) -> m a -> m a
forall a.
(CompileContext uni fun -> CompileContext uni fun) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\CompileContext uni fun
c -> CompileContext uni fun
c {ccScope=pushTyName ghcName var (ccScope c)}) (TyVarDecl TyName Ann -> m a
k TyVarDecl TyName Ann
var)

withTyVarsScoped ::
    Compiling uni fun m ann =>
    [GHC.Var] ->
    ([PIR.TyVarDecl PIR.TyName Ann] -> m a) ->
    m a
withTyVarsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann a.
Compiling uni fun m ann =>
[Var] -> ([TyVarDecl TyName Ann] -> m a) -> m a
withTyVarsScoped [Var]
vs [TyVarDecl TyName Ann] -> m a
k = do
    [(Name, TyVarDecl TyName Ann)]
vars <- [Var]
-> (Var -> m (Name, TyVarDecl TyName Ann))
-> m [(Name, TyVarDecl TyName Ann)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Var]
vs ((Var -> m (Name, TyVarDecl TyName Ann))
 -> m [(Name, TyVarDecl TyName Ann)])
-> (Var -> m (Name, TyVarDecl TyName Ann))
-> m [(Name, TyVarDecl TyName Ann)]
forall a b. (a -> b) -> a -> b
$ \Var
v -> do
        let name :: Name
name = Var -> Name
forall a. NamedThing a => a -> Name
GHC.getName Var
v
        TyVarDecl TyName Ann
var' <- Var -> m (TyVarDecl TyName Ann)
forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (TyVarDecl TyName Ann)
compileTyVarFresh Var
v
        (Name, TyVarDecl TyName Ann) -> m (Name, TyVarDecl TyName Ann)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name
name, TyVarDecl TyName Ann
var')
    (CompileContext uni fun -> CompileContext uni fun) -> m a -> m a
forall a.
(CompileContext uni fun -> CompileContext uni fun) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\CompileContext uni fun
c -> CompileContext uni fun
c {ccScope=pushTyNames vars (ccScope c)}) ([TyVarDecl TyName Ann] -> m a
k (((Name, TyVarDecl TyName Ann) -> TyVarDecl TyName Ann)
-> [(Name, TyVarDecl TyName Ann)] -> [TyVarDecl TyName Ann]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, TyVarDecl TyName Ann) -> TyVarDecl TyName Ann
forall a b. (a, b) -> b
snd [(Name, TyVarDecl TyName Ann)]
vars))

-- | Builds a lambda, binding the given variable to a name that
-- will be in scope when running the second argument.
mkLamAbsScoped ::
    CompilingDefault uni fun m ann =>
    GHC.Var ->
    m (PIRTerm uni fun) ->
    m (PIRTerm uni fun)
mkLamAbsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkLamAbsScoped Var
v m (PIRTerm uni fun)
body = Var
-> (VarDecl TyName Name uni Ann -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun)
forall (uni :: * -> *) fun (m :: * -> *) ann a.
CompilingDefault uni fun m ann =>
Var -> (VarDecl TyName Name uni Ann -> m a) -> m a
withVarScoped Var
v ((VarDecl TyName Name uni Ann -> m (PIRTerm uni fun))
 -> m (PIRTerm uni fun))
-> (VarDecl TyName Name uni Ann -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun)
forall a b. (a -> b) -> a -> b
$ \(PIR.VarDecl Ann
_ Name
n Type TyName uni Ann
t) -> Ann
-> Name
-> Type TyName uni Ann
-> PIRTerm uni fun
-> PIRTerm uni fun
forall tyname name (uni :: * -> *) fun a.
a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
PIR.LamAbs Ann
annMayInline Name
n Type TyName uni Ann
t (PIRTerm uni fun -> PIRTerm uni fun)
-> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (PIRTerm uni fun)
body

mkIterLamAbsScoped :: CompilingDefault uni fun m ann => [GHC.Var] -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkIterLamAbsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
[Var] -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkIterLamAbsScoped [Var]
vars m (PIRTerm uni fun)
body = (Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun) -> [Var] -> m (PIRTerm uni fun)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Var
v m (PIRTerm uni fun)
acc -> Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
forall (uni :: * -> *) fun (m :: * -> *) ann.
CompilingDefault uni fun m ann =>
Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkLamAbsScoped Var
v m (PIRTerm uni fun)
acc) m (PIRTerm uni fun)
body [Var]
vars

-- | Builds a type abstraction, binding the given variable to a name that
-- will be in scope when running the second argument.
mkTyAbsScoped :: Compiling uni fun m ann => GHC.Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkTyAbsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkTyAbsScoped Var
v m (PIRTerm uni fun)
body = Var
-> (TyVarDecl TyName Ann -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun)
forall (uni :: * -> *) fun (m :: * -> *) ann a.
Compiling uni fun m ann =>
Var -> (TyVarDecl TyName Ann -> m a) -> m a
withTyVarScoped Var
v ((TyVarDecl TyName Ann -> m (PIRTerm uni fun))
 -> m (PIRTerm uni fun))
-> (TyVarDecl TyName Ann -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun)
forall a b. (a -> b) -> a -> b
$ \(PIR.TyVarDecl Ann
_ TyName
t Kind Ann
k) -> Ann -> TyName -> Kind Ann -> PIRTerm uni fun -> PIRTerm uni fun
forall tyname name (uni :: * -> *) fun a.
a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
PIR.TyAbs Ann
annMayInline TyName
t Kind Ann
k (PIRTerm uni fun -> PIRTerm uni fun)
-> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (PIRTerm uni fun)
body

mkIterTyAbsScoped :: Compiling uni fun m ann => [GHC.Var] -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkIterTyAbsScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
[Var] -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkIterTyAbsScoped [Var]
vars m (PIRTerm uni fun)
body = (Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun))
-> m (PIRTerm uni fun) -> [Var] -> m (PIRTerm uni fun)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Var
v m (PIRTerm uni fun)
acc -> Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRTerm uni fun) -> m (PIRTerm uni fun)
mkTyAbsScoped Var
v m (PIRTerm uni fun)
acc) m (PIRTerm uni fun)
body [Var]
vars

-- | Builds a forall, binding the given variable to a name that
-- will be in scope when running the second argument.
mkTyForallScoped :: Compiling uni fun m ann => GHC.Var -> m (PIRType uni) -> m (PIRType uni)
mkTyForallScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRType uni) -> m (PIRType uni)
mkTyForallScoped Var
v m (PIRType uni)
body =
    Var -> (TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni)
forall (uni :: * -> *) fun (m :: * -> *) ann a.
Compiling uni fun m ann =>
Var -> (TyVarDecl TyName Ann -> m a) -> m a
withTyVarScoped Var
v ((TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni))
-> (TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni)
forall a b. (a -> b) -> a -> b
$ \(PIR.TyVarDecl Ann
_ TyName
t Kind Ann
k) -> Ann -> TyName -> Kind Ann -> PIRType uni -> PIRType uni
forall tyname (uni :: * -> *) ann.
ann
-> tyname -> Kind ann -> Type tyname uni ann -> Type tyname uni ann
PIR.TyForall Ann
annMayInline TyName
t Kind Ann
k (PIRType uni -> PIRType uni) -> m (PIRType uni) -> m (PIRType uni)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (PIRType uni)
body

mkIterTyForallScoped :: Compiling uni fun m ann => [GHC.Var] -> m (PIRType uni) -> m (PIRType uni)
mkIterTyForallScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
[Var] -> m (PIRType uni) -> m (PIRType uni)
mkIterTyForallScoped [Var]
vars m (PIRType uni)
body = (Var -> m (PIRType uni) -> m (PIRType uni))
-> m (PIRType uni) -> [Var] -> m (PIRType uni)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Var
v m (PIRType uni)
acc -> Var -> m (PIRType uni) -> m (PIRType uni)
forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRType uni) -> m (PIRType uni)
mkTyForallScoped Var
v m (PIRType uni)
acc) m (PIRType uni)
body [Var]
vars

-- | Builds a type lambda, binding the given variable to a name that
-- will be in scope when running the second argument.
mkTyLamScoped :: Compiling uni fun m ann => GHC.Var -> m (PIRType uni) -> m (PIRType uni)
mkTyLamScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRType uni) -> m (PIRType uni)
mkTyLamScoped Var
v m (PIRType uni)
body =
    Var -> (TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni)
forall (uni :: * -> *) fun (m :: * -> *) ann a.
Compiling uni fun m ann =>
Var -> (TyVarDecl TyName Ann -> m a) -> m a
withTyVarScoped Var
v ((TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni))
-> (TyVarDecl TyName Ann -> m (PIRType uni)) -> m (PIRType uni)
forall a b. (a -> b) -> a -> b
$ \(PIR.TyVarDecl Ann
_ TyName
t Kind Ann
k) -> Ann -> TyName -> Kind Ann -> PIRType uni -> PIRType uni
forall tyname (uni :: * -> *) ann.
ann
-> tyname -> Kind ann -> Type tyname uni ann -> Type tyname uni ann
PIR.TyLam Ann
annMayInline TyName
t Kind Ann
k (PIRType uni -> PIRType uni) -> m (PIRType uni) -> m (PIRType uni)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (PIRType uni)
body

mkIterTyLamScoped :: Compiling uni fun m ann => [GHC.Var] -> m (PIRType uni) -> m (PIRType uni)
mkIterTyLamScoped :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
[Var] -> m (PIRType uni) -> m (PIRType uni)
mkIterTyLamScoped [Var]
vars m (PIRType uni)
body = (Var -> m (PIRType uni) -> m (PIRType uni))
-> m (PIRType uni) -> [Var] -> m (PIRType uni)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Var
v m (PIRType uni)
acc -> Var -> m (PIRType uni) -> m (PIRType uni)
forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Var -> m (PIRType uni) -> m (PIRType uni)
mkTyLamScoped Var
v m (PIRType uni)
acc) m (PIRType uni)
body [Var]
vars