{-# LANGUAGE NamedFieldPuns #-}

module PlutusIR.Transform.KnownCon (knownCon, knownConPass, knownConPassSC) where

import PlutusCore qualified as PLC
import PlutusCore.Name.Unique qualified as PLC
import PlutusIR
import PlutusIR.Contexts
import PlutusIR.Transform.Rename ()

import Control.Lens hiding (cons)
import Data.List.Extra qualified as List
import PlutusIR.Analysis.VarInfo
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

knownConPassSC ::
    forall m uni fun a.
    ( PLC.Typecheckable uni fun, PLC.GEq uni, Ord a
    , PLC.MonadQuote m
    )
    => TC.PirTCConfig uni fun
    -> Pass m TyName Name uni fun a
knownConPassSC :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, MonadQuote m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
knownConPassSC PirTCConfig uni fun
tcconfig = Pass m TyName Name uni fun a
forall name tyname (m :: * -> *) a (uni :: * -> *) fun.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadQuote m, Ord a) =>
Pass m tyname name uni fun a
renamePass Pass m TyName Name uni fun a
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a. Semigroup a => a -> a -> a
<> PirTCConfig uni fun -> Pass m TyName Name uni fun a
forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
knownConPass PirTCConfig uni fun
tcconfig

knownConPass ::
    forall m uni fun a.
    ( PLC.Typecheckable uni fun, PLC.GEq uni, Ord a
    , Applicative m)
    => TC.PirTCConfig uni fun
    -> Pass m TyName Name uni fun a
knownConPass :: forall (m :: * -> *) (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni, Ord a, Applicative m) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
knownConPass PirTCConfig uni fun
tcconfig =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"case of known constructor" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun a -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 Eq name) =>
Term tyname name uni fun a -> Term tyname name uni fun a
knownCon)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig, Condition TyName Name uni fun a
forall tyname name a (uni :: * -> *) fun.
(HasUnique tyname TypeUnique, HasUnique name TermUnique, Ord a) =>
Condition tyname name uni fun a
GloballyUniqueNames]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]

{- | Simplify destructor applications, if the scrutinee is a constructor application.

As an example, given

@
    Maybe_match
      {x_type}
      (Just {x_type} x)
      {result_type}
      (\a -> <just_case_body : result_type>)
      (<nothing_case_body : result_type>)
      additional_args
@

`knownCon` turns it into

@
    (\a -> <just_case_body>) x additional_args
@
-}
knownCon ::
    forall tyname name uni fun a.
    ( PLC.HasUnique name PLC.TermUnique
    , PLC.HasUnique tyname PLC.TypeUnique
    , Eq name
    ) =>
    Term tyname name uni fun a ->
    Term tyname name uni fun a
knownCon :: forall tyname name (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 Eq name) =>
Term tyname name uni fun a -> Term tyname name uni fun a
knownCon Term tyname name uni fun a
t =
    let vinfo :: VarsInfo tyname name uni a
vinfo = Term tyname name uni fun a -> VarsInfo tyname name uni a
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
Term tyname name uni fun a -> VarsInfo tyname name uni a
termVarInfo Term tyname name uni fun a
t
    in ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall a b. ASetter a b a b -> (b -> b) -> a -> b
transformOf ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms (VarsInfo tyname name uni a
-> Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
(Eq name, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
VarsInfo tyname name uni a
-> Term tyname name uni fun a -> Term tyname name uni fun a
processTerm VarsInfo tyname name uni a
vinfo) Term tyname name uni fun a
t

processTerm ::
    forall tyname name uni fun a .
    (Eq name
    , PLC.HasUnique name PLC.TermUnique
    , PLC.HasUnique tyname PLC.TypeUnique) =>
    VarsInfo tyname name uni a ->
    Term tyname name uni fun a ->
    Term tyname name uni fun a
processTerm :: forall tyname name (uni :: * -> *) fun a.
(Eq name, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
VarsInfo tyname name uni a
-> Term tyname name uni fun a -> Term tyname name uni fun a
processTerm VarsInfo tyname name uni a
vinfo Term tyname name uni fun a
t
    | (Var a
_ name
n, AppContext tyname name uni fun a
args) <- Term tyname name uni fun a
-> (Term tyname name uni fun a, AppContext tyname name uni fun a)
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> (Term tyname name uni fun ann,
    AppContext tyname name uni fun ann)
splitApplication Term tyname name uni fun a
t
    , Just (DatatypeMatcher tyname
parentName) <- name
-> VarsInfo tyname name uni a -> Maybe (VarInfo tyname name uni a)
forall name tyname (uni :: * -> *) a.
HasUnique name TermUnique =>
name
-> VarsInfo tyname name uni a -> Maybe (VarInfo tyname name uni a)
lookupVarInfo name
n VarsInfo tyname name uni a
vinfo
    , Just (DatatypeTyVar (Datatype a
_ TyVarDecl tyname a
_ [TyVarDecl tyname a]
tvs name
_ [VarDecl tyname name uni a]
constructors) ) <- tyname
-> VarsInfo tyname name uni a
-> Maybe (TyVarInfo tyname name uni a)
forall tyname name (uni :: * -> *) a.
HasUnique tyname TypeUnique =>
tyname
-> VarsInfo tyname name uni a
-> Maybe (TyVarInfo tyname name uni a)
lookupTyVarInfo tyname
parentName VarsInfo tyname name uni a
vinfo
    , (TermAppContext Term tyname name uni fun a
scrut a
_ (TypeAppContext Type tyname uni a
_resTy a
_ AppContext tyname name uni fun a
branchArgs)) <-
        -- The datatype may have some type arguments, we
        -- aren't interested in them, so we drop them.
        Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
dropAppContext ([TyVarDecl tyname a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarDecl tyname a]
tvs) AppContext tyname name uni fun a
args
    , -- The scrutinee is itself an application
      (Var a
_ name
con, AppContext tyname name uni fun a
conArgs) <- Term tyname name uni fun a
-> (Term tyname name uni fun a, AppContext tyname name uni fun a)
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> (Term tyname name uni fun ann,
    AppContext tyname name uni fun ann)
splitApplication Term tyname name uni fun a
scrut
    , -- ... of one of the constructors from the same datatype as the destructor
      Just Int
i <- (name -> Bool) -> [name] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
List.findIndex (name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
con) ((VarDecl tyname name uni a -> name)
-> [VarDecl tyname name uni a] -> [name]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VarDecl tyname name uni a -> name
forall tyname name (uni :: * -> *) ann.
VarDecl tyname name uni ann -> name
_varDeclName [VarDecl tyname name uni a]
constructors)
    , -- ... and there is a  branch for that constructor in the destructor application
      (TermAppContext Term tyname name uni fun a
branch a
_ AppContext tyname name uni fun a
_) <- Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
dropAppContext Int
i AppContext tyname name uni fun a
branchArgs
    , -- This condition ensures the destructor is fully-applied
      -- (which should always be the case in programs that come from Plutus Tx,
      -- but not necessarily in arbitrary PIR programs).
      AppContext tyname name uni fun a -> Int
forall tyname name (uni :: * -> *) fun a.
AppContext tyname name uni fun a -> Int
lengthContext AppContext tyname name uni fun a
branchArgs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VarDecl tyname name uni a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarDecl tyname name uni a]
constructors =
        Term tyname name uni fun a
-> AppContext tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> AppContext tyname name uni fun ann
-> Term tyname name uni fun ann
fillAppContext
            Term tyname name uni fun a
branch
            -- The arguments to the selected branch consists of the arguments
            -- to the constructor, without the leading type arguments - e.g.,
            -- if the scrutinee is `Just {integer} 1`, we only need the `1`).
            (Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Int
-> AppContext tyname name uni fun a
-> AppContext tyname name uni fun a
dropAppContext ([TyVarDecl tyname a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarDecl tyname a]
tvs) AppContext tyname name uni fun a
conArgs)
    | Bool
otherwise = Term tyname name uni fun a
t