{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}

module UntypedPlutusCore.Transform.CaseReduce
  ( caseReduce
  , processTerm
  ) where

import Control.Lens (transformOf)
import Data.Bifunctor (second)
import Data.Vector qualified as V
import PlutusCore.Builtin (CaseBuiltin (..))
import PlutusCore.MkPlc
import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.Simplifier
  ( SimplifierStage (CaseReduce)
  , SimplifierT
  , recordSimplification
  )

caseReduce
  :: (Monad m, CaseBuiltin uni)
  => Term name uni fun a
  -> SimplifierT name uni fun a m (Term name uni fun a)
caseReduce :: forall (m :: * -> *) (uni :: * -> *) name fun a.
(Monad m, CaseBuiltin uni) =>
Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
caseReduce Term name uni fun a
term = do
  let result :: Term name uni fun a
result = ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> Term name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann (f :: * -> *).
Applicative f =>
(Term name uni fun ann -> f (Term name uni fun ann))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubterms Term name uni fun a -> Term name uni fun a
forall (uni :: * -> *) name fun a.
CaseBuiltin uni =>
Term name uni fun a -> Term name uni fun a
processTerm Term name uni fun a
term
  Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
recordSimplification Term name uni fun a
term SimplifierStage
CaseReduce Term name uni fun a
result
  Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
forall a. a -> SimplifierT name uni fun a m a
forall (m :: * -> *) a. Monad m => a -> m a
return Term name uni fun a
result

processTerm :: CaseBuiltin uni => Term name uni fun a -> Term name uni fun a
processTerm :: forall (uni :: * -> *) name fun a.
CaseBuiltin uni =>
Term name uni fun a -> Term name uni fun a
processTerm = \case
  -- We could've rewritten those patterns as 'Error' in the 'Nothing' cases, but that would turn a
  -- structural error into an operational one, which would be unfortunate, so instead we decided
  -- not to fully optimize such scripts, since they aren't valid anyway.
  Case a
ann (Constr a
_ Word64
i [Term name uni fun a]
args) Vector (Term name uni fun a)
cs
    | Just Term name uni fun a
c <- Vector (Term name uni fun a) -> Int -> Maybe (Term name uni fun a)
forall a. Vector a -> Int -> Maybe a
(V.!?) Vector (Term name uni fun a)
cs (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
i) ->
        Term name uni fun a
-> [(a, Term name uni fun a)] -> Term name uni fun a
forall (term :: * -> *) tyname name (uni :: * -> *) fun ann.
TermLike term tyname name uni fun =>
term ann -> [(ann, term ann)] -> term ann
mkIterApp Term name uni fun a
c ((a
ann,) (Term name uni fun a -> (a, Term name uni fun a))
-> [Term name uni fun a] -> [(a, Term name uni fun a)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Term name uni fun a]
args)
  Case a
ann (Constant a
_ Some (ValueOf uni)
con) Vector (Term name uni fun a)
cs
    | Right Term name uni fun a
t <- a
-> MonoHeadSpine Text (Term name uni fun a)
-> Either Text (Term name uni fun a)
forall (term :: * -> *) tyname name (uni :: * -> *) fun ann err.
TermLike term tyname name uni fun =>
ann -> MonoHeadSpine err (term ann) -> Either err (term ann)
headSpineToTerm a
ann ((Some (ValueOf uni) -> Term name uni fun a)
-> HeadSpine Text (Term name uni fun a) (Some (ValueOf uni))
-> MonoHeadSpine Text (Term name uni fun a)
forall b c a. (b -> c) -> HeadSpine Text a b -> HeadSpine Text a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (a -> Some (ValueOf uni) -> Term name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Some (ValueOf uni) -> Term name uni fun ann
Constant a
ann) (Some (ValueOf uni)
-> Vector (Term name uni fun a)
-> HeadSpine Text (Term name uni fun a) (Some (ValueOf uni))
forall term.
(UniOf term ~ uni) =>
Some (ValueOf uni)
-> Vector term -> HeadSpine Text term (Some (ValueOf uni))
forall (uni :: * -> *) term.
(CaseBuiltin uni, UniOf term ~ uni) =>
Some (ValueOf uni)
-> Vector term -> HeadSpine Text term (Some (ValueOf uni))
caseBuiltin Some (ValueOf uni)
con Vector (Term name uni fun a)
cs)) -> Term name uni fun a
t
  Term name uni fun a
t -> Term name uni fun a
t