{-# LANGUAGE LambdaCase   #-}
{-# LANGUAGE ViewPatterns #-}
{-|
A simple beta-reduction pass.
-}
module PlutusIR.Transform.Beta (
  beta,
  betaPass,
  betaPassSC
  ) where

import Control.Lens (over)
import Data.List.NonEmpty qualified as NE
import PlutusCore qualified as PLC
import PlutusIR
import PlutusIR.Pass
import PlutusIR.TypeCheck qualified as TC

{- Note [Multi-beta]
Consider two examples where applying beta should be helpful.

1: [(\x . [(\y . t) b]) a]
2: [[(\x . (\y . t)) a] b]

(1) is the typical "let-binding" pattern: each binding corresponds to an immediately-applied lambda.
(2) is the typical "function application" pattern: a multi-argument function applied to multiple
arguments.

In both cases we would like to produce something like

let
  x = a
  y = b
in t

However, if we naively do a bottom-up pattern-matching transformation on the AST
to look for immediately-applied lambda abstractions then we will get the following:

1:
  [(\x . [(\y . t) b]) a]
  -->
  [(\x . let y = b in t) a]
  ->
  let x = a in let y = b in t

2:
  [[(\x . (\y . t)) a] b]
  -->
  [(let x = a in (\y . t)) b]

Now, if we later lift the let out, then we will be able to see that we can transform (2) further.
But that means that
a) we'd have to do the expensive let-floating pass in every iteration of the
simplifier, and
b) we can only inline one function argument per iteration of the  simplifier, so for a function of
arity N we *must* do at least N passes.

This isn't great, so the solution is to recognize case (2) properly and handle all the arguments in
one go. That will also match cases like (1) just fine, since it's just made up of unary function
applications.

That does mean that we need to do a manual traversal rather than doing standard bottom-up
processing.

Note that multi-beta requires globally unique names. In the example above, we end up with
the binding for `x` outside `b`, which means it could shadow an existing `x` binding in the
environment.

Note that multi-beta cannot be used on TypeBinds. For instance, it is unsound to turn

(/\a \(b : a). t) {x} (y : x)

into

let a = x in let b = y in t

because in order to check that `b` and `y` have the same type, we need to know that `a = x`,
but we don't - type-lets are opaque inside their bodies.
-}

{-| Extract the list of bindings from a term, a bit like a "multi-beta" reduction.

Some examples will help:

[(\x . t) a] -> Just ([x |-> a], t)

[[[(\x . (\y . (\z . t))) a] b] c] -> Just ([x |-> a, y |-> b, z |-> c]) t)

[[(\x . t) a] b] -> Nothing
-}
extractBindings ::
  Term tyname name uni fun a
  -> Maybe (NE.NonEmpty (Binding tyname name uni fun a), Term tyname name uni fun a)
extractBindings :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
extractBindings = [Term tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall {tyname} {name} {uni :: * -> *} {fun} {a}.
[Term tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
collectArgs []
  where
      collectArgs :: [Term tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
collectArgs [Term tyname name uni fun a]
argStack (Apply a
_ Term tyname name uni fun a
f Term tyname name uni fun a
arg) = [Term tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
collectArgs (Term tyname name uni fun a
argTerm tyname name uni fun a
-> [Term tyname name uni fun a] -> [Term tyname name uni fun a]
forall a. a -> [a] -> [a]
:[Term tyname name uni fun a]
argStack) Term tyname name uni fun a
f
      collectArgs [Term tyname name uni fun a]
argStack Term tyname name uni fun a
t               = [Term tyname name uni fun a]
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall {tyname} {name} {uni :: * -> *} {fun} {a} {fun}.
[Term tyname name uni fun a]
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
matchArgs [Term tyname name uni fun a]
argStack [] Term tyname name uni fun a
t
      matchArgs :: [Term tyname name uni fun a]
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
matchArgs (Term tyname name uni fun a
arg:[Term tyname name uni fun a]
rest) [Binding tyname name uni fun a]
acc (LamAbs a
a name
n Type tyname uni a
ty Term tyname name uni fun a
body) =
        [Term tyname name uni fun a]
-> [Binding tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
matchArgs [Term tyname name uni fun a]
rest (a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
a Strictness
Strict (a -> name -> Type tyname uni a -> VarDecl tyname name uni a
forall tyname name (uni :: * -> *) ann.
ann -> name -> Type tyname uni ann -> VarDecl tyname name uni ann
VarDecl a
a name
n Type tyname uni a
ty) Term tyname name uni fun a
argBinding tyname name uni fun a
-> [Binding tyname name uni fun a]
-> [Binding tyname name uni fun a]
forall a. a -> [a] -> [a]
:[Binding tyname name uni fun a]
acc) Term tyname name uni fun a
body
      matchArgs []         [Binding tyname name uni fun a]
acc Term tyname name uni fun a
t                    =
          case [Binding tyname name uni fun a]
-> Maybe (NonEmpty (Binding tyname name uni fun a))
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([Binding tyname name uni fun a] -> [Binding tyname name uni fun a]
forall a. [a] -> [a]
reverse [Binding tyname name uni fun a]
acc) of
              Maybe (NonEmpty (Binding tyname name uni fun a))
Nothing   -> Maybe
  (NonEmpty (Binding tyname name uni fun a),
   Term tyname name uni fun a)
forall a. Maybe a
Nothing
              Just NonEmpty (Binding tyname name uni fun a)
acc' -> (NonEmpty (Binding tyname name uni fun a),
 Term tyname name uni fun a)
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall a. a -> Maybe a
Just (NonEmpty (Binding tyname name uni fun a)
acc', Term tyname name uni fun a
t)
      matchArgs (Term tyname name uni fun a
_:[Term tyname name uni fun a]
_)      [Binding tyname name uni fun a]
_   Term tyname name uni fun a
_                    = Maybe
  (NonEmpty (Binding tyname name uni fun a),
   Term tyname name uni fun a)
forall a. Maybe a
Nothing

{-|
Recursively apply the beta transformation on the code, both for the terms

@
    (\ (x : A). M) N
    ==>
    let x : A = N in M
@

and types

@
    (/\ a. \(x : a) . x) {A}
    ==>
    let a : * = A in
    (\ (x : A). x)
@

-}
beta
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
beta :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
beta = ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a (f :: * -> *).
Applicative f =>
(Term tyname name uni fun a -> f (Term tyname name uni fun a))
-> Term tyname name uni fun a -> f (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
beta (Term tyname name uni fun a -> Term tyname name uni fun a)
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
localTransform
  where
    localTransform :: Term tyname name uni fun a -> Term tyname name uni fun a
localTransform = \case
      -- See Note [Multi-beta]
      -- This maybe isn't the best annotation for this term, but it will do.
      (Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
extractBindings -> Just (NonEmpty (Binding tyname name uni fun a)
bs, Term tyname name uni fun a
t)) -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let (Term tyname name uni fun a -> a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> a
termAnn Term tyname name uni fun a
t) Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
t
      -- See Note [Multi-beta] for why we don't perform multi-beta on `TyInst`.
      TyInst a
_ (TyAbs a
a tyname
n Kind a
k Term tyname name uni fun a
body) Type tyname uni a
tyArg ->
          let b :: Binding tyname name uni fun a
b = a
-> TyVarDecl tyname a
-> Type tyname uni a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> TyVarDecl tyname a
-> Type tyname uni a
-> Binding tyname name uni fun a
TypeBind a
a (a -> tyname -> Kind a -> TyVarDecl tyname a
forall tyname ann.
ann -> tyname -> Kind ann -> TyVarDecl tyname ann
TyVarDecl a
a tyname
n Kind a
k) Type tyname uni a
tyArg
          in a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let (Term tyname name uni fun a -> a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> a
termAnn Term tyname name uni fun a
body) Recursivity
NonRec (Binding tyname name uni fun a
-> NonEmpty (Binding tyname name uni fun a)
forall a. a -> NonEmpty a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding tyname name uni fun a
forall {name} {fun}. Binding tyname name uni fun a
b) Term tyname name uni fun a
body
      Term tyname name uni fun a
t -> Term tyname name uni fun a
t

betaPassSC
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, PLC.MonadQuote m, Ord a)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
betaPassSC :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, MonadQuote m, Ord a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
betaPassSC PirTCConfig uni fun
tcconfig = Pass m TyName Name uni fun a
forall name tyname (m :: * -> *) a (uni :: * -> *) fun.
(HasUnique name TermUnique, HasUnique tyname TypeUnique,
 MonadQuote m, Ord a) =>
Pass m tyname name uni fun a
renamePass Pass m TyName Name uni fun a
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a. Semigroup a => a -> a -> a
<> PirTCConfig uni fun -> Pass m TyName Name uni fun a
forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m, Ord a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
betaPass PirTCConfig uni fun
tcconfig

betaPass
  :: (PLC.Typecheckable uni fun, PLC.GEq uni, Applicative m, Ord a)
  => TC.PirTCConfig uni fun
  -> Pass m TyName Name uni fun a
betaPass :: forall (uni :: * -> *) fun (m :: * -> *) a.
(Typecheckable uni fun, GEq uni, Applicative m, Ord a) =>
PirTCConfig uni fun -> Pass m TyName Name uni fun a
betaPass PirTCConfig uni fun
tcconfig =
  String
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
String
-> Pass m tyname name uni fun a -> Pass m tyname name uni fun a
NamedPass String
"beta" (Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a)
-> Pass m TyName Name uni fun a -> Pass m TyName Name uni fun a
forall a b. (a -> b) -> a -> b
$
    (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> [Condition TyName Name uni fun a]
-> [BiCondition TyName Name uni fun a]
-> Pass m TyName Name uni fun a
forall (m :: * -> *) tyname name (uni :: * -> *) fun a.
(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> [Condition tyname name uni fun a]
-> [BiCondition tyname name uni fun a]
-> Pass m tyname name uni fun a
Pass
      (Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> (Term TyName Name uni fun a -> Term TyName Name uni fun a)
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term TyName Name uni fun a -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
beta)
      [PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig, Condition TyName Name uni fun a
forall tyname name a (uni :: * -> *) fun.
(HasUnique tyname TypeUnique, HasUnique name TermUnique, Ord a) =>
Condition tyname name uni fun a
GloballyUniqueNames]
      [Condition TyName Name uni fun a
-> BiCondition TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
Condition tyname name uni fun a
-> BiCondition tyname name uni fun a
ConstCondition (PirTCConfig uni fun -> Condition TyName Name uni fun a
forall (uni :: * -> *) fun a.
(Typecheckable uni fun, GEq uni) =>
PirTCConfig uni fun -> Condition TyName Name uni fun a
Typechecks PirTCConfig uni fun
tcconfig)]