{-# LANGUAGE LambdaCase #-}
{-|
Pass to convert the following non-strict bindings into strict bindings, which have less overhead:

  * non-strict bindings whose RHSs are pure
  * non-strict bindings that are strict in the body
-}
module PlutusIR.Transform.StrictifyBindings (
  strictifyBindings,
  strictifyBindingsPass
  ) where

import PlutusCore.Builtin
import PlutusIR
import PlutusIR.Purity
import PlutusIR.Strictness

import Control.Lens (transformOf, (^.))
import PlutusCore qualified as PLC
import PlutusCore.Name.Unique qualified as PLC
import PlutusIR.Analysis.Builtins
import PlutusIR.Analysis.VarInfo
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

strictifyBindingsStep
    :: (ToBuiltinMeaning uni fun, PLC.HasUnique name PLC.TermUnique, Eq name)
    => BuiltinsInfo uni fun
    -> VarsInfo tyname name uni a
    -> Term tyname name uni fun a
    -> Term tyname name uni fun a
strictifyBindingsStep :: forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique, Eq name) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
strictifyBindingsStep BuiltinsInfo uni fun
binfo VarsInfo tyname name uni a
vinfo = \case
    Let a
a Recursivity
s NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
t -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
s ((Binding tyname name uni fun a -> Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding tyname name uni fun a -> Binding tyname name uni fun a
strictifyBinding NonEmpty (Binding tyname name uni fun a)
bs) Term tyname name uni fun a
t
      where
        strictifyBinding :: Binding tyname name uni fun a -> Binding tyname name uni fun a
strictifyBinding (TermBind a
x Strictness
NonStrict VarDecl tyname name uni a
vd Term tyname name uni fun a
rhs)
          | BuiltinsInfo uni fun
-> VarsInfo tyname name uni a -> Term tyname name uni fun a -> Bool
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a -> Term tyname name uni fun a -> Bool
isPure BuiltinsInfo uni fun
binfo VarsInfo tyname name uni a
vinfo Term tyname name uni fun a
rhs = a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
Strict VarDecl tyname name uni a
vd Term tyname name uni fun a
rhs
          | name -> Term tyname name uni fun a -> Bool
forall tyname name (uni :: * -> *) fun a.
Eq name =>
name -> Term tyname name uni fun a -> Bool
isStrictIn (VarDecl tyname name uni a
vd VarDecl tyname name uni a
-> Getting name (VarDecl tyname name uni a) name -> name
forall s a. s -> Getting a s a -> a
^. Getting name (VarDecl tyname name uni a) name
forall tyname name1 (uni :: * -> *) ann name2 (f :: * -> *).
Functor f =>
(name1 -> f name2)
-> VarDecl tyname name1 uni ann -> f (VarDecl tyname name2 uni ann)
PLC.varDeclName) Term tyname name uni fun a
t = a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
Strict VarDecl tyname name uni a
vd Term tyname name uni fun a
rhs
        strictifyBinding Binding tyname name uni fun a
b = Binding tyname name uni fun a
b
    Term tyname name uni fun a
t                                    -> Term tyname name uni fun a
t

strictifyBindings
    :: (ToBuiltinMeaning uni fun, PLC.HasUnique name PLC.TermUnique
    , PLC.HasUnique tyname PLC.TypeUnique, Eq name)
    => BuiltinsInfo uni fun
    -> Term tyname name uni fun a
    -> Term tyname name uni fun a
strictifyBindings :: forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique,
 HasUnique tyname TypeUnique, Eq name) =>
BuiltinsInfo uni fun
-> Term tyname name uni fun a -> Term tyname name uni fun a
strictifyBindings BuiltinsInfo uni fun
binfo Term tyname name uni fun a
term =
  ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf
    ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms
    (BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique, Eq name) =>
BuiltinsInfo uni fun
-> VarsInfo tyname name uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
strictifyBindingsStep BuiltinsInfo uni fun
binfo (Term tyname name uni fun a -> VarsInfo tyname name uni a
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
Term tyname name uni fun a -> VarsInfo tyname name uni a
termVarInfo Term tyname name uni fun a
term))
    Term tyname name uni fun a
term

strictifyBindingsPass ::
    forall m uni fun a.
    (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m) =>
    TC.PirTCConfig uni fun ->
    BuiltinsInfo uni fun ->
    Pass m TyName Name uni fun a
strictifyBindingsPass :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun
-> BuiltinsInfo uni fun -> Pass m TyName Name uni fun a
strictifyBindingsPass PirTCConfig uni fun
tcconfig BuiltinsInfo uni fun
binfo =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"strictify bindings" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuiltinsInfo uni fun
-> Term TyName Name uni fun a -> Term TyName Name uni fun a
forall (uni :: * -> *) fun name tyname a.
(ToBuiltinMeaning uni fun, HasUnique name TermUnique,
 HasUnique tyname TypeUnique, Eq name) =>
BuiltinsInfo uni fun
-> Term tyname name uni fun a -> Term tyname name uni fun a
strictifyBindings BuiltinsInfo uni fun
binfo)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]