{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module PlutusIR.Transform.RecSplit
(recSplit, recSplitPass) where
import PlutusCore.Name.Unique qualified as PLC
import PlutusIR
import PlutusIR.Subst
import Algebra.Graph.AdjacencyMap qualified as AM
import Algebra.Graph.AdjacencyMap.Algorithm qualified as AM hiding (isAcyclic)
import Algebra.Graph.NonEmpty.AdjacencyMap qualified as AMN
import Algebra.Graph.ToGraph (isAcyclic)
import Control.Lens
import Data.Either
import Data.Foldable qualified as Foldable (foldl')
import Data.List (nub)
import Data.List.NonEmpty qualified as NE
import Data.Map qualified as M
import Data.Semigroup.Foldable
import Data.Set qualified as S
import Data.Set.Lens (setOf)
import PlutusCore qualified as PLC
import PlutusIR.MkPir (mkLet)
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC
import PlutusPrelude ((<^>))
recSplitPass
:: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m)
=> TC.PirTCConfig uni fun
-> Pass m TyName Name uni fun a
recSplitPass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
recSplitPass PirTCConfig uni fun
tcconfig = String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
simplePass String
"recursive let split" PirTCConfig uni fun
tcconfig Term TyName Name uni fun a -> Term TyName Name uni fun a
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplit
recSplit :: forall uni fun a name tyname.
(PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
=> Term tyname name uni fun a
-> Term tyname name uni fun a
recSplit :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplit = ASetter
(Term tyname name uni fun a)
(Term tyname name uni fun a)
(Term tyname name uni fun a)
(Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
(Term tyname name uni fun a)
(Term tyname name uni fun a)
(Term tyname name uni fun a)
(Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep
recSplitStep :: forall uni fun a name tyname.
(PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
=> Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep = \case
Let a
a Recursivity
Rec NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
t ->
let
bindingsTable :: M.Map PLC.Unique (Binding tyname name uni fun a)
bindingsTable :: Map Unique (Binding tyname name uni fun a)
bindingsTable = [(Unique, Binding tyname name uni fun a)]
-> Map Unique (Binding tyname name uni fun a)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Unique, Binding tyname name uni fun a)]
-> Map Unique (Binding tyname name uni fun a))
-> (NonEmpty (Unique, Binding tyname name uni fun a)
-> [(Unique, Binding tyname name uni fun a)])
-> NonEmpty (Unique, Binding tyname name uni fun a)
-> Map Unique (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (Unique, Binding tyname name uni fun a)
-> [(Unique, Binding tyname name uni fun a)]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (Unique, Binding tyname name uni fun a)
-> Map Unique (Binding tyname name uni fun a))
-> NonEmpty (Unique, Binding tyname name uni fun a)
-> Map Unique (Binding tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ (Binding tyname name uni fun a
-> (Unique, Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Unique, Binding tyname name uni fun a)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ Binding tyname name uni fun a
b -> (Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b, Binding tyname name uni fun a
b)) NonEmpty (Binding tyname name uni fun a)
bs
hereSccs :: [AdjacencyMap Unique]
hereSccs =
[AdjacencyMap Unique]
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
-> [AdjacencyMap Unique]
forall b a. b -> Either a b -> b
fromRight (String -> [AdjacencyMap Unique]
forall a. HasCallStack => String -> a
error String
"Cycle detected in the scc-graph. This shouldn't happen in the first place.")
(Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
-> [AdjacencyMap Unique])
-> (AdjacencyMap Unique
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique])
-> AdjacencyMap Unique
-> [AdjacencyMap Unique]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AdjacencyMap (AdjacencyMap Unique)
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
forall a. Ord a => AdjacencyMap a -> Either (Cycle a) [a]
AM.topSort (AdjacencyMap (AdjacencyMap Unique)
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique])
-> (AdjacencyMap Unique -> AdjacencyMap (AdjacencyMap Unique))
-> AdjacencyMap Unique
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AdjacencyMap Unique -> AdjacencyMap (AdjacencyMap Unique)
forall a. Ord a => AdjacencyMap a -> AdjacencyMap (AdjacencyMap a)
AM.scc (AdjacencyMap Unique -> [AdjacencyMap Unique])
-> AdjacencyMap Unique -> [AdjacencyMap Unique]
forall a b. (a -> b) -> a -> b
$ NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
buildLocalDepGraph NonEmpty (Binding tyname name uni fun a)
bs
genLetFromScc :: Term tyname name uni fun a
-> AdjacencyMap Unique -> Term tyname name uni fun a
genLetFromScc Term tyname name uni fun a
acc AdjacencyMap Unique
scc = a
-> Recursivity
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a tyname name (uni :: * -> *) fun.
a
-> Recursivity
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Term tyname name uni fun a
mkLet a
a
(if AdjacencyMap Unique -> Bool
forall t. (ToGraph t, Ord (ToVertex t)) => t -> Bool
isAcyclic AdjacencyMap Unique
scc then Recursivity
NonRec else Recursivity
Rec)
(Map Unique (Binding tyname name uni fun a)
-> [Binding tyname name uni fun a]
forall k a. Map k a -> [a]
M.elems (Map Unique (Binding tyname name uni fun a)
-> [Binding tyname name uni fun a])
-> (Set Unique -> Map Unique (Binding tyname name uni fun a))
-> Set Unique
-> [Binding tyname name uni fun a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Unique (Binding tyname name uni fun a)
-> Set Unique -> Map Unique (Binding tyname name uni fun a)
forall k a. Ord k => Map k a -> Set k -> Map k a
M.restrictKeys Map Unique (Binding tyname name uni fun a)
bindingsTable (Set Unique -> [Binding tyname name uni fun a])
-> Set Unique -> [Binding tyname name uni fun a]
forall a b. (a -> b) -> a -> b
$ AdjacencyMap Unique -> Set Unique
forall a. AdjacencyMap a -> Set a
AMN.vertexSet AdjacencyMap Unique
scc)
Term tyname name uni fun a
acc
in (Term tyname name uni fun a
-> AdjacencyMap Unique -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> [AdjacencyMap Unique]
-> Term tyname name uni fun a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' Term tyname name uni fun a
-> AdjacencyMap Unique -> Term tyname name uni fun a
genLetFromScc Term tyname name uni fun a
t [AdjacencyMap Unique]
hereSccs
Term tyname name uni fun a
t -> Term tyname name uni fun a
t
buildLocalDepGraph :: forall uni fun a name tyname.
(PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
=> NE.NonEmpty (Binding tyname name uni fun a) -> AM.AdjacencyMap PLC.Unique
buildLocalDepGraph :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
buildLocalDepGraph NonEmpty (Binding tyname name uni fun a)
bs =
[AdjacencyMap Unique] -> AdjacencyMap Unique
forall a. Ord a => [AdjacencyMap a] -> AdjacencyMap a
AM.overlays ([AdjacencyMap Unique] -> AdjacencyMap Unique)
-> (NonEmpty (AdjacencyMap Unique) -> [AdjacencyMap Unique])
-> NonEmpty (AdjacencyMap Unique)
-> AdjacencyMap Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (AdjacencyMap Unique) -> [AdjacencyMap Unique]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (AdjacencyMap Unique) -> AdjacencyMap Unique)
-> NonEmpty (AdjacencyMap Unique) -> AdjacencyMap Unique
forall a b. (a -> b) -> a -> b
$ (Binding tyname name uni fun a -> AdjacencyMap Unique)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (AdjacencyMap Unique)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding tyname name uni fun a -> AdjacencyMap Unique
bindingSubGraph NonEmpty (Binding tyname name uni fun a)
bs
where
idTable :: M.Map PLC.Unique PLC.Unique
idTable :: Map Unique Unique
idTable = (Binding tyname name uni fun a -> Map Unique Unique)
-> NonEmpty (Binding tyname name uni fun a) -> Map Unique Unique
forall m a. Semigroup m => (a -> m) -> NonEmpty a -> m
forall (t :: * -> *) m a.
(Foldable1 t, Semigroup m) =>
(a -> m) -> t a -> m
foldMap1 (\ Binding tyname name uni fun a
b -> [(Unique, Unique)] -> Map Unique Unique
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ((Unique -> (Unique, Unique)) -> [Unique] -> [(Unique, Unique)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b) ([Unique] -> [(Unique, Unique)]) -> [Unique] -> [(Unique, Unique)]
forall a b. (a -> b) -> a -> b
$ Binding tyname name uni fun a
bBinding tyname name uni fun a
-> Getting (Endo [Unique]) (Binding tyname name uni fun a) Unique
-> [Unique]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [Unique]) (Binding tyname name uni fun a) Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Traversal1' (Binding tyname name uni fun a) Unique
Traversal1' (Binding tyname name uni fun a) Unique
bindingIds)) NonEmpty (Binding tyname name uni fun a)
bs
bindingSubGraph :: Binding tyname name uni fun a -> AM.AdjacencyMap PLC.Unique
bindingSubGraph :: Binding tyname name uni fun a -> AdjacencyMap Unique
bindingSubGraph Binding tyname name uni fun a
b =
let freeUniques :: Set Unique
freeUniques = Getting (Set Unique) (Binding tyname name uni fun a) Unique
-> Binding tyname name uni fun a -> Set Unique
forall a s. Getting (Set a) s a -> s -> Set a
setOf ((name -> f name)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall name tyname (uni :: * -> *) fun ann.
HasUnique name TermUnique =>
Traversal' (Binding tyname name uni fun ann) name
Traversal' (Binding tyname name uni fun a) name
fvBinding ((name -> f name)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a))
-> ((Unique -> f Unique) -> name -> f name)
-> (Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unique -> f Unique) -> name -> f name
forall name unique. HasUnique name unique => Lens' name Unique
Lens' name Unique
PLC.theUnique (forall {f :: * -> *}.
(Contravariant f, Applicative f) =>
(Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a))
-> (forall {f :: * -> *}.
(Contravariant f, Applicative f) =>
(Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a))
-> forall {f :: * -> *}.
(Contravariant f, Applicative f) =>
(Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall s a. Fold s a -> Fold s a -> Fold s a
<^> Recursivity -> Traversal' (Binding tyname name uni fun a) tyname
forall tyname name (uni :: * -> *) fun ann.
HasUnique tyname TypeUnique =>
Recursivity -> Traversal' (Binding tyname name uni fun ann) tyname
ftvBinding Recursivity
NonRec ((tyname -> f tyname)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a))
-> ((Unique -> f Unique) -> tyname -> f tyname)
-> (Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unique -> f Unique) -> tyname -> f tyname
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique) Binding tyname name uni fun a
b
occursIds :: Set Unique
occursIds = Map Unique Unique -> Set Unique
forall k a. Map k a -> Set k
M.keysSet Map Unique Unique
idTable Set Unique -> Set Unique -> Set Unique
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set Unique
freeUniques
occursPrincipals :: [Unique]
occursPrincipals = [Unique] -> [Unique]
forall a. Eq a => [a] -> [a]
nub ([Unique] -> [Unique]) -> [Unique] -> [Unique]
forall a b. (a -> b) -> a -> b
$ Map Unique Unique -> [Unique]
forall k a. Map k a -> [a]
M.elems (Map Unique Unique -> [Unique]) -> Map Unique Unique -> [Unique]
forall a b. (a -> b) -> a -> b
$ Map Unique Unique
idTable Map Unique Unique -> Set Unique -> Map Unique Unique
forall k a. Ord k => Map k a -> Set k -> Map k a
`M.restrictKeys` Set Unique
occursIds
in AdjacencyMap Unique -> AdjacencyMap Unique -> AdjacencyMap Unique
forall a.
Ord a =>
AdjacencyMap a -> AdjacencyMap a -> AdjacencyMap a
AM.connect (Unique -> AdjacencyMap Unique
forall a. a -> AdjacencyMap a
AM.vertex (Unique -> AdjacencyMap Unique) -> Unique -> AdjacencyMap Unique
forall a b. (a -> b) -> a -> b
$ Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b) ([Unique] -> AdjacencyMap Unique
forall a. Ord a => [a] -> AdjacencyMap a
AM.vertices [Unique]
occursPrincipals)
principal :: (PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
=> Binding tyname name uni fun a
-> PLC.Unique
principal :: forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal = \case TermBind a
_ Strictness
_ (VarDecl a
_ name
n Type tyname uni a
_) Term tyname name uni fun a
_ -> name
nname -> Getting Unique name Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique name Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' name Unique
PLC.theUnique
TypeBind a
_ (TyVarDecl a
_ tyname
n Kind a
_) Type tyname uni a
_ -> tyname
n tyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique
DatatypeBind a
_ (Datatype a
_ (TyVarDecl a
_ tyname
tyConstr Kind a
_) [TyVarDecl tyname a]
_ name
_ [VarDecl tyname name uni a]
_) -> tyname
tyConstr tyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique