{-# LANGUAGE LambdaCase #-}

{-| The Float Delay optimization floats `Delay` from arguments into function bodies,
if possible. It turns @(\n -> ...Force n...Force n...) (Delay arg)@ into
@(\n -> ...Force (Delay n)...Force (Delay n)...) arg@.

The above transformation is performed if:

    * All occurrences of @arg@ are under @Force@.

    * @arg@ is essentially work-free.

This achieves a similar effect to Plutonomy's "Split Delay" optimization. The difference
is that Split Delay simply splits the @Delay@ argument into multiple arguments, turning the
above example into @(\m -> (\n -> ...Force n...Force n...) (Delay m)) arg@, and then relies
on other optimizations to simplify it further. Specifically, once the inliner inlines
@Delay m@, it will be identical to the result of Float Delay.

The advantages of Float Delay are:

    * It doesn't rely on the inliner. In this example, Split Delay relies on the inliner to
      inline @Delay m@, but there's no guarantee that the inliner will do so, because inlining
      it may increase the program size.

      We can potentially modify the inliner such that it is aware of Float Delay and
      Force-Delay Cancel, and makes inlining decisions with these other optimizations in mind.
      The problem is that, not only does this makes the inlining heuristics much more
      complex, but it could easily lead to code duplication. Other optimizations often
      need to do some calculation in order to make certain optimization decisions (e.g., in
      this case, we want to check whether all occurrences of @arg@ are under @Force@), and
      if we rely on the inliner to inline the @Delay@, then the same check would need to be
      performed by the inliner.

    * Because Force Delay requires that all occurrences of @arg@ are under @Force@, it
      guarantees to not increase the size or the cost of the program. This is not the case
      with Split Delay: in this example, if the occurrences of @n@ are not under @Force@,
      then Split Delay may increase the size of the program, regardless of whether or not
      @Delay m@ is inlined. If @Delay m@ is not inlined, then it will also increase the
      cost of the program, due to the additional application.

The alternative approach that always floats the @Delay@ regardless of whether or not all
occurences of @arg@ are under @Force@ was implemented and tested, and it is strictly worse than
Float Delay on our current test suite (specifically, Split Delay causes one test case
to have a slightly bigger program, and everything else is equal).

Why is this optimization performed on UPLC, not PIR?

    1. Not only are the types and let-bindings in PIR not useful for this optimization,
       they can also get in the way. For example, we cannot transform
       @let f = /\a. ...a... in ...{f t1}...{f t2}...@ into
       @ket f = ...a... in ...f...f...@.

    2. This optimization mainly interacts with ForceDelayCancel and the inliner, and
       both are part of the UPLC simplifier. -}
module UntypedPlutusCore.Transform.FloatDelay (floatDelay) where

import PlutusCore qualified as PLC
import PlutusCore.Name.Unique qualified as PLC
import PlutusCore.Name.UniqueMap qualified as UMap
import PlutusCore.Name.UniqueSet qualified as USet
import UntypedPlutusCore.Core.Plated (termSubterms)
import UntypedPlutusCore.Core.Type (Term (..))
import UntypedPlutusCore.Transform.Simplifier
  ( SimplifierStage (FloatDelay)
  , SimplifierT
  , recordSimplification
  )

import Control.Lens (forOf, forOf_, transformOf)
import Control.Monad.Trans.Writer.CPS (Writer, execWriter, runWriter, tell)

floatDelay
  :: ( PLC.MonadQuote m
     , PLC.Rename (Term name uni fun a)
     , PLC.HasUnique name PLC.TermUnique
     )
  => Term name uni fun a
  -> SimplifierT name uni fun a m (Term name uni fun a)
floatDelay :: forall (m :: * -> *) name (uni :: * -> *) fun a.
(MonadQuote m, Rename (Term name uni fun a),
 HasUnique name TermUnique) =>
Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
floatDelay Term name uni fun a
term = do
  Term name uni fun a
result <-
    Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a (m :: * -> *). (Rename a, MonadQuote m) => a -> m a
forall (m :: * -> *).
MonadQuote m =>
Term name uni fun a -> m (Term name uni fun a)
PLC.rename Term name uni fun a
term SimplifierT name uni fun a m (Term name uni fun a)
-> (Term name uni fun a
    -> SimplifierT name uni fun a m (Term name uni fun a))
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a b.
SimplifierT name uni fun a m a
-> (a -> SimplifierT name uni fun a m b)
-> SimplifierT name uni fun a m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Term name uni fun a
t ->
      Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a. a -> SimplifierT name uni fun a m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun a
 -> SimplifierT name uni fun a m (Term name uni fun a))
-> ((Term name uni fun a, UniqueMap TermUnique a)
    -> Term name uni fun a)
-> (Term name uni fun a, UniqueMap TermUnique a)
-> SimplifierT name uni fun a m (Term name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term name uni fun a
 -> UniqueMap TermUnique a -> Term name uni fun a)
-> (Term name uni fun a, UniqueMap TermUnique a)
-> Term name uni fun a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((UniqueMap TermUnique a
 -> Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> UniqueMap TermUnique a
-> Term name uni fun a
forall a b c. (a -> b -> c) -> b -> a -> c
flip UniqueMap TermUnique a
-> Term name uni fun a -> Term name uni fun a
forall name a (uni :: * -> *) fun.
HasUnique name TermUnique =>
UniqueMap TermUnique a
-> Term name uni fun a -> Term name uni fun a
simplifyBodies) ((Term name uni fun a, UniqueMap TermUnique a)
 -> SimplifierT name uni fun a m (Term name uni fun a))
-> (Term name uni fun a, UniqueMap TermUnique a)
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a b. (a -> b) -> a -> b
$ UniqueSet TermUnique
-> Term name uni fun a
-> (Term name uni fun a, UniqueMap TermUnique a)
forall name (uni :: * -> *) fun a.
HasUnique name TermUnique =>
UniqueSet TermUnique
-> Term name uni fun a
-> (Term name uni fun a, UniqueMap TermUnique a)
simplifyArgs (Term name uni fun a -> UniqueSet TermUnique
forall name (uni :: * -> *) fun a.
HasUnique name TermUnique =>
Term name uni fun a -> UniqueSet TermUnique
unforcedVars Term name uni fun a
t) Term name uni fun a
t
  Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
recordSimplification Term name uni fun a
term SimplifierStage
FloatDelay Term name uni fun a
result
  Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a. a -> SimplifierT name uni fun a m a
forall (m :: * -> *) a. Monad m => a -> m a
return Term name uni fun a
result

{-| First pass. Returns the names of all variables, at least one occurrence
of which is not under `Force`. -}
unforcedVars
  :: forall name uni fun a
   . PLC.HasUnique name PLC.TermUnique
  => Term name uni fun a
  -> PLC.UniqueSet PLC.TermUnique
unforcedVars :: forall name (uni :: * -> *) fun a.
HasUnique name TermUnique =>
Term name uni fun a -> UniqueSet TermUnique
unforcedVars = Writer (UniqueSet TermUnique) () -> UniqueSet TermUnique
forall w a. Monoid w => Writer w a -> w
execWriter (Writer (UniqueSet TermUnique) () -> UniqueSet TermUnique)
-> (Term name uni fun a -> Writer (UniqueSet TermUnique) ())
-> Term name uni fun a
-> UniqueSet TermUnique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Writer (UniqueSet TermUnique) ()
go
  where
    go :: Term name uni fun a -> Writer (PLC.UniqueSet PLC.TermUnique) ()
    go :: Term name uni fun a -> Writer (UniqueSet TermUnique) ()
go = \case
      Var a
_ name
n -> UniqueSet TermUnique -> Writer (UniqueSet TermUnique) ()
forall w (m :: * -> *). (Monoid w, Monad m) => w -> WriterT w m ()
tell (name -> UniqueSet TermUnique
forall name unique.
HasUnique name unique =>
name -> UniqueSet unique
USet.singletonName name
n)
      Force a
_ Var {} -> () -> Writer (UniqueSet TermUnique) ()
forall a. a -> WriterT (UniqueSet TermUnique) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Term name uni fun a
t -> Getting
  (Traversed () (WriterT (UniqueSet TermUnique) Identity))
  (Term name uni fun a)
  (Term name uni fun a)
-> Term name uni fun a
-> (Term name uni fun a -> Writer (UniqueSet TermUnique) ())
-> Writer (UniqueSet TermUnique) ()
forall (f :: * -> *) r s a.
Functor f =>
Getting (Traversed r f) s a -> s -> (a -> f r) -> f ()
forOf_ Getting
  (Traversed () (WriterT (UniqueSet TermUnique) Identity))
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term name uni fun a
t Term name uni fun a -> Writer (UniqueSet TermUnique) ()
go

{-| Second pass. Removes `Delay` from eligible arguments, and returns
the names of variables whose corresponding arguments are modified. -}
simplifyArgs
  :: forall name uni fun a
   . PLC.HasUnique name PLC.TermUnique
  => PLC.UniqueSet PLC.TermUnique
  -- ^ The set of variables returned by `unforcedVars`.
  -> Term name uni fun a
  -> (Term name uni fun a, PLC.UniqueMap PLC.TermUnique a)
simplifyArgs :: forall name (uni :: * -> *) fun a.
HasUnique name TermUnique =>
UniqueSet TermUnique
-> Term name uni fun a
-> (Term name uni fun a, UniqueMap TermUnique a)
simplifyArgs UniqueSet TermUnique
blacklist = Writer (UniqueMap TermUnique a) (Term name uni fun a)
-> (Term name uni fun a, UniqueMap TermUnique a)
forall w a. Monoid w => Writer w a -> (a, w)
runWriter (Writer (UniqueMap TermUnique a) (Term name uni fun a)
 -> (Term name uni fun a, UniqueMap TermUnique a))
-> (Term name uni fun a
    -> Writer (UniqueMap TermUnique a) (Term name uni fun a))
-> Term name uni fun a
-> (Term name uni fun a, UniqueMap TermUnique a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a
-> Writer (UniqueMap TermUnique a) (Term name uni fun a)
forall ann.
Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
go
  where
    go :: Term name uni fun ann -> Writer (PLC.UniqueMap PLC.TermUnique ann) (Term name uni fun ann)
    go :: forall ann.
Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
go = \case
      Apply ann
appAnn (LamAbs ann
lamAnn name
n Term name uni fun ann
lamBody) (Delay ann
delayAnn Term name uni fun ann
arg)
        | Term name uni fun ann -> Bool
forall name (uni :: * -> *) fun a. Term name uni fun a -> Bool
isEssentiallyWorkFree Term name uni fun ann
arg
        , name
n name -> UniqueSet TermUnique -> Bool
forall name unique.
HasUnique name unique =>
name -> UniqueSet unique -> Bool
`USet.notMemberByName` UniqueSet TermUnique
blacklist -> do
            UniqueMap TermUnique ann
-> WriterT (UniqueMap TermUnique ann) Identity ()
forall w (m :: * -> *). (Monoid w, Monad m) => w -> WriterT w m ()
tell (name -> ann -> UniqueMap TermUnique ann
forall name unique a.
HasUnique name unique =>
name -> a -> UniqueMap unique a
UMap.singletonByName name
n ann
delayAnn)
            (ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply ann
appAnn (Term name uni fun ann
 -> Term name uni fun ann -> Term name uni fun ann)
-> (Term name uni fun ann -> Term name uni fun ann)
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ann -> name -> Term name uni fun ann -> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs ann
lamAnn name
n (Term name uni fun ann
 -> Term name uni fun ann -> Term name uni fun ann)
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
-> WriterT
     (UniqueMap TermUnique ann)
     Identity
     (Term name uni fun ann -> Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
forall ann.
Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
go Term name uni fun ann
lamBody) WriterT
  (UniqueMap TermUnique ann)
  Identity
  (Term name uni fun ann -> Term name uni fun ann)
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
forall a b.
WriterT (UniqueMap TermUnique ann) Identity (a -> b)
-> WriterT (UniqueMap TermUnique ann) Identity a
-> WriterT (UniqueMap TermUnique ann) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
forall ann.
Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
go Term name uni fun ann
arg
      Term name uni fun ann
t -> LensLike
  (WriterT (UniqueMap TermUnique ann) Identity)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
-> Term name uni fun ann
-> (Term name uni fun ann
    -> Writer (UniqueMap TermUnique ann) (Term name uni fun ann))
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
forall (f :: * -> *) s t a b.
LensLike f s t a b -> s -> (a -> f b) -> f t
forOf LensLike
  (WriterT (UniqueMap TermUnique ann) Identity)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term name uni fun ann
t Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
forall ann.
Term name uni fun ann
-> Writer (UniqueMap TermUnique ann) (Term name uni fun ann)
go

-- | Third pass. Turns @Force n@ into @Force (Delay n)@ for all eligibile @n@.
simplifyBodies
  :: PLC.HasUnique name PLC.TermUnique
  => PLC.UniqueMap PLC.TermUnique a
  -> Term name uni fun a
  -> Term name uni fun a
simplifyBodies :: forall name a (uni :: * -> *) fun.
HasUnique name TermUnique =>
UniqueMap TermUnique a
-> Term name uni fun a -> Term name uni fun a
simplifyBodies UniqueMap TermUnique a
whitelist = ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms ((Term name uni fun a -> Term name uni fun a)
 -> Term name uni fun a -> Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall a b. (a -> b) -> a -> b
$ \case
  var :: Term name uni fun a
var@(Var a
_ name
n)
    | Just a
ann <- name -> UniqueMap TermUnique a -> Maybe a
forall name unique a.
HasUnique name unique =>
name -> UniqueMap unique a -> Maybe a
UMap.lookupName name
n UniqueMap TermUnique a
whitelist -> a -> Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay a
ann Term name uni fun a
var
  Term name uni fun a
t -> Term name uni fun a
t

{-| Whether evaluating the given `Term` is pure and essentially work-free
(barring the CEK machine overhead). -}

--- This should be the erased version of 'PlutusIR.Transform.LetFloat.isEssentiallyWorkFree'.
isEssentiallyWorkFree :: Term name uni fun a -> Bool
isEssentiallyWorkFree :: forall name (uni :: * -> *) fun a. Term name uni fun a -> Bool
isEssentiallyWorkFree = \case
  LamAbs {} -> Bool
True
  Constant {} -> Bool
True
  Delay {} -> Bool
True
  Constr {} -> Bool
True
  Builtin {} -> Bool
True
  Var {} -> Bool
False
  Force {} -> Bool
False
  -- Unsaturated builtin applications should also be essentially work-free,
  -- but this is currently not implemented for UPLC.
  -- `UntypedPlutusCore.Transform.Inline.isPure` has the same problem.
  Apply {} -> Bool
False
  Case {} -> Bool
False
  Error {} -> Bool
False