{-# LANGUAGE NamedFieldPuns #-}

module UntypedPlutusCore.Transform.Simplifier (
    SimplifierT (..),
    SimplifierTrace (..),
    SimplifierStage (..),
    Simplification (..),
    runSimplifierT,
    evalSimplifierT,
    execSimplifierT,
    Simplifier,
    runSimplifier,
    evalSimplifier,
    execSimplifier,
    initSimplifierTrace,
    recordSimplification,
) where

import Control.Monad.State (MonadTrans, StateT)
import Control.Monad.State qualified as State

import Control.Monad.Identity (Identity, runIdentity)
import PlutusCore.Quote (MonadQuote)
import UntypedPlutusCore.Core.Type (Term)

newtype SimplifierT name uni fun ann m a =
  SimplifierT
    { forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
getSimplifierT :: StateT (SimplifierTrace name uni fun ann) m a
    }
  deriving newtype ((forall a b.
 (a -> b)
 -> SimplifierT name uni fun ann m a
 -> SimplifierT name uni fun ann m b)
-> (forall a b.
    a
    -> SimplifierT name uni fun ann m b
    -> SimplifierT name uni fun ann m a)
-> Functor (SimplifierT name uni fun ann m)
forall a b.
a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
forall a b.
(a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Functor m =>
a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Functor m =>
(a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Functor m =>
(a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
fmap :: forall a b.
(a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
$c<$ :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Functor m =>
a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
<$ :: forall a b.
a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
Functor, Functor (SimplifierT name uni fun ann m)
Functor (SimplifierT name uni fun ann m) =>
(forall a. a -> SimplifierT name uni fun ann m a)
-> (forall a b.
    SimplifierT name uni fun ann m (a -> b)
    -> SimplifierT name uni fun ann m a
    -> SimplifierT name uni fun ann m b)
-> (forall a b c.
    (a -> b -> c)
    -> SimplifierT name uni fun ann m a
    -> SimplifierT name uni fun ann m b
    -> SimplifierT name uni fun ann m c)
-> (forall a b.
    SimplifierT name uni fun ann m a
    -> SimplifierT name uni fun ann m b
    -> SimplifierT name uni fun ann m b)
-> (forall a b.
    SimplifierT name uni fun ann m a
    -> SimplifierT name uni fun ann m b
    -> SimplifierT name uni fun ann m a)
-> Applicative (SimplifierT name uni fun ann m)
forall a. a -> SimplifierT name uni fun ann m a
forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
forall a b.
SimplifierT name uni fun ann m (a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
forall a b c.
(a -> b -> c)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m c
forall name (uni :: * -> *) fun ann (m :: * -> *).
Monad m =>
Functor (SimplifierT name uni fun ann m)
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
a -> SimplifierT name uni fun ann m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m (a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
forall name (uni :: * -> *) fun ann (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
a -> SimplifierT name uni fun ann m a
pure :: forall a. a -> SimplifierT name uni fun ann m a
$c<*> :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m (a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
<*> :: forall a b.
SimplifierT name uni fun ann m (a -> b)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
$cliftA2 :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m c
liftA2 :: forall a b c.
(a -> b -> c)
-> SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m c
$c*> :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
*> :: forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
$c<* :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
<* :: forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m a
Applicative, Applicative (SimplifierT name uni fun ann m)
Applicative (SimplifierT name uni fun ann m) =>
(forall a b.
 SimplifierT name uni fun ann m a
 -> (a -> SimplifierT name uni fun ann m b)
 -> SimplifierT name uni fun ann m b)
-> (forall a b.
    SimplifierT name uni fun ann m a
    -> SimplifierT name uni fun ann m b
    -> SimplifierT name uni fun ann m b)
-> (forall a. a -> SimplifierT name uni fun ann m a)
-> Monad (SimplifierT name uni fun ann m)
forall a. a -> SimplifierT name uni fun ann m a
forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
forall a b.
SimplifierT name uni fun ann m a
-> (a -> SimplifierT name uni fun ann m b)
-> SimplifierT name uni fun ann m b
forall name (uni :: * -> *) fun ann (m :: * -> *).
Monad m =>
Applicative (SimplifierT name uni fun ann m)
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
a -> SimplifierT name uni fun ann m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> (a -> SimplifierT name uni fun ann m b)
-> SimplifierT name uni fun ann m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> (a -> SimplifierT name uni fun ann m b)
-> SimplifierT name uni fun ann m b
>>= :: forall a b.
SimplifierT name uni fun ann m a
-> (a -> SimplifierT name uni fun ann m b)
-> SimplifierT name uni fun ann m b
$c>> :: forall name (uni :: * -> *) fun ann (m :: * -> *) a b.
Monad m =>
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
>> :: forall a b.
SimplifierT name uni fun ann m a
-> SimplifierT name uni fun ann m b
-> SimplifierT name uni fun ann m b
$creturn :: forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
a -> SimplifierT name uni fun ann m a
return :: forall a. a -> SimplifierT name uni fun ann m a
Monad, (forall (m :: * -> *).
 Monad m =>
 Monad (SimplifierT name uni fun ann m)) =>
(forall (m :: * -> *) a.
 Monad m =>
 m a -> SimplifierT name uni fun ann m a)
-> MonadTrans (SimplifierT name uni fun ann)
forall name (uni :: * -> *) fun ann (m :: * -> *).
Monad m =>
Monad (SimplifierT name uni fun ann m)
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
m a -> SimplifierT name uni fun ann m a
forall (m :: * -> *).
Monad m =>
Monad (SimplifierT name uni fun ann m)
forall (m :: * -> *) a.
Monad m =>
m a -> SimplifierT name uni fun ann m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall name (uni :: * -> *) fun ann (m :: * -> *) a.
Monad m =>
m a -> SimplifierT name uni fun ann m a
lift :: forall (m :: * -> *) a.
Monad m =>
m a -> SimplifierT name uni fun ann m a
MonadTrans)

instance MonadQuote m => MonadQuote (SimplifierT name uni fun ann m)

runSimplifierT
  :: SimplifierT name uni fun ann m a
  -> m (a, SimplifierTrace name uni fun ann)
runSimplifierT :: forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> m (a, SimplifierTrace name uni fun ann)
runSimplifierT = (StateT (SimplifierTrace name uni fun ann) m a
 -> SimplifierTrace name uni fun ann
 -> m (a, SimplifierTrace name uni fun ann))
-> SimplifierTrace name uni fun ann
-> StateT (SimplifierTrace name uni fun ann) m a
-> m (a, SimplifierTrace name uni fun ann)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (SimplifierTrace name uni fun ann) m a
-> SimplifierTrace name uni fun ann
-> m (a, SimplifierTrace name uni fun ann)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
State.runStateT SimplifierTrace name uni fun ann
forall name (uni :: * -> *) fun a. SimplifierTrace name uni fun a
initSimplifierTrace (StateT (SimplifierTrace name uni fun ann) m a
 -> m (a, SimplifierTrace name uni fun ann))
-> (SimplifierT name uni fun ann m a
    -> StateT (SimplifierTrace name uni fun ann) m a)
-> SimplifierT name uni fun ann m a
-> m (a, SimplifierTrace name uni fun ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
getSimplifierT

evalSimplifierT
  :: Monad m => SimplifierT name uni fun ann m a -> m a
evalSimplifierT :: forall (m :: * -> *) name (uni :: * -> *) fun ann a.
Monad m =>
SimplifierT name uni fun ann m a -> m a
evalSimplifierT = (StateT (SimplifierTrace name uni fun ann) m a
 -> SimplifierTrace name uni fun ann -> m a)
-> SimplifierTrace name uni fun ann
-> StateT (SimplifierTrace name uni fun ann) m a
-> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (SimplifierTrace name uni fun ann) m a
-> SimplifierTrace name uni fun ann -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
State.evalStateT SimplifierTrace name uni fun ann
forall name (uni :: * -> *) fun a. SimplifierTrace name uni fun a
initSimplifierTrace (StateT (SimplifierTrace name uni fun ann) m a -> m a)
-> (SimplifierT name uni fun ann m a
    -> StateT (SimplifierTrace name uni fun ann) m a)
-> SimplifierT name uni fun ann m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
getSimplifierT

execSimplifierT
  :: Monad m => SimplifierT name uni fun ann m a -> m (SimplifierTrace name uni fun ann)
execSimplifierT :: forall (m :: * -> *) name (uni :: * -> *) fun ann a.
Monad m =>
SimplifierT name uni fun ann m a
-> m (SimplifierTrace name uni fun ann)
execSimplifierT = (StateT (SimplifierTrace name uni fun ann) m a
 -> SimplifierTrace name uni fun ann
 -> m (SimplifierTrace name uni fun ann))
-> SimplifierTrace name uni fun ann
-> StateT (SimplifierTrace name uni fun ann) m a
-> m (SimplifierTrace name uni fun ann)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (SimplifierTrace name uni fun ann) m a
-> SimplifierTrace name uni fun ann
-> m (SimplifierTrace name uni fun ann)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
State.execStateT SimplifierTrace name uni fun ann
forall name (uni :: * -> *) fun a. SimplifierTrace name uni fun a
initSimplifierTrace (StateT (SimplifierTrace name uni fun ann) m a
 -> m (SimplifierTrace name uni fun ann))
-> (SimplifierT name uni fun ann m a
    -> StateT (SimplifierTrace name uni fun ann) m a)
-> SimplifierT name uni fun ann m a
-> m (SimplifierTrace name uni fun ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> StateT (SimplifierTrace name uni fun ann) m a
getSimplifierT

type Simplifier name uni fun ann = SimplifierT name uni fun ann Identity

runSimplifier :: Simplifier name uni fun ann a -> (a, SimplifierTrace name uni fun ann)
runSimplifier :: forall name (uni :: * -> *) fun ann a.
Simplifier name uni fun ann a
-> (a, SimplifierTrace name uni fun ann)
runSimplifier = Identity (a, SimplifierTrace name uni fun ann)
-> (a, SimplifierTrace name uni fun ann)
forall a. Identity a -> a
runIdentity (Identity (a, SimplifierTrace name uni fun ann)
 -> (a, SimplifierTrace name uni fun ann))
-> (Simplifier name uni fun ann a
    -> Identity (a, SimplifierTrace name uni fun ann))
-> Simplifier name uni fun ann a
-> (a, SimplifierTrace name uni fun ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Simplifier name uni fun ann a
-> Identity (a, SimplifierTrace name uni fun ann)
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
SimplifierT name uni fun ann m a
-> m (a, SimplifierTrace name uni fun ann)
runSimplifierT

evalSimplifier :: Simplifier name uni fun ann a -> a
evalSimplifier :: forall name (uni :: * -> *) fun ann a.
Simplifier name uni fun ann a -> a
evalSimplifier = Identity a -> a
forall a. Identity a -> a
runIdentity (Identity a -> a)
-> (Simplifier name uni fun ann a -> Identity a)
-> Simplifier name uni fun ann a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Simplifier name uni fun ann a -> Identity a
forall (m :: * -> *) name (uni :: * -> *) fun ann a.
Monad m =>
SimplifierT name uni fun ann m a -> m a
evalSimplifierT

execSimplifier :: Simplifier name uni fun ann a -> SimplifierTrace name uni fun ann
execSimplifier :: forall name (uni :: * -> *) fun ann a.
Simplifier name uni fun ann a -> SimplifierTrace name uni fun ann
execSimplifier = Identity (SimplifierTrace name uni fun ann)
-> SimplifierTrace name uni fun ann
forall a. Identity a -> a
runIdentity (Identity (SimplifierTrace name uni fun ann)
 -> SimplifierTrace name uni fun ann)
-> (Simplifier name uni fun ann a
    -> Identity (SimplifierTrace name uni fun ann))
-> Simplifier name uni fun ann a
-> SimplifierTrace name uni fun ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Simplifier name uni fun ann a
-> Identity (SimplifierTrace name uni fun ann)
forall (m :: * -> *) name (uni :: * -> *) fun ann a.
Monad m =>
SimplifierT name uni fun ann m a
-> m (SimplifierTrace name uni fun ann)
execSimplifierT

data SimplifierStage
  = FloatDelay
  | ForceDelay
  | CaseOfCase
  | CaseReduce
  | Inline
  | CSE

data Simplification name uni fun a =
  Simplification
    { forall name (uni :: * -> *) fun a.
Simplification name uni fun a -> Term name uni fun a
beforeAST :: Term name uni fun a
    , forall name (uni :: * -> *) fun a.
Simplification name uni fun a -> SimplifierStage
stage     :: SimplifierStage
    , forall name (uni :: * -> *) fun a.
Simplification name uni fun a -> Term name uni fun a
afterAST  :: Term name uni fun a
    }

-- TODO2: we probably don't want this in memory so after MVP
-- we should consider serializing this to disk
newtype SimplifierTrace name uni fun a =
  SimplifierTrace
    { forall name (uni :: * -> *) fun a.
SimplifierTrace name uni fun a -> [Simplification name uni fun a]
simplifierTrace
      :: [Simplification name uni fun a]
    }

initSimplifierTrace :: SimplifierTrace name uni fun a
initSimplifierTrace :: forall name (uni :: * -> *) fun a. SimplifierTrace name uni fun a
initSimplifierTrace = [Simplification name uni fun a] -> SimplifierTrace name uni fun a
forall name (uni :: * -> *) fun a.
[Simplification name uni fun a] -> SimplifierTrace name uni fun a
SimplifierTrace []

recordSimplification
  :: Monad m
  => Term name uni fun a
  -> SimplifierStage
  -> Term name uni fun a
  -> SimplifierT name uni fun a m ()
recordSimplification :: forall (m :: * -> *) name (uni :: * -> *) fun a.
Monad m =>
Term name uni fun a
-> SimplifierStage
-> Term name uni fun a
-> SimplifierT name uni fun a m ()
recordSimplification Term name uni fun a
beforeAST SimplifierStage
stage Term name uni fun a
afterAST =
  let simplification :: Simplification name uni fun a
simplification = Simplification { Term name uni fun a
beforeAST :: Term name uni fun a
beforeAST :: Term name uni fun a
beforeAST, SimplifierStage
stage :: SimplifierStage
stage :: SimplifierStage
stage, Term name uni fun a
afterAST :: Term name uni fun a
afterAST :: Term name uni fun a
afterAST }
    in
      (SimplifierTrace name uni fun a -> SimplifierTrace name uni fun a)
-> SimplifierT name uni fun a m ()
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {ann}.
Monad m =>
(SimplifierTrace name uni fun ann
 -> SimplifierTrace name uni fun ann)
-> SimplifierT name uni fun ann m ()
modify ((SimplifierTrace name uni fun a -> SimplifierTrace name uni fun a)
 -> SimplifierT name uni fun a m ())
-> (SimplifierTrace name uni fun a
    -> SimplifierTrace name uni fun a)
-> SimplifierT name uni fun a m ()
forall a b. (a -> b) -> a -> b
$ \SimplifierTrace name uni fun a
st ->
        SimplifierTrace name uni fun a
st { simplifierTrace = simplification : simplifierTrace st }
  where
    modify :: (SimplifierTrace name uni fun ann
 -> SimplifierTrace name uni fun ann)
-> SimplifierT name uni fun ann m ()
modify SimplifierTrace name uni fun ann
-> SimplifierTrace name uni fun ann
f = StateT (SimplifierTrace name uni fun ann) m ()
-> SimplifierT name uni fun ann m ()
forall name (uni :: * -> *) fun ann (m :: * -> *) a.
StateT (SimplifierTrace name uni fun ann) m a
-> SimplifierT name uni fun ann m a
SimplifierT (StateT (SimplifierTrace name uni fun ann) m ()
 -> SimplifierT name uni fun ann m ())
-> StateT (SimplifierTrace name uni fun ann) m ()
-> SimplifierT name uni fun ann m ()
forall a b. (a -> b) -> a -> b
$ (SimplifierTrace name uni fun ann
 -> SimplifierTrace name uni fun ann)
-> StateT (SimplifierTrace name uni fun ann) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
State.modify' SimplifierTrace name uni fun ann
-> SimplifierTrace name uni fun ann
f