{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}

{-| Commute such that constants are the first arguments. Consider:

(1)    equalsInteger 1 x

(2)    equalsInteger x 1

We have unary application, so these are two partial applications:

(1)    (equalsInteger 1) x

(2)    (equalsInteger x) 1

With (1), we can share the `equalsInteger 1` node, and it will be the same across any place where
we do this.

With (2), both the nodes here include x, which is a variable that will likely be different in other
invocations of `equalsInteger`. So the second one is harder to share, which is worse for CSE.

So commuting `equalsInteger` so that it has the constant first both a) makes various occurrences of
`equalsInteger` more likely to look similar, and b) gives us a maximally-shareable node for CSE.

This applies to any commutative builtin function that takes constants as arguments, although we
might expect that `equalsInteger` is the one that will benefit the most.
Plutonomy only commutes `EqualsInteger` in their `commEquals`. -}
module PlutusIR.Transform.RewriteRules.CommuteFnWithConst
  ( commuteFnWithConst
  ) where

import PlutusCore.Default
import PlutusIR.Core.Type (Term (Apply, Builtin, Constant))

isConstant :: Term tyname name uni fun a -> Bool
isConstant :: forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant = \case
  Constant {} -> Bool
True
  Term tyname name uni fun a
_ -> Bool
False

commuteFnWithConst :: t ~ Term tyname name uni DefaultFun a => t -> t
commuteFnWithConst :: forall t tyname name (uni :: * -> *) a.
(t ~ Term tyname name uni DefaultFun a) =>
t -> t
commuteFnWithConst = \case
  Apply a
ann1 (Apply a
ann2 (Builtin a
ann3 DefaultFun
fun) Term tyname name uni DefaultFun a
arg1) Term tyname name uni DefaultFun a
arg2
    | DefaultFun -> Bool
isCommutative DefaultFun
fun
    , Bool -> Bool
not (Term tyname name uni DefaultFun a -> Bool
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant Term tyname name uni DefaultFun a
arg1)
    , Term tyname name uni DefaultFun a -> Bool
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Bool
isConstant Term tyname name uni DefaultFun a
arg2 ->
        a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
ann1 (a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
-> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
ann2 (a -> DefaultFun -> Term tyname name uni DefaultFun a
forall tyname name (uni :: * -> *) fun a.
a -> fun -> Term tyname name uni fun a
Builtin a
ann3 DefaultFun
fun) Term tyname name uni DefaultFun a
arg2) Term tyname name uni DefaultFun a
arg1
  t
t -> t
t

{-| Returns whether a `DefaultFun` is commutative. Not using
catchall to make sure that this function catches newly added `DefaultFun`. -}
isCommutative :: DefaultFun -> Bool
isCommutative :: DefaultFun -> Bool
isCommutative = \case
  DefaultFun
AddInteger -> Bool
True
  DefaultFun
MultiplyInteger -> Bool
True
  DefaultFun
EqualsInteger -> Bool
True
  DefaultFun
EqualsByteString -> Bool
True
  DefaultFun
EqualsString -> Bool
True
  DefaultFun
EqualsData -> Bool
True
  -- verbose laid down, to revisit this function if a new builtin is added
  DefaultFun
SubtractInteger -> Bool
False
  DefaultFun
DivideInteger -> Bool
False
  DefaultFun
QuotientInteger -> Bool
False
  DefaultFun
RemainderInteger -> Bool
False
  DefaultFun
ModInteger -> Bool
False
  DefaultFun
LessThanInteger -> Bool
False
  DefaultFun
LessThanEqualsInteger -> Bool
False
  DefaultFun
AppendByteString -> Bool
False
  DefaultFun
ConsByteString -> Bool
False
  DefaultFun
SliceByteString -> Bool
False
  DefaultFun
LengthOfByteString -> Bool
False
  DefaultFun
IndexByteString -> Bool
False
  DefaultFun
LessThanByteString -> Bool
False
  DefaultFun
LessThanEqualsByteString -> Bool
False
  DefaultFun
Sha2_256 -> Bool
False
  DefaultFun
Sha3_256 -> Bool
False
  DefaultFun
Blake2b_224 -> Bool
False
  DefaultFun
Blake2b_256 -> Bool
False
  DefaultFun
Keccak_256 -> Bool
False
  DefaultFun
Ripemd_160 -> Bool
False
  DefaultFun
VerifyEd25519Signature -> Bool
False
  DefaultFun
VerifyEcdsaSecp256k1Signature -> Bool
False
  DefaultFun
VerifySchnorrSecp256k1Signature -> Bool
False
  DefaultFun
Bls12_381_G1_add -> Bool
False
  DefaultFun
Bls12_381_G1_neg -> Bool
False
  DefaultFun
Bls12_381_G1_scalarMul -> Bool
False
  DefaultFun
Bls12_381_G1_multiScalarMul -> Bool
False
  DefaultFun
Bls12_381_G1_equal -> Bool
False
  DefaultFun
Bls12_381_G1_hashToGroup -> Bool
False
  DefaultFun
Bls12_381_G1_compress -> Bool
False
  DefaultFun
Bls12_381_G1_uncompress -> Bool
False
  DefaultFun
Bls12_381_G2_add -> Bool
False
  DefaultFun
Bls12_381_G2_neg -> Bool
False
  DefaultFun
Bls12_381_G2_scalarMul -> Bool
False
  DefaultFun
Bls12_381_G2_multiScalarMul -> Bool
False
  DefaultFun
Bls12_381_G2_equal -> Bool
False
  DefaultFun
Bls12_381_G2_hashToGroup -> Bool
False
  DefaultFun
Bls12_381_G2_compress -> Bool
False
  DefaultFun
Bls12_381_G2_uncompress -> Bool
False
  DefaultFun
Bls12_381_millerLoop -> Bool
False
  DefaultFun
Bls12_381_mulMlResult -> Bool
False
  DefaultFun
Bls12_381_finalVerify -> Bool
False
  DefaultFun
AppendString -> Bool
False
  DefaultFun
EncodeUtf8 -> Bool
False
  DefaultFun
DecodeUtf8 -> Bool
False
  DefaultFun
IfThenElse -> Bool
False
  DefaultFun
ChooseUnit -> Bool
False
  DefaultFun
Trace -> Bool
False
  DefaultFun
FstPair -> Bool
False
  DefaultFun
SndPair -> Bool
False
  DefaultFun
ChooseList -> Bool
False
  DefaultFun
MkCons -> Bool
False
  DefaultFun
HeadList -> Bool
False
  DefaultFun
TailList -> Bool
False
  DefaultFun
NullList -> Bool
False
  DefaultFun
LengthOfArray -> Bool
False
  DefaultFun
ListToArray -> Bool
False
  DefaultFun
IndexArray -> Bool
False
  DefaultFun
ChooseData -> Bool
False
  DefaultFun
ConstrData -> Bool
False
  DefaultFun
MapData -> Bool
False
  DefaultFun
ListData -> Bool
False
  DefaultFun
IData -> Bool
False
  DefaultFun
BData -> Bool
False
  DefaultFun
UnConstrData -> Bool
False
  DefaultFun
UnMapData -> Bool
False
  DefaultFun
UnListData -> Bool
False
  DefaultFun
UnIData -> Bool
False
  DefaultFun
UnBData -> Bool
False
  DefaultFun
SerialiseData -> Bool
False
  DefaultFun
MkPairData -> Bool
False
  DefaultFun
MkNilData -> Bool
False
  DefaultFun
MkNilPairData -> Bool
False
  DefaultFun
IntegerToByteString -> Bool
False
  DefaultFun
ByteStringToInteger -> Bool
False
  -- Currently, this requires commutativity in all arguments, which the
  -- logical and bitwise operations are not.
  DefaultFun
AndByteString -> Bool
False
  DefaultFun
OrByteString -> Bool
False
  DefaultFun
XorByteString -> Bool
False
  DefaultFun
ComplementByteString -> Bool
False
  DefaultFun
ReadBit -> Bool
False
  DefaultFun
WriteBits -> Bool
False
  DefaultFun
ReplicateByte -> Bool
False
  DefaultFun
ShiftByteString -> Bool
False
  DefaultFun
RotateByteString -> Bool
False
  DefaultFun
CountSetBits -> Bool
False
  DefaultFun
FindFirstSetBit -> Bool
False
  DefaultFun
ExpModInteger -> Bool
False
  DefaultFun
DropList -> Bool
False
  DefaultFun
InsertCoin -> Bool
False
  DefaultFun
LookupCoin -> Bool
False
  DefaultFun
UnionValue -> Bool
True
  DefaultFun
ValueContains -> Bool
False
  DefaultFun
ValueData -> Bool
False
  DefaultFun
UnValueData -> Bool
False
  DefaultFun
ScaleValue -> Bool
False