{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs             #-}
{-# LANGUAGE RankNTypes        #-}
{-# LANGUAGE TypeOperators     #-}
module PlutusIR.Transform.RewriteRules
    ( rewriteWith
    , rewritePass
    , rewritePassSC
    , RewriteRules
    , unRewriteRules
    , defaultUniRewriteRules
    ) where

import PlutusCore qualified as PLC
import PlutusCore.Core (HasUniques)
import PlutusCore.Name.Unique
import PlutusCore.Quote
import PlutusIR as PIR
import PlutusIR.Analysis.VarInfo
import PlutusIR.Transform.RewriteRules.Internal

import Control.Lens
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

rewritePassSC ::
    forall m uni fun a.
    ( PLC.Typecheckable uni fun, PLC.GEq uni, Ord a
    , PLC.MonadQuote m, Monoid a
    ) =>
    TC.PirTCConfig uni fun ->
    RewriteRules uni fun ->
    Pass m TyName Name uni fun a
rewritePassSC :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, MonadQuote m, Monoid a) =>
PirTCConfig uni fun
-> RewriteRules uni fun -> Pass m TyName Name uni fun a
rewritePassSC PirTCConfig uni fun
tcconfig RewriteRules uni fun
rules =
    Pass m TyName Name uni fun a
forall name tyname (m :: * -> *) a (uni :: * -> *) fun.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadQuote m, Ord a) =>
Pass m tyname name uni fun a
renamePass Pass m TyName Name uni fun a
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a. Semigroup a => a -> a -> a
<> PirTCConfig uni fun
-> RewriteRules uni fun -> Pass m TyName Name uni fun a
forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, MonadQuote m, Monoid a) =>
PirTCConfig uni fun
-> RewriteRules uni fun -> Pass m TyName Name uni fun a
rewritePass PirTCConfig uni fun
tcconfig RewriteRules uni fun
rules

rewritePass ::
    forall m uni fun a.
    ( PLC.Typecheckable uni fun, PLC.GEq uni, Ord a
    , PLC.MonadQuote m, Monoid a
    ) =>
    TC.PirTCConfig uni fun ->
    RewriteRules uni fun ->
    Pass m TyName Name uni fun a
rewritePass :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, MonadQuote m, Monoid a) =>
PirTCConfig uni fun
-> RewriteRules uni fun -> Pass m TyName Name uni fun a
rewritePass PirTCConfig uni fun
tcconfig RewriteRules uni fun
rules =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"rewrite rules" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (RewriteRules uni fun
-> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a t tyname (uni :: * -> *) fun (m :: * -> *).
(Monoid a, t ~ Term tyname Name uni fun a, HasUniques t,
 MonadQuote m) =>
RewriteRules uni fun -> t -> m t
rewriteWith RewriteRules uni fun
rules)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig, Condition TyName Name uni fun a
forall tyname name a (uni :: * -> *) fun.
(HasUnique tyname TypeUnique, HasUnique name TermUnique, Ord a) =>
Condition tyname name uni fun a
GloballyUniqueNames]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]

{- | Rewrite a `Term` using the given `RewriteRules` (similar to functions of Term -> Term)
Normally the rewrite rules are configured at entrypoint time of the compiler.

It goes without saying that the supplied rewrite rules must be type-preserving.
MAYBE: enforce this with a `through typeCheckTerm`?
-}
rewriteWith :: ( Monoid a, t ~ Term tyname Name uni fun a
              , HasUniques t
              , MonadQuote m
              )
            => RewriteRules uni fun
            -> t
            -> m t
rewriteWith :: forall a t tyname (uni :: * -> *) fun (m :: * -> *).
(Monoid a, t ~ Term tyname Name uni fun a, HasUniques t,
 MonadQuote m) =>
RewriteRules uni fun -> t -> m t
rewriteWith RewriteRules uni fun
rules t
t =
    -- We collect `VarsInfo` on the whole program term and pass it on as arg to each RewriteRule.
    -- This has the limitation that any variables newly-introduced by the rules would
    -- not be accounted in `VarsInfo`. This is currently fine, because we only rely on VarsInfo
    -- for isPure; isPure is safe w.r.t "open" terms.
    let vinfo :: VarsInfo tyname Name uni a
vinfo = Term tyname Name uni fun a -> VarsInfo tyname Name uni a
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
Term tyname name uni fun a -> VarsInfo tyname name uni a
termVarInfo t
Term tyname Name uni fun a
t
    in LensLike
  (WrappedMonad m)
  (Term tyname Name uni fun a)
  t
  (Term tyname Name uni fun a)
  t
-> (t -> m t) -> Term tyname Name uni fun a -> m t
forall (m :: * -> *) a b.
Monad m =>
LensLike (WrappedMonad m) a b a b -> (b -> m b) -> a -> m b
transformMOf LensLike
  (WrappedMonad m)
  (Term tyname Name uni fun a)
  t
  (Term tyname Name uni fun a)
  t
(Term tyname Name uni fun a
 -> WrappedMonad m (Term tyname Name uni fun a))
-> Term tyname Name uni fun a
-> WrappedMonad m (Term tyname Name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms (RewriteRules uni fun
-> forall tyname (m :: * -> *) a.
   (MonadQuote m, Monoid a) =>
   VarsInfo tyname Name uni a
   -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
forall (uni :: * -> *) fun.
RewriteRules uni fun
-> forall tyname (m :: * -> *) a.
   (MonadQuote m, Monoid a) =>
   VarsInfo tyname Name uni a
   -> Term tyname Name uni fun a -> m (Term tyname Name uni fun a)
unRewriteRules RewriteRules uni fun
rules VarsInfo tyname Name uni a
vinfo) t
Term tyname Name uni fun a
t