{-# LANGUAGE BlockArguments    #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications  #-}
{-# LANGUAGE TypeFamilies      #-}

module UntypedPlutusCore.Transform.Cse (cse) where

import PlutusCore (MonadQuote, Name, Rename, freshName, rename)
import PlutusCore.Builtin (ToBuiltinMeaning (BuiltinSemanticsVariant))
import UntypedPlutusCore.Core
import UntypedPlutusCore.Purity (isWorkFree)
import UntypedPlutusCore.Size (termSize)

import Control.Arrow ((>>>))
import Control.Lens (foldrOf, transformOf)
import Control.Monad (join, void)
import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Trans.Reader (ReaderT (runReaderT), ask, local)
import Control.Monad.Trans.State.Strict (State, evalState, get, put)
import Data.Foldable as Foldable (foldl')
import Data.Hashable (Hashable)
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as Map
import Data.List.Extra (isSuffixOf, sortOn)
import Data.Ord (Down (..))
import Data.Proxy (Proxy (..))
import Data.Traversable (for)
import Data.Tuple.Extra (snd3, thd3)
import PlutusCore.Arity (builtinArity)

{- Note [CSE]

-------------------------------------------------------------------------------
1. Simplifications
-------------------------------------------------------------------------------

This is a simplified (i.e., not fully optimal) implementation of CSE. The two simplifications
we made are:

- No alpha equivalence check, i.e., `\x -> x` and `\y -> y` are considered different expressions.
- The builtin function arity information is approximate: rather than using the accurate arities,
  we simply use the maximum number of arguments applied to a builtin function in the program
  as the builtin function's arity. The arity information is used to determine whether a builtin
  application is possibly saturated.

-------------------------------------------------------------------------------
2. How does it work?
-------------------------------------------------------------------------------

We use the following example to explain how the implementation works:

\x y -> (1+(2+x))
        +
        (case y [ (1+(2+x)) + (3+x)
                , (2+x) + (3+x)
                , 4+x
                ]
        )

The implementation makes several passes on the given term. The first pass collects builtin
arity information as described above.

In the second pass, we assign a unique ID to each `LamAbs`, `Delay`, and each `Case` branch.
Then, we annotate each subterm with a path, consisting of IDs encountered from the root
to that subterm (not including itself). The reason to do this is because `LamAbs`, `Delay`,
and `Case` branches represent places where computation stops, i.e., subexpressions are not
immediately evaluated, and may not be evaluated at all.

In the above example, the ID of `\x` is 0, the ID of `\y` is 1, and the IDs of the
three case branches are 2, 3, 4 (the actual numbers don't matter, as long as they are unique).
The path for the first `1+(2+x)` and the first `2+x` is "0.1"; the path for the second
`1+(2+x)` and the second `2+x` is "0.1.2"; the path for `4+x` is "0.1.4".

In the third pass, we calculate a count for each `(term, path)` pair, where `term` is a
non-workfree term, and `path` is its path. If the same term has two paths, and one is an
ancestor (i.e., prefix) of the other, we increment the count for the ancestor path in both
instances.

In the above example, there are three occurrences of `2+x`, whose paths are "0.1", "0.1.2"
and "0.1.3", respectively. The first path is an ancestor of the latter two. Therefore,
the count for `(2+x, "0.1")` is 3, while the count for `(2+x, "0.1.2")` and `(2+x, "0.1.3")`
is 0. The following all have a count of 1: `(3+x, "0.1.2")`, `(3+x, "0.1.3")` and
`(4+x, "0.1.4")`.

Now, each `(term, path)` pair whose count is greater than 1 is a CSE candidate.
In the above example, the CSE candidates are `(2+x, "0.1")` and `(1+(2+x), "0.1")`.
Note that `3+x` is not a CSE candidate, because it has two paths, and neither has a count
greater than 1. `2+` is also not a CSE candidate, because it is workfree.

The CSE candidates are then processed in descending order of their `termSize`s. For each CSE
candidate, we generate a fresh variable, create a LamAbs for it under its path, and substitute
it for all occurrences in the original term whose paths are descendents (or self) of
the candidate's path. The order is because a bigger expression may contain a small subexpression.

In the above example, we first process CSE candidate `(1+(2+x), "0.1")`. We create a fresh
variable `cse1` for it, perform substitution, and create a `LamAbs` under path "0.1" (i.e., around
the body of `y`). After processing this CSE candidate, the original term becomes

\x y -> (\cse1 -> cse1
                  +
                  (case y [ cse1 + (3+x)
                          , (2+x) + (3+x)
                          , 4+x
                          ]
        ) (1+(2+x))

The second CSE candidate is processed similarly, and the final result is

\x y -> (\cse2 -> (\cse1 -> cse1
                            +
                            (case y [ cse1 + (3+x)
                                    , cse2 + (3+x)
                                    , 4+x
                                    ]
                  ) (1+cse2)
        ) (2+x)

Here's another example:

force (force ifThenElse
         (lessThanEqualsInteger 0 0)
         (delay ((1+2) + (1+2)))
         (delay (1+2))
      )

In this case, the first two occurrences of `1+2` can be CSE'd, but the third occurrence
can not. This is ensured by checking the path when substituting `cse1` for `1+2`. The result is

force (force ifThenElse
         (lessThanEqualsInteger 0 0)
         (delay ((\cse1 -> cse1 + cse1) (1+2))
         (delay (1+2))
      )

-------------------------------------------------------------------------------
3. When should CSE run?
-------------------------------------------------------------------------------

CSE should run for multiple iterations, and should interleave with inlining. The following
example illustrates why:

\x ->
  f
    ((\y -> 1+(y+y)) (0+x))
    ((\z -> 2+(z+z)) (0+x))

There is no inlining opportunity in this term. After the first iteration of CSE, where
the common subepxression is `0+x`, we get:

\x ->
  (\cse1 ->
    f
      ((\y -> 1+(y+y)) cse1)
      ((\z -> 2+(z+z)) cse1)
  ) (0+x)

Now `y` and `z` can be inlined, after which we get

\x ->
  (\cse1 ->
    f
      (1+(cse1+cse1))
      (2+(cse1+cse1))
  ) (0+x)

Now there's a new common subexpression: `cse1+cse1`. So another iteration of CSE is
needed, yielding:

\x ->
  (\cse1 ->
    (\cse2 ->
      f
        (1+cse2)
        (2+cse2)
    ) (cse1+cse1)
  ) (0+x)

With this example in mind, one may be tempted to make CSE part of the simplifier, and simply
run it along with the rest of the simplifier. That is, however, a bad idea. CSE does the reverse
of inlining; inlining tends to expose more optimization opportunities, and conversely, CSE
tends to destroy optimization opportunities. Running CSE on a not-fully-optimized program
may cause many optimization opportunities to be permanently lost. Give it a try if you want
to see how bad it is!

Therefore, this is what we do: first run the simplifier iterations. Then, run the CSE iterations,
interleaving with the simplifier. For example, suppose max-simplifier-iterations-uplc=12, and
max-cse-iterations=4. We first run 12 iterations of the simplifier, then run 4 iterations
of CSE, with a simplifier pass after each iteration of CSE (i.e., the simplifier is run for a
total of 16 times).

Finally, since CSE can change the order or the number of occurrences of effects, it is only run
when conservative optimization is off.
-}

-- | In reverse order, e.g., "1.2.3" is `[3, 2, 1]`.
type Path = [Int]

isAncestorOrSelf :: Path -> Path -> Bool
isAncestorOrSelf :: Path -> Path -> Bool
isAncestorOrSelf = Path -> Path -> Bool
forall a. Eq a => [a] -> [a] -> Bool
isSuffixOf

data CseCandidate uni fun ann = CseCandidate
  { forall (uni :: * -> *) fun ann. CseCandidate uni fun ann -> Name
ccFreshName     :: Name
  , forall (uni :: * -> *) fun ann.
CseCandidate uni fun ann -> Term Name uni fun ()
ccTerm          :: Term Name uni fun ()
  , forall (uni :: * -> *) fun ann.
CseCandidate uni fun ann -> Term Name uni fun (Path, ann)
ccAnnotatedTerm :: Term Name uni fun (Path, ann)
  -- ^ `ccTerm` is needed for equality comparison, while `ccAnnotatedTerm` is needed
  -- for the actual substitution. They are always the same term barring the annotations.
  }

cse ::
  ( MonadQuote m
  , Hashable (Term Name uni fun ())
  , Rename (Term Name uni fun ann)
  , ToBuiltinMeaning uni fun
  ) =>
  BuiltinSemanticsVariant fun ->
  Term Name uni fun ann ->
  m (Term Name uni fun ann)
cse :: forall (m :: * -> *) (uni :: * -> *) fun ann.
(MonadQuote m, Hashable (Term Name uni fun ()),
 Rename (Term Name uni fun ann), ToBuiltinMeaning uni fun) =>
BuiltinSemanticsVariant fun
-> Term Name uni fun ann -> m (Term Name uni fun ann)
cse BuiltinSemanticsVariant fun
builtinSemanticsVariant Term Name uni fun ann
t0 = do
  Term Name uni fun ann
t <- Term Name uni fun ann -> m (Term Name uni fun ann)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term Name uni fun ann -> m (Term Name uni fun ann)
rename Term Name uni fun ann
t0
  let annotated :: Term Name uni fun (Path, ann)
annotated = Term Name uni fun ann -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
Term name uni fun ann -> Term name uni fun (Path, ann)
annotate Term Name uni fun ann
t
      commonSubexprs :: [Term Name uni fun (Path, ann)]
commonSubexprs =
        -- Processed the common subexpressions in descending order of `termSize`.
        -- See Note [CSE].
        (Term Name uni fun (Path, ann) -> Down Size)
-> [Term Name uni fun (Path, ann)]
-> [Term Name uni fun (Path, ann)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Size -> Down Size
forall a. a -> Down a
Down (Size -> Down Size)
-> (Term Name uni fun (Path, ann) -> Size)
-> Term Name uni fun (Path, ann)
-> Down Size
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term Name uni fun (Path, ann) -> Size
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> Size
termSize)
          ([Term Name uni fun (Path, ann)]
 -> [Term Name uni fun (Path, ann)])
-> (HashMap
      (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
    -> [Term Name uni fun (Path, ann)])
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [Term Name uni fun (Path, ann)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Path, Term Name uni fun (Path, ann), Int)
 -> Term Name uni fun (Path, ann))
-> [(Path, Term Name uni fun (Path, ann), Int)]
-> [Term Name uni fun (Path, ann)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Path, Term Name uni fun (Path, ann), Int)
-> Term Name uni fun (Path, ann)
forall a b c. (a, b, c) -> b
snd3
          -- A subexpression is common if the count is greater than 1.
          ([(Path, Term Name uni fun (Path, ann), Int)]
 -> [Term Name uni fun (Path, ann)])
-> (HashMap
      (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
    -> [(Path, Term Name uni fun (Path, ann), Int)])
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [Term Name uni fun (Path, ann)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Path, Term Name uni fun (Path, ann), Int) -> Bool)
-> [(Path, Term Name uni fun (Path, ann), Int)]
-> [(Path, Term Name uni fun (Path, ann), Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ((Path, Term Name uni fun (Path, ann), Int) -> Int)
-> (Path, Term Name uni fun (Path, ann), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Path, Term Name uni fun (Path, ann), Int) -> Int
forall a b c. (a, b, c) -> c
thd3)
          ([(Path, Term Name uni fun (Path, ann), Int)]
 -> [(Path, Term Name uni fun (Path, ann), Int)])
-> (HashMap
      (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
    -> [(Path, Term Name uni fun (Path, ann), Int)])
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [(Path, Term Name uni fun (Path, ann), Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(Path, Term Name uni fun (Path, ann), Int)]]
-> [(Path, Term Name uni fun (Path, ann), Int)]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join
          ([[(Path, Term Name uni fun (Path, ann), Int)]]
 -> [(Path, Term Name uni fun (Path, ann), Int)])
-> (HashMap
      (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
    -> [[(Path, Term Name uni fun (Path, ann), Int)]])
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [(Path, Term Name uni fun (Path, ann), Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap
  (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [[(Path, Term Name uni fun (Path, ann), Int)]]
forall k v. HashMap k v -> [v]
Map.elems
          (HashMap
   (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
 -> [Term Name uni fun (Path, ann)])
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
-> [Term Name uni fun (Path, ann)]
forall a b. (a -> b) -> a -> b
$ BuiltinSemanticsVariant fun
-> Term Name uni fun (Path, ann)
-> HashMap
     (Term Name uni fun ()) [(Path, Term Name uni fun (Path, ann), Int)]
forall name (uni :: * -> *) fun ann.
(Hashable (Term name uni fun ()), ToBuiltinMeaning uni fun) =>
BuiltinSemanticsVariant fun
-> Term name uni fun (Path, ann)
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
countOccs BuiltinSemanticsVariant fun
builtinSemanticsVariant Term Name uni fun (Path, ann)
annotated
  [Term Name uni fun (Path, ann)]
-> Term Name uni fun (Path, ann) -> m (Term Name uni fun ann)
forall (uni :: * -> *) fun ann (m :: * -> *).
(MonadQuote m, Eq (Term Name uni fun ())) =>
[Term Name uni fun (Path, ann)]
-> Term Name uni fun (Path, ann) -> m (Term Name uni fun ann)
mkCseTerm [Term Name uni fun (Path, ann)]
commonSubexprs Term Name uni fun (Path, ann)
annotated

-- | The second pass. See Note [CSE].
annotate :: Term name uni fun ann -> Term name uni fun (Path, ann)
annotate :: forall name (uni :: * -> *) fun ann.
Term name uni fun ann -> Term name uni fun (Path, ann)
annotate = (State Int (Term name uni fun (Path, ann))
 -> Int -> Term name uni fun (Path, ann))
-> Int
-> State Int (Term name uni fun (Path, ann))
-> Term name uni fun (Path, ann)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State Int (Term name uni fun (Path, ann))
-> Int -> Term name uni fun (Path, ann)
forall s a. State s a -> s -> a
evalState Int
0 (State Int (Term name uni fun (Path, ann))
 -> Term name uni fun (Path, ann))
-> (Term name uni fun ann
    -> State Int (Term name uni fun (Path, ann)))
-> Term name uni fun ann
-> Term name uni fun (Path, ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ReaderT Path (StateT Int Identity) (Term name uni fun (Path, ann))
 -> Path -> State Int (Term name uni fun (Path, ann)))
-> Path
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> State Int (Term name uni fun (Path, ann))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> Path -> State Int (Term name uni fun (Path, ann))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT [] (ReaderT Path (StateT Int Identity) (Term name uni fun (Path, ann))
 -> State Int (Term name uni fun (Path, ann)))
-> (Term name uni fun ann
    -> ReaderT
         Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> Term name uni fun ann
-> State Int (Term name uni fun (Path, ann))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go
  where
    -- The integer state is the highest ID assigned so far.
    -- The reader context is the current path.
    go :: Term name uni fun ann -> ReaderT Path (State Int) (Term name uni fun (Path, ann))
    go :: forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
t = do
      Path
path <- ReaderT Path (StateT Int Identity) Path
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
      case Term name uni fun ann
t of
        Apply ann
ann Term name uni fun ann
fun Term name uni fun ann
arg -> (Path, ann)
-> Term name uni fun (Path, ann)
-> Term name uni fun (Path, ann)
-> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply (Path
path, ann
ann) (Term name uni fun (Path, ann)
 -> Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path
     (StateT Int Identity)
     (Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
fun ReaderT
  Path
  (StateT Int Identity)
  (Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b.
ReaderT Path (StateT Int Identity) (a -> b)
-> ReaderT Path (StateT Int Identity) a
-> ReaderT Path (StateT Int Identity) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
arg
        Force ann
ann Term name uni fun ann
body -> (Path, ann)
-> Term name uni fun (Path, ann) -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force (Path
path, ann
ann) (Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
body
        Constr ann
ann Word64
i [Term name uni fun ann]
args -> (Path, ann)
-> Word64
-> [Term name uni fun (Path, ann)]
-> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Word64 -> [Term name uni fun ann] -> Term name uni fun ann
Constr (Path
path, ann
ann) Word64
i ([Term name uni fun (Path, ann)] -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) [Term name uni fun (Path, ann)]
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Term name uni fun ann
 -> ReaderT
      Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> [Term name uni fun ann]
-> ReaderT
     Path (StateT Int Identity) [Term name uni fun (Path, ann)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go [Term name uni fun ann]
args
        Constant ann
ann Some (ValueOf uni)
val -> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a. a -> ReaderT Path (StateT Int Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun (Path, ann)
 -> ReaderT
      Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b. (a -> b) -> a -> b
$ (Path, ann) -> Some (ValueOf uni) -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Some (ValueOf uni) -> Term name uni fun ann
Constant (Path
path, ann
ann) Some (ValueOf uni)
val
        Error ann
ann -> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a. a -> ReaderT Path (StateT Int Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun (Path, ann)
 -> ReaderT
      Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b. (a -> b) -> a -> b
$ (Path, ann) -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann. ann -> Term name uni fun ann
Error (Path
path, ann
ann)
        Builtin ann
ann fun
fun -> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a. a -> ReaderT Path (StateT Int Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun (Path, ann)
 -> ReaderT
      Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b. (a -> b) -> a -> b
$ (Path, ann) -> fun -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> fun -> Term name uni fun ann
Builtin (Path
path, ann
ann) fun
fun
        Var ann
ann name
name -> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a. a -> ReaderT Path (StateT Int Identity) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun (Path, ann)
 -> ReaderT
      Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> Term name uni fun (Path, ann)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b. (a -> b) -> a -> b
$ (Path, ann) -> name -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann
Var (Path
path, ann
ann) name
name
        LamAbs ann
ann name
n Term name uni fun ann
body -> do
          Int
freshId <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int)
-> ReaderT Path (StateT Int Identity) Int
-> ReaderT Path (StateT Int Identity) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State Int Int -> ReaderT Path (StateT Int Identity) Int
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Int Int
forall (m :: * -> *) s. Monad m => StateT s m s
get
          State Int () -> ReaderT Path (StateT Int Identity) ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Int () -> ReaderT Path (StateT Int Identity) ())
-> State Int () -> ReaderT Path (StateT Int Identity) ()
forall a b. (a -> b) -> a -> b
$ Int -> State Int ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Int
freshId
          (Path, ann)
-> name
-> Term name uni fun (Path, ann)
-> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs (Path
path, ann
ann) name
n (Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Path -> Path)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
local (Int
freshId Int -> Path -> Path
forall a. a -> [a] -> [a]
:) (Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
body)
        Delay ann
ann Term name uni fun ann
body -> do
          Int
freshId <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int)
-> ReaderT Path (StateT Int Identity) Int
-> ReaderT Path (StateT Int Identity) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State Int Int -> ReaderT Path (StateT Int Identity) Int
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Int Int
forall (m :: * -> *) s. Monad m => StateT s m s
get
          State Int () -> ReaderT Path (StateT Int Identity) ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Int () -> ReaderT Path (StateT Int Identity) ())
-> State Int () -> ReaderT Path (StateT Int Identity) ()
forall a b. (a -> b) -> a -> b
$ Int -> State Int ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Int
freshId
          (Path, ann)
-> Term name uni fun (Path, ann) -> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay (Path
path, ann
ann) (Term name uni fun (Path, ann) -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Path -> Path)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
local (Int
freshId Int -> Path -> Path
forall a. a -> [a] -> [a]
:) (Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
body)
        Case ann
ann Term name uni fun ann
scrut Vector (Term name uni fun ann)
branches ->
          (Path, ann)
-> Term name uni fun (Path, ann)
-> Vector (Term name uni fun (Path, ann))
-> Term name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case (Path
path, ann
ann)
            (Term name uni fun (Path, ann)
 -> Vector (Term name uni fun (Path, ann))
 -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path
     (StateT Int Identity)
     (Vector (Term name uni fun (Path, ann))
      -> Term name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
scrut
            ReaderT
  Path
  (StateT Int Identity)
  (Vector (Term name uni fun (Path, ann))
   -> Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Vector (Term name uni fun (Path, ann)))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall a b.
ReaderT Path (StateT Int Identity) (a -> b)
-> ReaderT Path (StateT Int Identity) a
-> ReaderT Path (StateT Int Identity) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( Vector (Term name uni fun ann)
-> (Term name uni fun ann
    -> ReaderT
         Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> ReaderT
     Path (StateT Int Identity) (Vector (Term name uni fun (Path, ann)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Vector (Term name uni fun ann)
branches ((Term name uni fun ann
  -> ReaderT
       Path (StateT Int Identity) (Term name uni fun (Path, ann)))
 -> ReaderT
      Path
      (StateT Int Identity)
      (Vector (Term name uni fun (Path, ann))))
-> (Term name uni fun ann
    -> ReaderT
         Path (StateT Int Identity) (Term name uni fun (Path, ann)))
-> ReaderT
     Path (StateT Int Identity) (Vector (Term name uni fun (Path, ann)))
forall a b. (a -> b) -> a -> b
$ \Term name uni fun ann
br -> do
                    Int
freshId <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int)
-> ReaderT Path (StateT Int Identity) Int
-> ReaderT Path (StateT Int Identity) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State Int Int -> ReaderT Path (StateT Int Identity) Int
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift State Int Int
forall (m :: * -> *) s. Monad m => StateT s m s
get
                    State Int () -> ReaderT Path (StateT Int Identity) ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT Path m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State Int () -> ReaderT Path (StateT Int Identity) ())
-> State Int () -> ReaderT Path (StateT Int Identity) ()
forall a b. (a -> b) -> a -> b
$ Int -> State Int ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Int
freshId
                    (Path -> Path)
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
local (Int
freshId Int -> Path -> Path
forall a. a -> [a] -> [a]
:) (Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann.
Term name uni fun ann
-> ReaderT
     Path (StateT Int Identity) (Term name uni fun (Path, ann))
go Term name uni fun ann
br)
                )

-- | The third pass. See Note [CSE].
countOccs ::
  forall name uni fun ann.
  (Hashable (Term name uni fun ()), ToBuiltinMeaning uni fun) =>
  BuiltinSemanticsVariant fun ->
  Term name uni fun (Path, ann) ->
  -- | Here, the value of the inner map not only contains the count, but also contains
  -- the annotated term, corresponding to the term that is the key of the outer map.
  -- The annotated terms need to be recorded since they will be used for substitution.
  HashMap (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
countOccs :: forall name (uni :: * -> *) fun ann.
(Hashable (Term name uni fun ()), ToBuiltinMeaning uni fun) =>
BuiltinSemanticsVariant fun
-> Term name uni fun (Path, ann)
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
countOccs BuiltinSemanticsVariant fun
builtinSemanticsVariant = Getting
  (Endo
     (HashMap
        (Term name uni fun ())
        [(Path, Term name uni fun (Path, ann), Int)]))
  (Term name uni fun (Path, ann))
  (Term name uni fun (Path, ann))
-> (Term name uni fun (Path, ann)
    -> HashMap
         (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
    -> HashMap
         (Term name uni fun ())
         [(Path, Term name uni fun (Path, ann), Int)])
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
-> Term name uni fun (Path, ann)
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
forall r s a. Getting (Endo r) s a -> (a -> r -> r) -> r -> s -> r
foldrOf Getting
  (Endo
     (HashMap
        (Term name uni fun ())
        [(Path, Term name uni fun (Path, ann), Int)]))
  (Term name uni fun (Path, ann))
  (Term name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann (f :: * -> *).
(Contravariant f, Applicative f) =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubtermsDeep Term name uni fun (Path, ann)
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
addToMap HashMap
  (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
forall k v. HashMap k v
Map.empty
  where
    addToMap ::
      Term name uni fun (Path, ann) ->
      HashMap (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)] ->
      HashMap (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
    addToMap :: Term name uni fun (Path, ann)
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
addToMap Term name uni fun (Path, ann)
t0
      -- We don't consider work-free terms for CSE, because doing so may or may not
      -- have a size benefit, but certainly doesn't have any cost benefit (the cost
      -- will in fact be slightly higher due to the additional application).
      | BuiltinSemanticsVariant fun
-> Term name uni fun (Path, ann) -> Bool
forall (uni :: * -> *) fun name a.
ToBuiltinMeaning uni fun =>
BuiltinSemanticsVariant fun -> Term name uni fun a -> Bool
isWorkFree BuiltinSemanticsVariant fun
builtinSemanticsVariant Term name uni fun (Path, ann)
t0
        Bool -> Bool -> Bool
|| Bool -> Bool
not (Term name uni fun (Path, ann) -> Bool
isBuiltinSaturated Term name uni fun (Path, ann)
t0)
        Bool -> Bool -> Bool
|| Term name uni fun (Path, ann) -> Bool
forall {name} {uni :: * -> *} {fun} {ann}.
Term name uni fun ann -> Bool
isForcingBuiltin Term name uni fun (Path, ann)
t0 =
          HashMap
  (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
forall a. a -> a
id
      | Bool
otherwise =
          (Maybe [(Path, Term name uni fun (Path, ann), Int)]
 -> Maybe [(Path, Term name uni fun (Path, ann), Int)])
-> Term name uni fun ()
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
-> HashMap
     (Term name uni fun ()) [(Path, Term name uni fun (Path, ann), Int)]
forall k v.
(Eq k, Hashable k) =>
(Maybe v -> Maybe v) -> k -> HashMap k v -> HashMap k v
Map.alter
            ( \case
                Maybe [(Path, Term name uni fun (Path, ann), Int)]
Nothing -> [(Path, Term name uni fun (Path, ann), Int)]
-> Maybe [(Path, Term name uni fun (Path, ann), Int)]
forall a. a -> Maybe a
Just [(Path
path, Term name uni fun (Path, ann)
t0, Int
1)]
                Just [(Path, Term name uni fun (Path, ann), Int)]
paths -> [(Path, Term name uni fun (Path, ann), Int)]
-> Maybe [(Path, Term name uni fun (Path, ann), Int)]
forall a. a -> Maybe a
Just ([(Path, Term name uni fun (Path, ann), Int)]
 -> Maybe [(Path, Term name uni fun (Path, ann), Int)])
-> [(Path, Term name uni fun (Path, ann), Int)]
-> Maybe [(Path, Term name uni fun (Path, ann), Int)]
forall a b. (a -> b) -> a -> b
$ Term name uni fun (Path, ann)
-> Path
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
forall name (uni :: * -> *) fun ann.
Term name uni fun (Path, ann)
-> Path
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
combinePaths Term name uni fun (Path, ann)
t0 Path
path [(Path, Term name uni fun (Path, ann), Int)]
paths
            )
            Term name uni fun ()
t
      where
        t :: Term name uni fun ()
t = Term name uni fun (Path, ann) -> Term name uni fun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Term name uni fun (Path, ann)
t0
        path :: Path
path = (Path, ann) -> Path
forall a b. (a, b) -> a
fst (Term name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term name uni fun (Path, ann)
t0)

    isBuiltinSaturated :: Term name uni fun (Path, ann) -> Bool
isBuiltinSaturated =
      Term name uni fun (Path, ann)
-> (Term name uni fun (Path, ann),
    [((Path, ann), Term name uni fun (Path, ann))])
forall name (uni :: * -> *) fun a.
Term name uni fun a
-> (Term name uni fun a, [(a, Term name uni fun a)])
splitApplication (Term name uni fun (Path, ann)
 -> (Term name uni fun (Path, ann),
     [((Path, ann), Term name uni fun (Path, ann))]))
-> ((Term name uni fun (Path, ann),
     [((Path, ann), Term name uni fun (Path, ann))])
    -> Bool)
-> Term name uni fun (Path, ann)
-> Bool
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
        (Builtin (Path, ann)
_ fun
fun, [((Path, ann), Term name uni fun (Path, ann))]
args) ->
          [((Path, ann), Term name uni fun (Path, ann))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [((Path, ann), Term name uni fun (Path, ann))]
args Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Param] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Proxy uni -> BuiltinSemanticsVariant fun -> fun -> [Param]
forall (uni :: * -> *) fun.
ToBuiltinMeaning uni fun =>
Proxy uni -> BuiltinSemanticsVariant fun -> fun -> [Param]
builtinArity (forall {k} (t :: k). Proxy t
forall (t :: * -> *). Proxy t
Proxy @uni) BuiltinSemanticsVariant fun
builtinSemanticsVariant fun
fun)
        (Term name uni fun (Path, ann),
 [((Path, ann), Term name uni fun (Path, ann))])
_term -> Bool
True

    isForcingBuiltin :: Term name uni fun ann -> Bool
isForcingBuiltin = \case
      Builtin{} -> Bool
True
      Force ann
_ Term name uni fun ann
t -> Term name uni fun ann -> Bool
isForcingBuiltin Term name uni fun ann
t
      Term name uni fun ann
_ -> Bool
False

-- | Combine a new path with a number of existing (path, count) pairs.
combinePaths ::
  forall name uni fun ann.
  Term name uni fun (Path, ann) ->
  Path ->
  [(Path, Term name uni fun (Path, ann), Int)] ->
  [(Path, Term name uni fun (Path, ann), Int)]
combinePaths :: forall name (uni :: * -> *) fun ann.
Term name uni fun (Path, ann)
-> Path
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
combinePaths Term name uni fun (Path, ann)
t Path
path = Int
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
go Int
1
  where
    go ::
      Int ->
      [(Path, Term name uni fun (Path, ann), Int)] ->
      [(Path, Term name uni fun (Path, ann), Int)]
    -- The new path is not a descendent-or-self of any existing path.
    go :: Int
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
go Int
acc [] = [(Path
path, Term name uni fun (Path, ann)
t, Int
acc)]
    go Int
acc ((Path
path', Term name uni fun (Path, ann)
t', Int
cnt) : [(Path, Term name uni fun (Path, ann), Int)]
paths)
      -- The new path is an ancestor-or-self of an existing path.
      -- Take over all counts of the existing path, remove the existing path,
      -- and continue.
      | Path
path Path -> Path -> Bool
`isAncestorOrSelf` Path
path' = Int
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
go (Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cnt) [(Path, Term name uni fun (Path, ann), Int)]
paths
      -- The new path is a descendent-or-self of an existing path.
      -- Increment the count for the existing path. There can only be one such
      -- existing path, so we don't need to recurse here.
      | Path
path' Path -> Path -> Bool
`isAncestorOrSelf` Path
path = (Path
path', Term name uni fun (Path, ann)
t', Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Path, Term name uni fun (Path, ann), Int)
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
forall a. a -> [a] -> [a]
: [(Path, Term name uni fun (Path, ann), Int)]
paths
      | Bool
otherwise = (Path
path', Term name uni fun (Path, ann)
t', Int
cnt) (Path, Term name uni fun (Path, ann), Int)
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
forall a. a -> [a] -> [a]
: Int
-> [(Path, Term name uni fun (Path, ann), Int)]
-> [(Path, Term name uni fun (Path, ann), Int)]
go Int
acc [(Path, Term name uni fun (Path, ann), Int)]
paths

mkCseTerm ::
  forall uni fun ann m.
  (MonadQuote m, Eq (Term Name uni fun ())) =>
  [Term Name uni fun (Path, ann)] ->
  -- | The original annotated term
  Term Name uni fun (Path, ann) ->
  m (Term Name uni fun ann)
mkCseTerm :: forall (uni :: * -> *) fun ann (m :: * -> *).
(MonadQuote m, Eq (Term Name uni fun ())) =>
[Term Name uni fun (Path, ann)]
-> Term Name uni fun (Path, ann) -> m (Term Name uni fun ann)
mkCseTerm [Term Name uni fun (Path, ann)]
ts Term Name uni fun (Path, ann)
t = do
  [CseCandidate uni fun ann]
cs <- (Term Name uni fun (Path, ann) -> m (CseCandidate uni fun ann))
-> [Term Name uni fun (Path, ann)] -> m [CseCandidate uni fun ann]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Term Name uni fun (Path, ann) -> m (CseCandidate uni fun ann)
forall (uni :: * -> *) fun ann (m :: * -> *).
MonadQuote m =>
Term Name uni fun (Path, ann) -> m (CseCandidate uni fun ann)
mkCseCandidate [Term Name uni fun (Path, ann)]
ts
  Term Name uni fun ann -> m (Term Name uni fun ann)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term Name uni fun ann -> m (Term Name uni fun ann))
-> (Term Name uni fun (Path, ann) -> Term Name uni fun ann)
-> Term Name uni fun (Path, ann)
-> m (Term Name uni fun ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Path, ann) -> ann)
-> Term Name uni fun (Path, ann) -> Term Name uni fun ann
forall a b. (a -> b) -> Term Name uni fun a -> Term Name uni fun b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Path, ann) -> ann
forall a b. (a, b) -> b
snd (Term Name uni fun (Path, ann) -> m (Term Name uni fun ann))
-> Term Name uni fun (Path, ann) -> m (Term Name uni fun ann)
forall a b. (a -> b) -> a -> b
$ (Term Name uni fun (Path, ann)
 -> CseCandidate uni fun ann -> Term Name uni fun (Path, ann))
-> Term Name uni fun (Path, ann)
-> [CseCandidate uni fun ann]
-> Term Name uni fun (Path, ann)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' ((CseCandidate uni fun ann
 -> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> Term Name uni fun (Path, ann)
-> CseCandidate uni fun ann
-> Term Name uni fun (Path, ann)
forall a b c. (a -> b -> c) -> b -> a -> c
flip CseCandidate uni fun ann
-> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
forall (uni :: * -> *) fun ann.
Eq (Term Name uni fun ()) =>
CseCandidate uni fun ann
-> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
applyCse) Term Name uni fun (Path, ann)
t [CseCandidate uni fun ann]
cs

applyCse ::
  forall uni fun ann.
  (Eq (Term Name uni fun ())) =>
  CseCandidate uni fun ann ->
  Term Name uni fun (Path, ann) ->
  Term Name uni fun (Path, ann)
applyCse :: forall (uni :: * -> *) fun ann.
Eq (Term Name uni fun ()) =>
CseCandidate uni fun ann
-> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
applyCse CseCandidate uni fun ann
c = Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
-> (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
  (Term Name uni fun (Path, ann))
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
substCseVarForTerm
  where
    candidatePath :: Path
candidatePath = (Path, ann) -> Path
forall a b. (a, b) -> a
fst (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn (CseCandidate uni fun ann -> Term Name uni fun (Path, ann)
forall (uni :: * -> *) fun ann.
CseCandidate uni fun ann -> Term Name uni fun (Path, ann)
ccAnnotatedTerm CseCandidate uni fun ann
c))

    substCseVarForTerm :: Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
    substCseVarForTerm :: Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
substCseVarForTerm Term Name uni fun (Path, ann)
t =
      if Term Name uni fun ()
currTerm Term Name uni fun () -> Term Name uni fun () -> Bool
forall a. Eq a => a -> a -> Bool
== CseCandidate uni fun ann -> Term Name uni fun ()
forall (uni :: * -> *) fun ann.
CseCandidate uni fun ann -> Term Name uni fun ()
ccTerm CseCandidate uni fun ann
c Bool -> Bool -> Bool
&& Path
candidatePath Path -> Path -> Bool
`isAncestorOrSelf` Path
currPath
        then (Path, ann) -> Name -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann
Var (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term Name uni fun (Path, ann)
t) (CseCandidate uni fun ann -> Name
forall (uni :: * -> *) fun ann. CseCandidate uni fun ann -> Name
ccFreshName CseCandidate uni fun ann
c)
        else Term Name uni fun (Path, ann)
t
      where
        currTerm :: Term Name uni fun ()
currTerm = Term Name uni fun (Path, ann) -> Term Name uni fun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Term Name uni fun (Path, ann)
t
        currPath :: Path
currPath = (Path, ann) -> Path
forall a b. (a, b) -> a
fst (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term Name uni fun (Path, ann)
t)

    mkLamApp :: Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
    mkLamApp :: Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
t
      | Path
currPath Path -> Path -> Bool
forall a. Eq a => a -> a -> Bool
== Path
candidatePath =
          (Path, ann)
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply
            (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term Name uni fun (Path, ann)
t)
            ((Path, ann)
-> Name
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term Name uni fun (Path, ann)
t) (CseCandidate uni fun ann -> Name
forall (uni :: * -> *) fun ann. CseCandidate uni fun ann -> Name
ccFreshName CseCandidate uni fun ann
c) Term Name uni fun (Path, ann)
t)
            (CseCandidate uni fun ann -> Term Name uni fun (Path, ann)
forall (uni :: * -> *) fun ann.
CseCandidate uni fun ann -> Term Name uni fun (Path, ann)
ccAnnotatedTerm CseCandidate uni fun ann
c)
      | Path
currPath Path -> Path -> Bool
`isAncestorOrSelf` Path
candidatePath = case Term Name uni fun (Path, ann)
t of
          Var (Path, ann)
ann Name
name            -> (Path, ann) -> Name -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann
Var (Path, ann)
ann Name
name
          LamAbs (Path, ann)
ann Name
name Term Name uni fun (Path, ann)
body    -> (Path, ann)
-> Name
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs (Path, ann)
ann Name
name (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
body)
          Apply (Path, ann)
ann Term Name uni fun (Path, ann)
fun Term Name uni fun (Path, ann)
arg       -> (Path, ann)
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply (Path, ann)
ann (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
fun) (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
arg)
          Force (Path, ann)
ann Term Name uni fun (Path, ann)
body          -> (Path, ann)
-> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force (Path, ann)
ann (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
body)
          Delay (Path, ann)
ann Term Name uni fun (Path, ann)
body          -> (Path, ann)
-> Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay (Path, ann)
ann (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
body)
          Constant (Path, ann)
ann Some (ValueOf uni)
val        -> (Path, ann) -> Some (ValueOf uni) -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Some (ValueOf uni) -> Term name uni fun ann
Constant (Path, ann)
ann Some (ValueOf uni)
val
          Builtin (Path, ann)
ann fun
fun         -> (Path, ann) -> fun -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> fun -> Term name uni fun ann
Builtin (Path, ann)
ann fun
fun
          Error (Path, ann)
ann               -> (Path, ann) -> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann. ann -> Term name uni fun ann
Error (Path, ann)
ann
          Constr (Path, ann)
ann Word64
i [Term Name uni fun (Path, ann)]
ts         -> (Path, ann)
-> Word64
-> [Term Name uni fun (Path, ann)]
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann -> Word64 -> [Term name uni fun ann] -> Term name uni fun ann
Constr (Path, ann)
ann Word64
i (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> [Term Name uni fun (Path, ann)]
-> [Term Name uni fun (Path, ann)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Term Name uni fun (Path, ann)]
ts)
          Case (Path, ann)
ann Term Name uni fun (Path, ann)
scrut Vector (Term Name uni fun (Path, ann))
branches -> (Path, ann)
-> Term Name uni fun (Path, ann)
-> Vector (Term Name uni fun (Path, ann))
-> Term Name uni fun (Path, ann)
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case (Path, ann)
ann (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp Term Name uni fun (Path, ann)
scrut) (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann)
mkLamApp (Term Name uni fun (Path, ann) -> Term Name uni fun (Path, ann))
-> Vector (Term Name uni fun (Path, ann))
-> Vector (Term Name uni fun (Path, ann))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (Term Name uni fun (Path, ann))
branches)
      | Bool
otherwise = Term Name uni fun (Path, ann)
t
      where
        currPath :: Path
currPath = (Path, ann) -> Path
forall a b. (a, b) -> a
fst (Term Name uni fun (Path, ann) -> (Path, ann)
forall name (uni :: * -> *) fun ann. Term name uni fun ann -> ann
termAnn Term Name uni fun (Path, ann)
t)

-- | Generate a fresh variable for the common subexpression.
mkCseCandidate ::
  forall uni fun ann m.
  (MonadQuote m) =>
  Term Name uni fun (Path, ann) ->
  m (CseCandidate uni fun ann)
mkCseCandidate :: forall (uni :: * -> *) fun ann (m :: * -> *).
MonadQuote m =>
Term Name uni fun (Path, ann) -> m (CseCandidate uni fun ann)
mkCseCandidate Term Name uni fun (Path, ann)
t = Name
-> Term Name uni fun ()
-> Term Name uni fun (Path, ann)
-> CseCandidate uni fun ann
forall (uni :: * -> *) fun ann.
Name
-> Term Name uni fun ()
-> Term Name uni fun (Path, ann)
-> CseCandidate uni fun ann
CseCandidate (Name
 -> Term Name uni fun ()
 -> Term Name uni fun (Path, ann)
 -> CseCandidate uni fun ann)
-> m Name
-> m (Term Name uni fun ()
      -> Term Name uni fun (Path, ann) -> CseCandidate uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Text -> m Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName Text
"cse" m (Term Name uni fun ()
   -> Term Name uni fun (Path, ann) -> CseCandidate uni fun ann)
-> m (Term Name uni fun ())
-> m (Term Name uni fun (Path, ann) -> CseCandidate uni fun ann)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Term Name uni fun () -> m (Term Name uni fun ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term Name uni fun (Path, ann) -> Term Name uni fun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Term Name uni fun (Path, ann)
t) m (Term Name uni fun (Path, ann) -> CseCandidate uni fun ann)
-> m (Term Name uni fun (Path, ann))
-> m (CseCandidate uni fun ann)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Term Name uni fun (Path, ann) -> m (Term Name uni fun (Path, ann))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term Name uni fun (Path, ann)
t