{-# LANGUAGE LambdaCase #-}
{-|
A trivial simplification that cancels unwrap/wrap pairs.

This can only occur if we've inlined both datatype constructors and destructors
and we're deconstructing something we just constructed. This is probably rare,
and should anyway better be handled by something like case-of-known constructor.
But it's so simple we might as well include it just in case.
-}
module PlutusIR.Transform.Unwrap (
  unwrapCancel,
  unwrapCancelPass
  ) where

import PlutusIR

import Control.Lens (transformOf)
import PlutusCore qualified as PLC
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

{-|
A single non-recursive application of wrap/unwrap cancellation.
-}
unwrapCancelStep
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
unwrapCancelStep :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
unwrapCancelStep = \case
    Unwrap a
_ (IWrap a
_ Type tyname uni a
_ Type tyname uni a
_ Term tyname name uni fun a
b) -> Term tyname name uni fun a
b
    Term tyname name uni fun a
t                        -> Term tyname name uni fun a
t

{-|
Recursively apply wrap/unwrap cancellation.
-}
unwrapCancel
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
unwrapCancel :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
unwrapCancel = ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
unwrapCancelStep

unwrapCancelPass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
unwrapCancelPass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
unwrapCancelPass PirTCConfig uni fun
tcconfig = String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m) =>
String
-> PirTCConfig uni fun
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Pass m TyName Name uni fun a
simplePass String
"wrap-unwrap cancel" PirTCConfig uni fun
tcconfig Term TyName Name uni fun a -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
unwrapCancel