{-# LANGUAGE GADTs            #-}
{-# LANGUAGE RankNTypes       #-}
{-# LANGUAGE TypeApplications #-}

module UntypedPlutusCore.Simplify (
    module Opts,
    simplifyTerm,
    simplifyProgram,
    InlineHints (..),
) where

import PlutusCore.Compiler.Types
import PlutusCore.Default qualified as PLC
import PlutusCore.Default.Builtins
import PlutusCore.Name.Unique
import UntypedPlutusCore.Core.Type
import UntypedPlutusCore.Simplify.Opts as Opts
import UntypedPlutusCore.Transform.CaseOfCase
import UntypedPlutusCore.Transform.CaseReduce
import UntypedPlutusCore.Transform.Cse
import UntypedPlutusCore.Transform.FloatDelay (floatDelay)
import UntypedPlutusCore.Transform.ForceDelay (forceDelay)
import UntypedPlutusCore.Transform.Inline (InlineHints (..), inline)

import Control.Monad
import Control.Monad.State.Class (MonadState)
import Control.Monad.State.Class qualified as State
import Data.List as List (foldl')
import Data.Typeable

simplifyProgram ::
    forall name uni fun m a.
    (Compiling m uni fun name a
    , MonadState (UPLCSimplifierTrace name uni fun a) m
    ) =>
    SimplifyOpts name a ->
    BuiltinSemanticsVariant fun ->
    Program name uni fun a ->
    m (Program name uni fun a)
simplifyProgram :: forall name (uni :: * -> *) fun (m :: * -> *) a.
(Compiling m uni fun name a,
 MonadState (UPLCSimplifierTrace name uni fun a) m) =>
SimplifyOpts name a
-> BuiltinSemanticsVariant fun
-> Program name uni fun a
-> m (Program name uni fun a)
simplifyProgram SimplifyOpts name a
opts BuiltinSemanticsVariant fun
builtinSemanticsVariant (Program a
a Version
v Term name uni fun a
t) =
  a -> Version -> Term name uni fun a -> Program name uni fun a
forall name (uni :: * -> *) fun ann.
ann -> Version -> Term name uni fun ann -> Program name uni fun ann
Program a
a Version
v (Term name uni fun a -> Program name uni fun a)
-> m (Term name uni fun a) -> m (Program name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimplifyOpts name a
-> BuiltinSemanticsVariant fun
-> Term name uni fun a
-> m (Term name uni fun a)
forall name (uni :: * -> *) fun (m :: * -> *) a.
(Compiling m uni fun name a,
 MonadState (UPLCSimplifierTrace name uni fun a) m) =>
SimplifyOpts name a
-> BuiltinSemanticsVariant fun
-> Term name uni fun a
-> m (Term name uni fun a)
simplifyTerm SimplifyOpts name a
opts BuiltinSemanticsVariant fun
builtinSemanticsVariant Term name uni fun a
t

simplifyTerm ::
    forall name uni fun m a.
    ( Compiling m uni fun name a
    , MonadState (UPLCSimplifierTrace name uni fun a) m
    ) =>
    SimplifyOpts name a ->
    BuiltinSemanticsVariant fun ->
    Term name uni fun a ->
    m (Term name uni fun a)
simplifyTerm :: forall name (uni :: * -> *) fun (m :: * -> *) a.
(Compiling m uni fun name a,
 MonadState (UPLCSimplifierTrace name uni fun a) m) =>
SimplifyOpts name a
-> BuiltinSemanticsVariant fun
-> Term name uni fun a
-> m (Term name uni fun a)
simplifyTerm SimplifyOpts name a
opts BuiltinSemanticsVariant fun
builtinSemanticsVariant =
    Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyNTimes (SimplifyOpts name a -> Int
forall name a. SimplifyOpts name a -> Int
_soMaxSimplifierIterations SimplifyOpts name a
opts) (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Int -> Term name uni fun a -> m (Term name uni fun a)
cseNTimes Int
cseTimes
  where
    -- Run the simplifier @n@ times
    simplifyNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
    simplifyNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyNTimes Int
n = ((Term name uni fun a -> m (Term name uni fun a))
 -> (Term name uni fun a -> m (Term name uni fun a))
 -> Term name uni fun a
 -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> [Term name uni fun a -> m (Term name uni fun a)]
-> Term name uni fun a
-> m (Term name uni fun a)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
(>=>) Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Term name uni fun a -> m (Term name uni fun a)]
 -> Term name uni fun a -> m (Term name uni fun a))
-> [Term name uni fun a -> m (Term name uni fun a)]
-> Term name uni fun a
-> m (Term name uni fun a)
forall a b. (a -> b) -> a -> b
$ (Int -> Term name uni fun a -> m (Term name uni fun a))
-> [Int] -> [Term name uni fun a -> m (Term name uni fun a)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyStep [Int
1..Int
n]

    -- Run CSE @n@ times, interleaved with the simplifier.
    -- See Note [CSE]
    cseNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
    cseNTimes :: Int -> Term name uni fun a -> m (Term name uni fun a)
cseNTimes Int
n = ((Term name uni fun a -> m (Term name uni fun a))
 -> (Term name uni fun a -> m (Term name uni fun a))
 -> Term name uni fun a
 -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> [Term name uni fun a -> m (Term name uni fun a)]
-> Term name uni fun a
-> m (Term name uni fun a)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
(>=>) Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Term name uni fun a -> m (Term name uni fun a)]
 -> Term name uni fun a -> m (Term name uni fun a))
-> [Term name uni fun a -> m (Term name uni fun a)]
-> Term name uni fun a
-> m (Term name uni fun a)
forall a b. (a -> b) -> a -> b
$ (Int -> [Term name uni fun a -> m (Term name uni fun a)])
-> [Int] -> [Term name uni fun a -> m (Term name uni fun a)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\Int
i -> [Int -> Term name uni fun a -> m (Term name uni fun a)
cseStep Int
i, Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyStep Int
i]) [Int
1..Int
n]

    -- generate simplification step
    simplifyStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
    simplifyStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
simplifyStep Int
_ =
      Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall (m :: * -> *) name (uni :: * -> *) fun a.
(MonadQuote m, Rename (Term name uni fun a),
 HasUnique name TermUnique) =>
Term name uni fun a -> m (Term name uni fun a)
floatDelay
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> m (Term name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun a.
Term name uni fun a -> Term name uni fun a
forceDelay
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> m (Term name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Term name uni fun a
caseOfCase'
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> Term name uni fun a)
-> Term name uni fun a
-> m (Term name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term name uni fun a -> Term name uni fun a
forall name (uni :: * -> *) fun a.
Term name uni fun a -> Term name uni fun a
caseReduce
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Bool
-> InlineHints name a
-> BuiltinSemanticsVariant fun
-> Term name uni fun a
-> m (Term name uni fun a)
forall name (uni :: * -> *) fun (m :: * -> *) a.
ExternalConstraints name uni fun m =>
Bool
-> InlineHints name a
-> BuiltinSemanticsVariant fun
-> Term name uni fun a
-> m (Term name uni fun a)
inline (SimplifyOpts name a -> Bool
forall name a. SimplifyOpts name a -> Bool
_soInlineConstants SimplifyOpts name a
opts) (SimplifyOpts name a -> InlineHints name a
forall name a. SimplifyOpts name a -> InlineHints name a
_soInlineHints SimplifyOpts name a
opts) BuiltinSemanticsVariant fun
builtinSemanticsVariant
        (Term name uni fun a -> m (Term name uni fun a))
-> (Term name uni fun a -> m (Term name uni fun a))
-> Term name uni fun a
-> m (Term name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Term name uni fun a -> m (Term name uni fun a)
forall {m :: * -> *} {name} {uni :: * -> *} {fun} {a}.
MonadState (UPLCSimplifierTrace name uni fun a) m =>
Term name uni fun a -> m (Term name uni fun a)
traceAST

    caseOfCase' :: Term name uni fun a -> Term name uni fun a
    caseOfCase' :: Term name uni fun a -> Term name uni fun a
caseOfCase' = case forall {k} (a :: k) (b :: k).
(Typeable a, Typeable b) =>
Maybe (a :~: b)
forall a b. (Typeable a, Typeable b) => Maybe (a :~: b)
eqT @fun @DefaultFun of
      Just fun :~: DefaultFun
Refl -> Term name uni fun a -> Term name uni fun a
forall fun name (uni :: * -> *) a.
(fun ~ DefaultFun) =>
Term name uni fun a -> Term name uni fun a
caseOfCase
      Maybe (fun :~: DefaultFun)
Nothing   -> Term name uni fun a -> Term name uni fun a
forall a. a -> a
id

    cseStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
    cseStep :: Int -> Term name uni fun a -> m (Term name uni fun a)
cseStep Int
_ =
      case (forall {k} (a :: k) (b :: k).
(Typeable a, Typeable b) =>
Maybe (a :~: b)
forall a b. (Typeable a, Typeable b) => Maybe (a :~: b)
eqT @name @Name, forall {k} (a :: k) (b :: k).
(Typeable a, Typeable b) =>
Maybe (a :~: b)
forall (a :: * -> *) (b :: * -> *).
(Typeable a, Typeable b) =>
Maybe (a :~: b)
eqT @uni @PLC.DefaultUni) of
        (Just name :~: Name
Refl, Just uni :~: DefaultUni
Refl) -> BuiltinSemanticsVariant fun
-> Term Name uni fun a -> m (Term Name uni fun a)
forall (m :: * -> *) (uni :: * -> *) fun ann.
(MonadQuote m, Hashable (Term Name uni fun ()),
 Rename (Term Name uni fun ann), ToBuiltinMeaning uni fun) =>
BuiltinSemanticsVariant fun
-> Term Name uni fun ann -> m (Term Name uni fun ann)
cse BuiltinSemanticsVariant fun
builtinSemanticsVariant
        (Maybe (name :~: Name), Maybe (uni :~: DefaultUni))
_                      -> Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

    traceAST :: Term name uni fun a -> m (Term name uni fun a)
traceAST Term name uni fun a
ast = do
      (UPLCSimplifierTrace name uni fun a
 -> UPLCSimplifierTrace name uni fun a)
-> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
State.modify' (\UPLCSimplifierTrace name uni fun a
st -> UPLCSimplifierTrace name uni fun a
st { uplcSimplifierTrace = uplcSimplifierTrace st ++ [ast] })
      Term name uni fun a -> m (Term name uni fun a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Term name uni fun a
ast

    cseTimes :: Int
cseTimes = if SimplifyOpts name a -> Bool
forall name a. SimplifyOpts name a -> Bool
_soConservativeOpts SimplifyOpts name a
opts then Int
0 else SimplifyOpts name a -> Int
forall name a. SimplifyOpts name a -> Int
_soMaxCseIterations SimplifyOpts name a
opts