{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

{-| Drops redundant `unsafeCaseList` calls produced by AsData.

See Note [Dropping redundant unsafeCaseList calls produced by AsData]. -}
module PlutusIR.Transform.DeadCase
  ( deadCase
  , deadCasePass
  , deadCasePassSC
  ) where

import PlutusCore qualified as PLC
import PlutusCore.Annotation
import PlutusCore.Name.Unique
import PlutusIR
import PlutusIR.Analysis.Usages qualified as Usages
import PlutusIR.Pass
import PlutusIR.Transform.Rename ()
import PlutusIR.TypeCheck qualified as TC

import Control.Lens (transformOf)

deadCasePassSC
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, PLC.MonadQuote m, Ord a, AnnCase a)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
deadCasePassSC :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, MonadQuote m, Ord a, AnnCase a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
deadCasePassSC PirTCConfig uni fun
tcconfig =
  Pass m TyName Name uni fun a
forall name tyname (m :: * -> *) a (uni :: * -> *) fun.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadQuote m, Ord a) =>
Pass m tyname name uni fun a
renamePass Pass m TyName Name uni fun a
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a. Semigroup a => a -> a -> a
<> PirTCConfig uni fun -> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m, AnnCase a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
deadCasePass PirTCConfig uni fun
tcconfig

deadCasePass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m, AnnCase a)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
deadCasePass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m, AnnCase a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
deadCasePass PirTCConfig uni fun
tcconfig =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"eliminate dead cases" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun a -> Term TyName Name uni fun a
forall name a (uni :: * -> *) fun.
(HasUnique name TermUnique, AnnCase a) =>
Term TyName name uni fun a -> Term TyName name uni fun a
deadCase)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]

{-| Eliminate @Case@ expressions marked safe-to-eliminate whose branch
binders are all dead. Uses a bottom-up traversal so that inner
eliminations cascade outward in a single pass. -}
deadCase
  :: (HasUnique name TermUnique, AnnCase a)
  => Term TyName name uni fun a
  -> Term TyName name uni fun a
deadCase :: forall name a (uni :: * -> *) fun.
(HasUnique name TermUnique, AnnCase a) =>
Term TyName name uni fun a -> Term TyName name uni fun a
deadCase = ASetter
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
-> (Term TyName name uni fun a -> Term TyName name uni fun a)
-> Term TyName name uni fun a
-> Term TyName name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
  (Term TyName name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term TyName name uni fun a -> Term TyName name uni fun a
forall name a (uni :: * -> *) fun.
(HasUnique name TermUnique, AnnCase a) =>
Term TyName name uni fun a -> Term TyName name uni fun a
processTerm

processTerm
  :: (HasUnique name TermUnique, AnnCase a)
  => Term TyName name uni fun a
  -> Term TyName name uni fun a
processTerm :: forall name a (uni :: * -> *) fun.
(HasUnique name TermUnique, AnnCase a) =>
Term TyName name uni fun a -> Term TyName name uni fun a
processTerm = \case
  Case a
a Type TyName uni a
_resTy Term TyName name uni fun a
_scrut [LamAbs a
_ name
x Type TyName uni a
_ (LamAbs a
_ name
y Type TyName uni a
_ Term TyName name uni fun a
body)]
    | a -> Bool
forall a. AnnCase a => a -> Bool
annIsSafeToDrop a
a
    , let usages :: Usages
usages = Term TyName name uni fun a -> Usages
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
Term tyname name uni fun a -> Usages
Usages.termUsages Term TyName name uni fun a
body
    , name -> Usages -> Int
forall n unique. HasUnique n unique => n -> Usages -> Int
Usages.getUsageCount name
x Usages
usages Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    , name -> Usages -> Int
forall n unique. HasUnique n unique => n -> Usages -> Int
Usages.getUsageCount name
y Usages
usages Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
        Term TyName name uni fun a
body
  Term TyName name uni fun a
other -> Term TyName name uni fun a
other