{-# LANGUAGE GADTs         #-}
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE TypeOperators #-}

{- | Commute such that constants are the first arguments. Consider:

(1)    equalsInteger 1 x

(2)    equalsInteger x 1

We have unary application, so these are two partial applications:

(1)    (equalsInteger 1) x

(2)    (equalsInteger x) 1

With (1), we can share the `equalsInteger 1` node, and it will be the same across any place where
we do this.

With (2), both the nodes here include x, which is a variable that will likely be different in other
invocations of `equalsInteger`. So the second one is harder to share, which is worse for CSE.

So commuting `equalsInteger` so that it has the constant first both a) makes various occurrences of
`equalsInteger` more likely to look similar, and b) gives us a maximally-shareable node for CSE.

This applies to any commutative builtin function that takes constants as arguments, although we
might expect that `equalsInteger` is the one that will benefit the most.
Plutonomy only commutes `EqualsInteger` in their `commEquals`.
-}

module PlutusIR.Transform.RewriteRules.CommuteFnWithConst
    ( commuteFnWithConst
    ) where

import PlutusCore.Default
import PlutusIR.Core.Type (Term (Apply, Builtin, Constant))

isConstant :: Term tyname name uni fun a -> Bool
isConstant :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant = \case
    Constant{} -> Bool
True
    Term tyname name uni fun a
_          -> Bool
False

commuteFnWithConst :: (t ~ Term tyname name uni DefaultFun a) => t -> t
commuteFnWithConst :: forall t tyname name (uni :: * -> *) a.
(t ~ Term tyname name uni DefaultFun a) =>
t -> t
commuteFnWithConst = \case
    Apply a
ann1 (Apply a
ann2 (Builtin a
ann3 DefaultFun
fun) Term tyname name uni DefaultFun a
arg1) Term tyname name uni DefaultFun a
arg2
        | DefaultFun -> Bool
isCommutative DefaultFun
fun
        , Bool -> Bool
not (Term tyname name uni DefaultFun a -> Bool
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant Term tyname name uni DefaultFun a
arg1)
        , Term tyname name uni DefaultFun a -> Bool
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant Term tyname name uni DefaultFun a
arg2
        -> a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
ann1 (a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
ann2 (a -> DefaultFun -> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a -> fun -> Term tyname name uni fun a
Builtin a
ann3 DefaultFun
fun) Term tyname name uni DefaultFun a
arg2) Term tyname name uni DefaultFun a
arg1
    t
t -> t
t

-- | Returns whether a `DefaultFun` is commutative. Not using
-- catchall to make sure that this function catches newly added `DefaultFun`.
isCommutative :: DefaultFun -> Bool
isCommutative :: DefaultFun -> Bool
isCommutative = \case
  DefaultFun
AddInteger                      -> Bool
True
  DefaultFun
MultiplyInteger                 -> Bool
True
  DefaultFun
EqualsInteger                   -> Bool
True
  DefaultFun
EqualsByteString                -> Bool
True
  DefaultFun
EqualsString                    -> Bool
True
  DefaultFun
EqualsData                      -> Bool
True
  -- verbose laid down, to revisit this function if a new builtin is added
  DefaultFun
SubtractInteger                 -> Bool
False
  DefaultFun
DivideInteger                   -> Bool
False
  DefaultFun
QuotientInteger                 -> Bool
False
  DefaultFun
RemainderInteger                -> Bool
False
  DefaultFun
ModInteger                      -> Bool
False
  DefaultFun
LessThanInteger                 -> Bool
False
  DefaultFun
LessThanEqualsInteger           -> Bool
False
  DefaultFun
AppendByteString                -> Bool
False
  DefaultFun
ConsByteString                  -> Bool
False
  DefaultFun
SliceByteString                 -> Bool
False
  DefaultFun
LengthOfByteString              -> Bool
False
  DefaultFun
IndexByteString                 -> Bool
False
  DefaultFun
LessThanByteString              -> Bool
False
  DefaultFun
LessThanEqualsByteString        -> Bool
False
  DefaultFun
Sha2_256                        -> Bool
False
  DefaultFun
Sha3_256                        -> Bool
False
  DefaultFun
Blake2b_224                     -> Bool
False
  DefaultFun
Blake2b_256                     -> Bool
False
  DefaultFun
Keccak_256                      -> Bool
False
  DefaultFun
Ripemd_160                      -> Bool
False
  DefaultFun
VerifyEd25519Signature          -> Bool
False
  DefaultFun
VerifyEcdsaSecp256k1Signature   -> Bool
False
  DefaultFun
VerifySchnorrSecp256k1Signature -> Bool
False
  DefaultFun
Bls12_381_G1_add                -> Bool
False
  DefaultFun
Bls12_381_G1_neg                -> Bool
False
  DefaultFun
Bls12_381_G1_scalarMul          -> Bool
False
  DefaultFun
Bls12_381_G1_equal              -> Bool
False
  DefaultFun
Bls12_381_G1_hashToGroup        -> Bool
False
  DefaultFun
Bls12_381_G1_compress           -> Bool
False
  DefaultFun
Bls12_381_G1_uncompress         -> Bool
False
  DefaultFun
Bls12_381_G2_add                -> Bool
False
  DefaultFun
Bls12_381_G2_neg                -> Bool
False
  DefaultFun
Bls12_381_G2_scalarMul          -> Bool
False
  DefaultFun
Bls12_381_G2_equal              -> Bool
False
  DefaultFun
Bls12_381_G2_hashToGroup        -> Bool
False
  DefaultFun
Bls12_381_G2_compress           -> Bool
False
  DefaultFun
Bls12_381_G2_uncompress         -> Bool
False
  DefaultFun
Bls12_381_millerLoop            -> Bool
False
  DefaultFun
Bls12_381_mulMlResult           -> Bool
False
  DefaultFun
Bls12_381_finalVerify           -> Bool
False
  DefaultFun
AppendString                    -> Bool
False
  DefaultFun
EncodeUtf8                      -> Bool
False
  DefaultFun
DecodeUtf8                      -> Bool
False
  DefaultFun
IfThenElse                      -> Bool
False
  DefaultFun
ChooseUnit                      -> Bool
False
  DefaultFun
Trace                           -> Bool
False
  DefaultFun
FstPair                         -> Bool
False
  DefaultFun
SndPair                         -> Bool
False
  DefaultFun
ChooseList                      -> Bool
False
  DefaultFun
MkCons                          -> Bool
False
  DefaultFun
HeadList                        -> Bool
False
  DefaultFun
TailList                        -> Bool
False
  DefaultFun
NullList                        -> Bool
False
  DefaultFun
ChooseData                      -> Bool
False
  DefaultFun
ConstrData                      -> Bool
False
  DefaultFun
MapData                         -> Bool
False
  DefaultFun
ListData                        -> Bool
False
  DefaultFun
IData                           -> Bool
False
  DefaultFun
BData                           -> Bool
False
  DefaultFun
UnConstrData                    -> Bool
False
  DefaultFun
UnMapData                       -> Bool
False
  DefaultFun
UnListData                      -> Bool
False
  DefaultFun
UnIData                         -> Bool
False
  DefaultFun
UnBData                         -> Bool
False
  DefaultFun
SerialiseData                   -> Bool
False
  DefaultFun
MkPairData                      -> Bool
False
  DefaultFun
MkNilData                       -> Bool
False
  DefaultFun
MkNilPairData                   -> Bool
False
  DefaultFun
IntegerToByteString             -> Bool
False
  DefaultFun
ByteStringToInteger             -> Bool
False
  -- Currently, this requires commutativity in all arguments, which the
  -- logical and bitwise operations are not.
  DefaultFun
AndByteString                   -> Bool
False
  DefaultFun
OrByteString                    -> Bool
False
  DefaultFun
XorByteString                   -> Bool
False
  DefaultFun
ComplementByteString            -> Bool
False
  DefaultFun
ReadBit                         -> Bool
False
  DefaultFun
WriteBits                       -> Bool
False
  DefaultFun
ReplicateByte                   -> Bool
False
  DefaultFun
ShiftByteString                 -> Bool
False
  DefaultFun
RotateByteString                -> Bool
False
  DefaultFun
CountSetBits                    -> Bool
False
  DefaultFun
FindFirstSetBit                 -> Bool
False
  DefaultFun
ExpModInteger           -> Bool
False