{-# LANGUAGE LambdaCase #-}

{-|
A trivial simplification that merges adjacent non-recursive let terms. -}
module PlutusIR.Transform.LetMerge
  ( letMerge
  , letMergePass
  ) where

import PlutusIR

import Control.Lens (transformOf)
import PlutusCore qualified as PLC
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

{-|
A single non-recursive application of let-merging cancellation. -}
letMergeStep
  :: Term tyname name uni fun a
  -> Term tyname name uni fun a
letMergeStep :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMergeStep = \case
  Let a
a Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs (Let a
_ Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs' Term tyname name uni fun a
t) -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
NonRec (NonEmpty (Binding tyname name uni fun a)
bs NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
forall a. Semigroup a => a -> a -> a
<> NonEmpty (Binding tyname name uni fun a)
bs') Term tyname name uni fun a
t
  Term tyname name uni fun a
t -> Term tyname name uni fun a
t

{-|
Recursively apply let merging cancellation. -}
letMerge
  :: Term tyname name uni fun a
  -> Term tyname name uni fun a
letMerge :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMerge = ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMergeStep

letMergePass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
letMergePass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
letMergePass PirTCConfig uni fun
tcconfig = String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
simplePass String
"let merge" PirTCConfig uni fun
tcconfig Term TyName Name uni fun a -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
letMerge