{-# LANGUAGE LambdaCase #-}
{-|
Pass to convert pointlessly non-strict bindings into strict bindings, which have less overhead.
-}
module PlutusIR.Transform.StrictifyBindings (
  strictifyBindings,
  strictifyBindingsPass
  ) where

import PlutusCore.Builtin
import PlutusIR
import PlutusIR.Purity

import Control.Lens (transformOf)
import PlutusCore qualified as PLC
import PlutusCore.Name.Unique qualified as PLC
import PlutusIR.Analysis.Builtins
import PlutusIR.Analysis.VarInfo
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

strictifyBindingsStep
    :: (ToBuiltinMeaning uni fun, PLC.HasUnique name PLC.TermUnique)
    => BuiltinsInfo uni fun
    -> VarsInfo tyname name uni a
    -> Term tyname name uni fun a
    -> Term tyname name uni fun a
strictifyBindingsStep :: forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
strictifyBindingsStep BuiltinsInfo uni fun
binfo VarsInfo tyname name uni a
vinfo = \case
    Let a
a Recursivity
s NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
t -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
s ((Binding tyname name uni fun a -> Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding tyname name uni fun a -> Binding tyname name uni fun a
strictifyBinding NonEmpty (Binding tyname name uni fun a)
bs) Term tyname name uni fun a
t
      where
        strictifyBinding :: Binding tyname name uni fun a -> Binding tyname name uni fun a
strictifyBinding (TermBind a
x Strictness
NonStrict VarDecl tyname name uni a
vd Term tyname name uni fun a
rhs)
          | BuiltinsInfo uni fun
-> VarsInfo tyname name uni a -> Term tyname name uni fun a -> Bool
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a -> Term tyname name uni fun a -> Bool
isPure BuiltinsInfo uni fun
binfo VarsInfo tyname name uni a
vinfo Term tyname name uni fun a
rhs = a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
Strict VarDecl tyname name uni a
vd Term tyname name uni fun a
rhs
        strictifyBinding Binding tyname name uni fun a
b = Binding tyname name uni fun a
b
    Term tyname name uni fun a
t                                    -> Term tyname name uni fun a
t

strictifyBindings
    :: (ToBuiltinMeaning uni fun, PLC.HasUnique name PLC.TermUnique
    , PLC.HasUnique tyname PLC.TypeUnique)
    => BuiltinsInfo uni fun
    -> Term tyname name uni fun a
    -> Term tyname name uni fun a
strictifyBindings :: forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
BuiltinsInfo uni fun
-> Term tyname name uni fun a -> Term tyname name uni fun a
strictifyBindings BuiltinsInfo uni fun
binfo Term tyname name uni fun a
term =
  ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf
    ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms
    (BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
strictifyBindingsStep BuiltinsInfo uni fun
binfo (Term tyname name uni fun a -> VarsInfo tyname name uni a
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
Term tyname name uni fun a -> VarsInfo tyname name uni a
termVarInfo Term tyname name uni fun a
term))
    Term tyname name uni fun a
term

strictifyBindingsPass ::
    forall m uni fun a.
    (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m) =>
    TC.PirTCConfig uni fun ->
    BuiltinsInfo uni fun ->
    Pass m TyName Name uni fun a
strictifyBindingsPass :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun
-> BuiltinsInfo uni fun -> Pass m TyName Name uni fun a
strictifyBindingsPass PirTCConfig uni fun
tcconfig BuiltinsInfo uni fun
binfo =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"strictify bindings" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuiltinsInfo uni fun
-> Term TyName Name uni fun a -> Term TyName Name uni fun a
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
BuiltinsInfo uni fun
-> Term tyname name uni fun a -> Term tyname name uni fun a
strictifyBindings BuiltinsInfo uni fun
binfo)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]