{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

{-|
Perform the case-of-case transformation. This pushes
case expressions into the case branches of other case
expressions, which can often yield optimization opportunities.

Example:
@
    case (case s of { C1 a -> x; C2 b -> y; }) of
      D1 -> w
      D2 -> z

    -->

    case s of
      C1 a -> case x of { D1 -> w; D2 -> z; }
      C2 b -> case y of { D1 -> w; D2 -> z; }
@

We also transform

@
    case ((force ifThenElse) b (constr t) (constr f)) alts
@

into

@
    force (force ifThenElse b (delay (case (constr t) alts)) (delay (case (constr f) alts)))
@

This is always an improvement. -}
module UntypedPlutusCore.Transform.CaseOfCase (caseOfCase) where

import PlutusPrelude

import PlutusCore qualified as PLC
import PlutusCore.Builtin (CaseBuiltin (..))
import PlutusCore.MkPlc (mkIterApp)
import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.CaseReduce qualified as CaseReduce
import UntypedPlutusCore.Transform.Simplifier
  ( SimplifierStage (CaseOfCase)
  , SimplifierT
  , recordSimplification
  )

import Control.Lens
import Data.List (nub)

caseOfCase
  :: ( fun ~ PLC.DefaultFun
     , Monad m
     , CaseBuiltin uni
     , PLC.GEq uni
     , PLC.Closed uni
     , uni `PLC.Everywhere` Eq
     )
  => Term name uni fun a
  -> SimplifierT name uni fun a m (Term name uni fun a)
caseOfCase :: forall fun (m :: * -> *) (uni :: * -> *) name a.
(fun ~ DefaultFun, Monad m, CaseBuiltin uni, GEq uni, Closed uni,
 Everywhere uni Eq) =>
Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
caseOfCase Term name uni fun a
term = do
  let result :: Term name uni fun a
result = ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term name uni fun a -> Term name uni fun a
forall fun (uni :: * -> *) name a.
(fun ~ DefaultFun, CaseBuiltin uni, GEq uni, Closed uni,
 Everywhere uni Eq) =>
Term name uni fun a -> Term name uni fun a
processTerm Term name uni fun a
term
  Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
recordSimplification Term name uni fun a
term SimplifierStage
CaseOfCase Term name uni fun a
result
  Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a. a -> SimplifierT name uni fun a m a
forall (m :: * -> *) a. Monad m => a -> m a
return Term name uni fun a
result

processTerm
  :: ( fun ~ PLC.DefaultFun
     , CaseBuiltin uni
     , PLC.GEq uni
     , PLC.Closed uni
     , uni `PLC.Everywhere` Eq
     )
  => Term name uni fun a -> Term name uni fun a
processTerm :: forall fun (uni :: * -> *) name a.
(fun ~ DefaultFun, CaseBuiltin uni, GEq uni, Closed uni,
 Everywhere uni Eq) =>
Term name uni fun a -> Term name uni fun a
processTerm = \case
  Case a
ann Term name uni fun a
scrut Vector (Term name uni fun a)
alts
    | ( ite :: Term name uni fun a
ite@(Force a
a (Builtin a
_ fun
DefaultFun
PLC.IfThenElse))
        , [(a, Term name uni fun a)
cond, (a
trueAnn, true :: Term name uni fun a
true@Constr {}), (a
falseAnn, false :: Term name uni fun a
false@Constr {})]
        ) <-
        Term name uni fun a
-> (Term name uni fun a, [(a, Term name uni fun a)])
forall name (uni :: * -> *) fun a.
Term name uni fun a
-> (Term name uni fun a, [(a, Term name uni fun a)])
splitApplication Term name uni fun a
scrut ->
        a -> Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force a
a (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a -> Term name uni fun a
forall a b. (a -> b) -> a -> b
$
          Term name uni fun a
-> [(a, Term name uni fun a)] -> Term name uni fun a
forall (term :: * -> *) tyname name (uni :: * -> *) fun ann.
TermLike term tyname name uni fun =>
term ann -> [(ann, term ann)] -> term ann
mkIterApp
            Term name uni fun a
ite
            [ (a, Term name uni fun a)
cond
            , -- Here we call a single step of case-reduce in order to immediately clean up the
              -- duplication of @alts@. Otherwise optimizing case-of-case-of-case-of... would create
              -- exponential blowup of the case branches, which would eventually get deduplicated
              -- with case-reduce, but only after that exponential blowup has already slowed the
              -- optimizer down unnecessarily.
              (a
trueAnn, a -> Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay a
trueAnn (Term name uni fun a -> Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Term name uni fun a
forall (uni :: * -> *) name fun a.
CaseBuiltin uni =>
Term name uni fun a -> Term name uni fun a
CaseReduce.processTerm (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a -> Term name uni fun a
forall a b. (a -> b) -> a -> b
$ a
-> Term name uni fun a
-> Vector (Term name uni fun a)
-> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case a
ann Term name uni fun a
true Vector (Term name uni fun a)
alts)
            , (a
falseAnn, a -> Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay a
falseAnn (Term name uni fun a -> Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Term name uni fun a
forall (uni :: * -> *) name fun a.
CaseBuiltin uni =>
Term name uni fun a -> Term name uni fun a
CaseReduce.processTerm (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a -> Term name uni fun a
forall a b. (a -> b) -> a -> b
$ a
-> Term name uni fun a
-> Vector (Term name uni fun a)
-> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case a
ann Term name uni fun a
false Vector (Term name uni fun a)
alts)
            ]
  original :: Term name uni fun a
original@(Case a
annOuter (Case a
annInner Term name uni fun a
scrut Vector (Term name uni fun a)
altsInner) Vector (Term name uni fun a)
altsOuter) ->
    Term name uni fun a
-> (Vector (Term name uni fun a) -> Term name uni fun a)
-> Maybe (Vector (Term name uni fun a))
-> Term name uni fun a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      Term name uni fun a
original
      (a
-> Term name uni fun a
-> Vector (Term name uni fun a)
-> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case a
annInner Term name uni fun a
scrut)
      ( do
          Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
constrs <- Vector (Term name uni fun a)
-> (Term name uni fun a
    -> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a))
-> Maybe
     (Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Vector (Term name uni fun a)
altsInner ((Term name uni fun a
  -> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a))
 -> Maybe
      (Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)))
-> (Term name uni fun a
    -> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a))
-> Maybe
     (Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a))
forall a b. (a -> b) -> a -> b
$ \case
            c :: Term name uni fun a
c@(Constr a
_ Word64
i [Term name uni fun a]
_) -> (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
forall a. a -> Maybe a
Just (Word64 -> Either Word64 (Some (ValueOf uni))
forall a b. a -> Either a b
Left Word64
i, Term name uni fun a
c)
            c :: Term name uni fun a
c@(Constant a
_ Some (ValueOf uni)
val) -> (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
forall a. a -> Maybe a
Just (Some (ValueOf uni) -> Either Word64 (Some (ValueOf uni))
forall a b. b -> Either a b
Right Some (ValueOf uni)
val, Term name uni fun a
c)
            Term name uni fun a
_ -> Maybe (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
forall a. Maybe a
Nothing
          -- See Note [Case-of-case and duplicating code].
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Either Word64 (Some (ValueOf uni))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Either Word64 (Some (ValueOf uni))]
-> [Either Word64 (Some (ValueOf uni))]
forall a. Eq a => [a] -> [a]
nub ([Either Word64 (Some (ValueOf uni))]
 -> [Either Word64 (Some (ValueOf uni))])
-> (Vector (Either Word64 (Some (ValueOf uni)))
    -> [Either Word64 (Some (ValueOf uni))])
-> Vector (Either Word64 (Some (ValueOf uni)))
-> [Either Word64 (Some (ValueOf uni))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Either Word64 (Some (ValueOf uni)))
-> [Either Word64 (Some (ValueOf uni))]
forall a. Vector a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Vector (Either Word64 (Some (ValueOf uni)))
 -> [Either Word64 (Some (ValueOf uni))])
-> Vector (Either Word64 (Some (ValueOf uni)))
-> [Either Word64 (Some (ValueOf uni))]
forall a b. (a -> b) -> a -> b
$ ((Either Word64 (Some (ValueOf uni)), Term name uni fun a)
 -> Either Word64 (Some (ValueOf uni)))
-> Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> Vector (Either Word64 (Some (ValueOf uni)))
forall a b. (a -> b) -> Vector a -> Vector b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> Either Word64 (Some (ValueOf uni))
forall a b. (a, b) -> a
fst Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
constrs) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> Int
forall a. Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
constrs
          Vector (Term name uni fun a)
-> Maybe (Vector (Term name uni fun a))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector (Term name uni fun a)
 -> Maybe (Vector (Term name uni fun a)))
-> Vector (Term name uni fun a)
-> Maybe (Vector (Term name uni fun a))
forall a b. (a -> b) -> a -> b
$ Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
constrs Vector (Either Word64 (Some (ValueOf uni)), Term name uni fun a)
-> ((Either Word64 (Some (ValueOf uni)), Term name uni fun a)
    -> Term name uni fun a)
-> Vector (Term name uni fun a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Either Word64 (Some (ValueOf uni))
_, Term name uni fun a
c) -> Term name uni fun a -> Term name uni fun a
forall (uni :: * -> *) name fun a.
CaseBuiltin uni =>
Term name uni fun a -> Term name uni fun a
CaseReduce.processTerm (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a -> Term name uni fun a
forall a b. (a -> b) -> a -> b
$ a
-> Term name uni fun a
-> Vector (Term name uni fun a)
-> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case a
annOuter Term name uni fun a
c Vector (Term name uni fun a)
altsOuter
      )
  Term name uni fun a
other -> Term name uni fun a
other