{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs             #-}
{-# LANGUAGE RankNTypes        #-}

module PlutusIR.Transform.RewriteRules.Internal
  ( RewriteRules (..)
  , defaultUniRewriteRules
  ) where

import PlutusCore.Default (DefaultFun, DefaultUni)
import PlutusCore.Name.Unique (Name)
import PlutusCore.Quote (MonadQuote)
import PlutusIR.Analysis.VarInfo (VarsInfo)
import PlutusIR.Core.Type qualified as PIR
import PlutusIR.Transform.RewriteRules.CommuteFnWithConst (commuteFnWithConst)
import PlutusIR.Transform.RewriteRules.UnConstrConstrData (unConstrConstrData)
import PlutusPrelude (Default (..), (>=>))

-- | A bundle of composed `RewriteRules`, to be passed at entrypoint of the compiler.
newtype RewriteRules uni fun where
  RewriteRules
    :: { forall (uni :: * -> *) fun.
RewriteRules uni fun
-> forall tyname (m :: * -> *) a.
   (MonadQuote m, Monoid a) =>
   VarsInfo tyname Name uni a
   -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
unRewriteRules
          :: forall tyname m a
           . (MonadQuote m, Monoid a)
          => VarsInfo tyname Name uni a
          -> PIR.Term tyname Name uni fun a
          -> m (PIR.Term tyname Name uni fun a)
       }
    -> RewriteRules uni fun

-- | The rules for the Default Universe/Builtin.
defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules = (forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name DefaultUni a
 -> Term tyname Name DefaultUni DefaultFun a
 -> m (Term tyname Name DefaultUni DefaultFun a))
-> RewriteRules DefaultUni DefaultFun
forall (uni :: * -> *) fun.
(forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name uni a
 -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> RewriteRules uni fun
RewriteRules ((forall tyname (m :: * -> *) a.
  (MonadQuote m, Monoid a) =>
  VarsInfo tyname Name DefaultUni a
  -> Term tyname Name DefaultUni DefaultFun a
  -> m (Term tyname Name DefaultUni DefaultFun a))
 -> RewriteRules DefaultUni DefaultFun)
-> (forall tyname (m :: * -> *) a.
    (MonadQuote m, Monoid a) =>
    VarsInfo tyname Name DefaultUni a
    -> Term tyname Name DefaultUni DefaultFun a
    -> m (Term tyname Name DefaultUni DefaultFun a))
-> RewriteRules DefaultUni DefaultFun
forall a b. (a -> b) -> a -> b
$ \VarsInfo tyname Name DefaultUni a
varsInfo ->
  -- The rules are composed from left to right.
  Term tyname Name DefaultUni DefaultFun a
-> m (Term tyname Name DefaultUni DefaultFun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term tyname Name DefaultUni DefaultFun a
 -> m (Term tyname Name DefaultUni DefaultFun a))
-> (Term tyname Name DefaultUni DefaultFun a
    -> Term tyname Name DefaultUni DefaultFun a)
-> Term tyname Name DefaultUni DefaultFun a
-> m (Term tyname Name DefaultUni DefaultFun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term tyname Name DefaultUni DefaultFun a
-> Term tyname Name DefaultUni DefaultFun a
forall t tyname name (uni :: * -> *) a.
(t ~ Term tyname name uni DefaultFun a) =>
t -> t
commuteFnWithConst (Term tyname Name DefaultUni DefaultFun a
 -> m (Term tyname Name DefaultUni DefaultFun a))
-> (Term tyname Name DefaultUni DefaultFun a
    -> m (Term tyname Name DefaultUni DefaultFun a))
-> Term tyname Name DefaultUni DefaultFun a
-> m (Term tyname Name DefaultUni DefaultFun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> BuiltinsInfo DefaultUni DefaultFun
-> VarsInfo tyname Name DefaultUni a
-> Term tyname Name DefaultUni DefaultFun a
-> m (Term tyname Name DefaultUni DefaultFun a)
forall (m :: * -> *) t tyname a.
(MonadQuote m, t ~ Term tyname Name DefaultUni DefaultFun a,
 Monoid a) =>
BuiltinsInfo DefaultUni DefaultFun
-> VarsInfo tyname Name DefaultUni a -> t -> m t
unConstrConstrData BuiltinsInfo DefaultUni DefaultFun
forall a. Default a => a
def VarsInfo tyname Name DefaultUni a
varsInfo

instance Default (RewriteRules DefaultUni DefaultFun) where
  def :: RewriteRules DefaultUni DefaultFun
def = RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules

instance Semigroup (RewriteRules uni fun) where
  RewriteRules forall tyname (m :: * -> *) a.
(MonadQuote m, Monoid a) =>
VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
r1 <> :: RewriteRules uni fun
-> RewriteRules uni fun -> RewriteRules uni fun
<> RewriteRules forall tyname (m :: * -> *) a.
(MonadQuote m, Monoid a) =>
VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
r2 = (forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name uni a
 -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> RewriteRules uni fun
forall (uni :: * -> *) fun.
(forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name uni a
 -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> RewriteRules uni fun
RewriteRules (\VarsInfo tyname Name uni a
varsInfo -> VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
forall tyname (m :: * -> *) a.
(MonadQuote m, Monoid a) =>
VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
r1 VarsInfo tyname Name uni a
varsInfo (Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> (Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> Term tyname Name uni fun a
-> m (Term tyname Name uni fun a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
forall tyname (m :: * -> *) a.
(MonadQuote m, Monoid a) =>
VarsInfo tyname Name uni a
-> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
r2 VarsInfo tyname Name uni a
varsInfo)

instance Monoid (RewriteRules uni fun) where
  mempty :: RewriteRules uni fun
mempty = (forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name uni a
 -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> RewriteRules uni fun
forall (uni :: * -> *) fun.
(forall tyname (m :: * -> *) a.
 (MonadQuote m, Monoid a) =>
 VarsInfo tyname Name uni a
 -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> RewriteRules uni fun
RewriteRules ((Term tyname Name uni fun a -> m (Term tyname Name uni fun a))
-> VarsInfo tyname Name uni a
-> Term tyname Name uni fun a
-> m (Term tyname Name uni fun a)
forall a b. a -> b -> a
const Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)