{-# LANGUAGE LambdaCase #-}

{- Note [Applying force to delays in case branches]

Often, the following pattern occurs in UPLC terms:

> force (case scrut [\x1... -> delay term_1, ..., \x1... -> delay term_m])

It's sound to remove the 'force' and the 'delay's, as long as the original term is "well-formed".
Note that the lambda abstraction may be missing, and we consider that case as well.

Intuitively, what we mean by "well-formed" is that the term does not evaluate to bottom unless
the user intended it to (i.e. by introducing an 'error' subterm).

In the the context of the Plinth compiler pipeline, UPLC is always generated from TPLC, by
erasing the types of a TPLC term. So at the beginning of the UPLC phase of the compiler,
the UPLC can be considered "well-formed", since it is guaranteed by the types of the original
TPLC term.

The other UPLC transformations we have are guaranteed to preserve this property, so we can assume
that any UPLC term coming from the UPLC phase of the compiler is "well-formed".

What about preserving the correct evaluation strategy? Since we are removing 'delay's, we need to
ensure that the terms under the delays are not evaluated ahead of time. However, this is not a problem
because the 'case' construct is lazy in its branches, so the terms under the 'delay's will not be
evaluated unless the corresponding branch is taken.

We should, however, formally define what "well-formed" means, and this is left as future work:
FIXME(https://github.com/IntersectMBO/plutus-private/issues/1644).

-}
module UntypedPlutusCore.Transform.ForceCaseDelay
  ( forceCaseDelay
  )
where

import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.Simplifier
  ( SimplifierStage (ForceCaseDelay)
  , SimplifierT
  , recordSimplification
  )

import Control.Lens

forceCaseDelay
  :: Monad m
  => Term name uni fun a
  -> SimplifierT name uni fun a m (Term name uni fun a)
forceCaseDelay :: forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forceCaseDelay Term name uni fun a
term = do
  let result :: Term name uni fun a
result = ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun a.
Term name uni fun a -> Term name uni fun a
processTerm Term name uni fun a
term
  Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
recordSimplification Term name uni fun a
term SimplifierStage
ForceCaseDelay Term name uni fun a
result
  Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a. a -> SimplifierT name uni fun a m a
forall (m :: * -> *) a. Monad m => a -> m a
return Term name uni fun a
result

processTerm :: Term name uni fun a -> Term name uni fun a
processTerm :: forall name (uni :: * -> *) fun a.
Term name uni fun a -> Term name uni fun a
processTerm = \case
  original :: Term name uni fun a
original@(Force a
_ (Case a
cAnn Term name uni fun a
scrut Vector (Term name uni fun a)
branches)) ->
    let mNewBranches :: Maybe (Vector (Term name uni fun a))
mNewBranches = (Term name uni fun a -> Maybe (Term name uni fun a))
-> Vector (Term name uni fun a)
-> Maybe (Vector (Term name uni fun a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Vector a -> f (Vector b)
traverse Term name uni fun a -> Maybe (Term name uni fun a)
forall name (uni :: * -> *) fun a.
Term name uni fun a -> Maybe (Term name uni fun a)
findDelayUnderLambdas Vector (Term name uni fun a)
branches
     in case Maybe (Vector (Term name uni fun a))
mNewBranches of
          Just Vector (Term name uni fun a)
newBranches ->
            a
-> Term name uni fun a
-> Vector (Term name uni fun a)
-> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case a
cAnn Term name uni fun a
scrut Vector (Term name uni fun a)
newBranches
          Maybe (Vector (Term name uni fun a))
Nothing -> Term name uni fun a
original
  Term name uni fun a
other -> Term name uni fun a
other
  where
    findDelayUnderLambdas :: Term name uni fun a -> Maybe (Term name uni fun a)
    findDelayUnderLambdas :: forall name (uni :: * -> *) fun a.
Term name uni fun a -> Maybe (Term name uni fun a)
findDelayUnderLambdas = \case
      LamAbs a
ann name
var Term name uni fun a
body -> a -> name -> Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs a
ann name
var (Term name uni fun a -> Term name uni fun a)
-> Maybe (Term name uni fun a) -> Maybe (Term name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Term name uni fun a -> Maybe (Term name uni fun a)
forall name (uni :: * -> *) fun a.
Term name uni fun a -> Maybe (Term name uni fun a)
findDelayUnderLambdas Term name uni fun a
body
      Delay a
_ Term name uni fun a
term -> Term name uni fun a -> Maybe (Term name uni fun a)
forall a. a -> Maybe a
Just Term name uni fun a
term
      Term name uni fun a
_ -> Maybe (Term name uni fun a)
forall a. Maybe a
Nothing