-- editorconfig-checker-disable-file
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}
module PlutusIR.Transform.RecSplit
    (recSplit, recSplitPass) where

import PlutusCore.Name.Unique qualified as PLC
import PlutusIR
import PlutusIR.Subst

import Algebra.Graph.AdjacencyMap qualified as AM
import Algebra.Graph.AdjacencyMap.Algorithm qualified as AM hiding (isAcyclic)
import Algebra.Graph.NonEmpty.AdjacencyMap qualified as AMN
import Algebra.Graph.ToGraph (isAcyclic)
import Control.Lens
import Data.Either
import Data.Foldable qualified as Foldable (foldl')
import Data.List (nub)
import Data.List.NonEmpty qualified as NE
import Data.Map qualified as M
import Data.Semigroup.Foldable
import Data.Set qualified as S
import Data.Set.Lens (setOf)
import PlutusCore qualified as PLC
import PlutusIR.MkPir (mkLet)
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC
import PlutusPrelude ((<^>))

{- Note [LetRec splitting pass]

This pass can achieve two things:

- turn recursive let-bindings which are not really recursive into non-recursive let-bindings.
- break down letrec groups into smaller ones, based on the dependencies of the group's bindings.

This pass examines a single letrec group at a time
and maybe splits the group into sub-letgroups (rec or nonrec).

Invariants of the pass:

- Preserves the well-scopedness of the term.
- Does not turn an out-of-scope term into well-scoped.
- Does not place/move the sub-letgroups into locations other than the original letrec location (hole).
- Does not touch let-nonrec groups.

The (a) grouping into sub-letgroups
and the (b) order of appearance of these sub-groups inside the result term,
is determined by locally constructing at each reclet-group location,
a dependency graph between the bindings of the original let.

The created sub-letgroups will either be
(a) let-rec with 2 or more bindings
(b) let-nonrec with a single binding

Currently the implementation relies on 'Unique's, so there is the assumption of global uniqueness of the input term.
However, the algorithm could be changed to work without this assumption (has not been tested).
-}

{- Note [Principal id]
The algorihtm identifies & stores bindings and their corresponding rhs'es in some intermediate tables.
To identify/store each binding to such tables, we need to "key" them by a single unique identifier.

For term bindings and type bindings this is easily achieved by using the single introduced name or tyname as "the key" (principal id).

Datatype bindings, however, introduce multiple names and tynames (i.e. type-constructor, type args, destructor, data-constructors)
and the 'principal' function arbitrarily chooses between one of these introduced names/tynames of the databind
to represent the "principal" id of the whole datatype binding so it can be used as "the key".
-}

recSplitPass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
recSplitPass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
recSplitPass PirTCConfig uni fun
tcconfig = String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
simplePass String
"recursive let split" PirTCConfig uni fun
tcconfig Term TyName Name uni fun a -> Term TyName Name uni fun a
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplit

{-|
Apply letrec splitting, recursively in bottom-up fashion.
-}
recSplit :: forall uni fun a name tyname.
           (PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
         => Term tyname name uni fun a
         -> Term tyname name uni fun a
recSplit :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplit = ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep

{-|
Apply splitting for a single letrec group.
-}
recSplitStep :: forall uni fun a name tyname.
               (PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
             => Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Term tyname name uni fun a -> Term tyname name uni fun a
recSplitStep = \case
    -- See Note [LetRec splitting pass]
    Let a
a Recursivity
Rec NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
t ->
        let -- a table from principal id to the its corresponding 'Binding'
            bindingsTable :: M.Map PLC.Unique (Binding tyname name uni fun a)
            bindingsTable :: Map Unique (Binding tyname name uni fun a)
bindingsTable = [(Unique, Binding tyname name uni fun a)]
-> Map Unique (Binding tyname name uni fun a)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Unique, Binding tyname name uni fun a)]
 -> Map Unique (Binding tyname name uni fun a))
-> (NonEmpty (Unique, Binding tyname name uni fun a)
    -> [(Unique, Binding tyname name uni fun a)])
-> NonEmpty (Unique, Binding tyname name uni fun a)
-> Map Unique (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (Unique, Binding tyname name uni fun a)
-> [(Unique, Binding tyname name uni fun a)]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (Unique, Binding tyname name uni fun a)
 -> Map Unique (Binding tyname name uni fun a))
-> NonEmpty (Unique, Binding tyname name uni fun a)
-> Map Unique (Binding tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ (Binding tyname name uni fun a
 -> (Unique, Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (Unique, Binding tyname name uni fun a)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ Binding tyname name uni fun a
b -> (Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b, Binding tyname name uni fun a
b)) NonEmpty (Binding tyname name uni fun a)
bs
            hereSccs :: [AdjacencyMap Unique]
hereSccs =
                       [AdjacencyMap Unique]
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
-> [AdjacencyMap Unique]
forall b a. b -> Either a b -> b
fromRight (String -> [AdjacencyMap Unique]
forall a. HasCallStack => String -> a
error String
"Cycle detected in the scc-graph. This shouldn't happen in the first place.")
                       -- we take the topological sort (for the correct order)
                       -- from the SCCs (for the correct grouping) of the local dep-graph
                       (Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
 -> [AdjacencyMap Unique])
-> (AdjacencyMap Unique
    -> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique])
-> AdjacencyMap Unique
-> [AdjacencyMap Unique]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AdjacencyMap (AdjacencyMap Unique)
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
forall a. Ord a => AdjacencyMap a -> Either (Cycle a) [a]
AM.topSort (AdjacencyMap (AdjacencyMap Unique)
 -> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique])
-> (AdjacencyMap Unique -> AdjacencyMap (AdjacencyMap Unique))
-> AdjacencyMap Unique
-> Either (Cycle (AdjacencyMap Unique)) [AdjacencyMap Unique]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AdjacencyMap Unique -> AdjacencyMap (AdjacencyMap Unique)
forall a. Ord a => AdjacencyMap a -> AdjacencyMap (AdjacencyMap a)
AM.scc (AdjacencyMap Unique -> [AdjacencyMap Unique])
-> AdjacencyMap Unique -> [AdjacencyMap Unique]
forall a b. (a -> b) -> a -> b
$ NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
buildLocalDepGraph NonEmpty (Binding tyname name uni fun a)
bs

            genLetFromScc :: Term tyname name uni fun a
-> AdjacencyMap Unique -> Term tyname name uni fun a
genLetFromScc Term tyname name uni fun a
acc AdjacencyMap Unique
scc = a
-> Recursivity
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a tyname name (uni :: * -> *) fun.
a
-> Recursivity
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Term tyname name uni fun a
mkLet a
a
                (if AdjacencyMap Unique -> Bool
forall t. (ToGraph t, Ord (ToVertex t)) => t -> Bool
isAcyclic AdjacencyMap Unique
scc then Recursivity
NonRec else Recursivity
Rec)
                (Map Unique (Binding tyname name uni fun a)
-> [Binding tyname name uni fun a]
forall k a. Map k a -> [a]
M.elems (Map Unique (Binding tyname name uni fun a)
 -> [Binding tyname name uni fun a])
-> (Set Unique -> Map Unique (Binding tyname name uni fun a))
-> Set Unique
-> [Binding tyname name uni fun a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Unique (Binding tyname name uni fun a)
-> Set Unique -> Map Unique (Binding tyname name uni fun a)
forall k a. Ord k => Map k a -> Set k -> Map k a
M.restrictKeys Map Unique (Binding tyname name uni fun a)
bindingsTable (Set Unique -> [Binding tyname name uni fun a])
-> Set Unique -> [Binding tyname name uni fun a]
forall a b. (a -> b) -> a -> b
$ AdjacencyMap Unique -> Set Unique
forall a. AdjacencyMap a -> Set a
AMN.vertexSet AdjacencyMap Unique
scc)
                Term tyname name uni fun a
acc
        in (Term tyname name uni fun a
 -> AdjacencyMap Unique -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> [AdjacencyMap Unique]
-> Term tyname name uni fun a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' Term tyname name uni fun a
-> AdjacencyMap Unique -> Term tyname name uni fun a
genLetFromScc Term tyname name uni fun a
t [AdjacencyMap Unique]
hereSccs
    Term tyname name uni fun a
t  -> Term tyname name uni fun a
t

{-|
It constructs a dependency graph for the currently-examined let-group.

The vertices of this graph are the bindings of this let-group, and the edges,
dependencies between those bindings.

This local graph may contain loops:
- A "self-edge" indicates a self-recursive binding.
- Any other loop indicates mutual-recursive bindings.
-}
buildLocalDepGraph :: forall uni fun a name tyname.
                     (PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
                   => NE.NonEmpty (Binding tyname name uni fun a) -> AM.AdjacencyMap PLC.Unique
buildLocalDepGraph :: forall (uni :: * -> *) fun a name tyname.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
NonEmpty (Binding tyname name uni fun a) -> AdjacencyMap Unique
buildLocalDepGraph NonEmpty (Binding tyname name uni fun a)
bs =
    -- join together
    [AdjacencyMap Unique] -> AdjacencyMap Unique
forall a. Ord a => [AdjacencyMap a] -> AdjacencyMap a
AM.overlays ([AdjacencyMap Unique] -> AdjacencyMap Unique)
-> (NonEmpty (AdjacencyMap Unique) -> [AdjacencyMap Unique])
-> NonEmpty (AdjacencyMap Unique)
-> AdjacencyMap Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (AdjacencyMap Unique) -> [AdjacencyMap Unique]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty (AdjacencyMap Unique) -> AdjacencyMap Unique)
-> NonEmpty (AdjacencyMap Unique) -> AdjacencyMap Unique
forall a b. (a -> b) -> a -> b
$ (Binding tyname name uni fun a -> AdjacencyMap Unique)
-> NonEmpty (Binding tyname name uni fun a)
-> NonEmpty (AdjacencyMap Unique)
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Binding tyname name uni fun a -> AdjacencyMap Unique
bindingSubGraph NonEmpty (Binding tyname name uni fun a)
bs
    where
      -- a map of a all introduced binding ids of this letgroup to their belonging principal id
      idTable :: M.Map PLC.Unique PLC.Unique
      idTable :: Map Unique Unique
idTable = (Binding tyname name uni fun a -> Map Unique Unique)
-> NonEmpty (Binding tyname name uni fun a) -> Map Unique Unique
forall m a. Semigroup m => (a -> m) -> NonEmpty a -> m
forall (t :: * -> *) m a.
(Foldable1 t, Semigroup m) =>
(a -> m) -> t a -> m
foldMap1 (\ Binding tyname name uni fun a
b -> [(Unique, Unique)] -> Map Unique Unique
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ((Unique -> (Unique, Unique)) -> [Unique] -> [(Unique, Unique)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b) ([Unique] -> [(Unique, Unique)]) -> [Unique] -> [(Unique, Unique)]
forall a b. (a -> b) -> a -> b
$ Binding tyname name uni fun a
bBinding tyname name uni fun a
-> Getting (Endo [Unique]) (Binding tyname name uni fun a) Unique
-> [Unique]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..Getting (Endo [Unique]) (Binding tyname name uni fun a) Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Traversal1' (Binding tyname name uni fun a) Unique
Traversal1' (Binding tyname name uni fun a) Unique
bindingIds)) NonEmpty (Binding tyname name uni fun a)
bs

      -- Given a binding, it intersects the free uniques of the binding,
      -- with the introduced uniques of the current let group (all bindings).
      -- The result of this intersection is the "local" dependencies of the binding to other
      -- "sibling" bindings of this let group or to itself (if self-recursive).
      -- It returns a graph which connects this binding to all of its calculated "local" dependencies.
      bindingSubGraph :: Binding tyname name uni fun a -> AM.AdjacencyMap PLC.Unique
      bindingSubGraph :: Binding tyname name uni fun a -> AdjacencyMap Unique
bindingSubGraph Binding tyname name uni fun a
b =
          -- the free uniques (variables or tyvariables) that occur inside this binding
          -- Special case for datatype bindings:
          -- To find out if the binding is self-recursive,
          -- we treat it like it was originally belonging to a let-nonrec (`ftvBinding NonRec`).
          -- Then, if it the datatype is indeed self-recursive, the call to `ftvBinding NonRec` will return
          -- its typeconstructor as free.
          let freeUniques :: Set Unique
freeUniques = Getting (Set Unique) (Binding tyname name uni fun a) Unique
-> Binding tyname name uni fun a -> Set Unique
forall a s. Getting (Set a) s a -> s -> Set a
setOf ((name -> f name)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall name tyname (uni :: * -> *) fun ann.
HasUnique name TermUnique =>
Traversal' (Binding tyname name uni fun ann) name
Traversal' (Binding tyname name uni fun a) name
fvBinding ((name -> f name)
 -> Binding tyname name uni fun a
 -> f (Binding tyname name uni fun a))
-> ((Unique -> f Unique) -> name -> f name)
-> (Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unique -> f Unique) -> name -> f name
forall name unique. HasUnique name unique => Lens' name Unique
Lens' name Unique
PLC.theUnique (forall {f :: * -> *}.
 (Contravariant f, Applicative f) =>
 (Unique -> f Unique)
 -> Binding tyname name uni fun a
 -> f (Binding tyname name uni fun a))
-> (forall {f :: * -> *}.
    (Contravariant f, Applicative f) =>
    (Unique -> f Unique)
    -> Binding tyname name uni fun a
    -> f (Binding tyname name uni fun a))
-> forall {f :: * -> *}.
   (Contravariant f, Applicative f) =>
   (Unique -> f Unique)
   -> Binding tyname name uni fun a
   -> f (Binding tyname name uni fun a)
forall s a. Fold s a -> Fold s a -> Fold s a
<^> Recursivity -> Traversal' (Binding tyname name uni fun a) tyname
forall tyname name (uni :: * -> *) fun ann.
HasUnique tyname TypeUnique =>
Recursivity -> Traversal' (Binding tyname name uni fun ann) tyname
ftvBinding Recursivity
NonRec ((tyname -> f tyname)
 -> Binding tyname name uni fun a
 -> f (Binding tyname name uni fun a))
-> ((Unique -> f Unique) -> tyname -> f tyname)
-> (Unique -> f Unique)
-> Binding tyname name uni fun a
-> f (Binding tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unique -> f Unique) -> tyname -> f tyname
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique) Binding tyname name uni fun a
b
              -- the "local" dependencies
              occursIds :: Set Unique
occursIds = Map Unique Unique -> Set Unique
forall k a. Map k a -> Set k
M.keysSet Map Unique Unique
idTable Set Unique -> Set Unique -> Set Unique
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set Unique
freeUniques
              -- maps the ids of the "local" dependencies to their principal uniques.
              -- See Note [Principal id]
              occursPrincipals :: [Unique]
occursPrincipals = [Unique] -> [Unique]
forall a. Eq a => [a] -> [a]
nub ([Unique] -> [Unique]) -> [Unique] -> [Unique]
forall a b. (a -> b) -> a -> b
$ Map Unique Unique -> [Unique]
forall k a. Map k a -> [a]
M.elems (Map Unique Unique -> [Unique]) -> Map Unique Unique -> [Unique]
forall a b. (a -> b) -> a -> b
$ Map Unique Unique
idTable Map Unique Unique -> Set Unique -> Map Unique Unique
forall k a. Ord k => Map k a -> Set k -> Map k a
`M.restrictKeys` Set Unique
occursIds
          in AdjacencyMap Unique -> AdjacencyMap Unique -> AdjacencyMap Unique
forall a.
Ord a =>
AdjacencyMap a -> AdjacencyMap a -> AdjacencyMap a
AM.connect (Unique -> AdjacencyMap Unique
forall a. a -> AdjacencyMap a
AM.vertex (Unique -> AdjacencyMap Unique) -> Unique -> AdjacencyMap Unique
forall a b. (a -> b) -> a -> b
$ Binding tyname name uni fun a -> Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal Binding tyname name uni fun a
b) ([Unique] -> AdjacencyMap Unique
forall a. Ord a => [a] -> AdjacencyMap a
AM.vertices [Unique]
occursPrincipals)


{-|
A function that returns a single 'Unique' for a particular binding.
See Note [Principal id]
-}
principal :: (PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
            => Binding tyname name uni fun a
            -> PLC.Unique
principal :: forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Binding tyname name uni fun a -> Unique
principal = \case TermBind a
_ Strictness
_ (VarDecl a
_ name
n Type tyname uni a
_) Term tyname name uni fun a
_                             -> name
nname -> Getting Unique name Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique name Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' name Unique
PLC.theUnique
                  TypeBind a
_ (TyVarDecl a
_ tyname
n Kind a
_) Type tyname uni a
_                             -> tyname
n tyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique
                  -- arbitrary: uses the type constructor's unique as the principal unique of this data binding group
                  DatatypeBind a
_ (Datatype a
_ (TyVarDecl a
_ tyname
tyConstr Kind a
_) [TyVarDecl tyname a]
_ name
_ [VarDecl tyname name uni a]
_) -> tyname
tyConstr tyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^. Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
Lens' tyname Unique
PLC.theUnique