{-# LANGUAGE LambdaCase #-}
{-|
A trivial simplification that merges adjacent non-recursive let terms.
-}
module PlutusIR.Transform.LetMerge (
  letMerge
  , letMergePass
  ) where

import PlutusIR

import Control.Lens (transformOf)
import PlutusCore qualified as PLC
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

{-|
A single non-recursive application of let-merging cancellation.
-}
letMergeStep
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
letMergeStep :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMergeStep = \case
    Let a
a Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs (Let a
_ Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs' Term tyname name uni fun a
t) -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
NonRec (NonEmpty (Binding tyname name uni fun a)
bs NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
forall a. Semigroup a => a -> a -> a
<> NonEmpty (Binding tyname name uni fun a)
bs') Term tyname name uni fun a
t
    Term tyname name uni fun a
t                                    -> Term tyname name uni fun a
t

{-|
Recursively apply let merging cancellation.
-}
letMerge
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
letMerge :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMerge = ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMergeStep

letMergePass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
letMergePass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
letMergePass PirTCConfig uni fun
tcconfig = String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
simplePass String
"let merge" PirTCConfig uni fun
tcconfig Term TyName Name uni fun a -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMerge