{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module PlutusTx.Eq.TH (Eq (..), deriveEq) where

import Data.Deriving.Internal (varTToName)
import Data.Foldable
import Data.Traversable
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import PlutusTx.Bool (Bool (True), (&&))
import PlutusTx.Eq.Class hiding ((/=))
import Prelude hiding (Bool (True), Eq, (&&), (==))

{-| Derive a Plinth 'Eq' instance for a datatype or newtype.

Similar to Haskell's @deriving stock Eq@, this generates structural equality
with short-circuit evaluation and INLINEABLE pragmas for optimal on-chain performance.

__Usage:__

@
data MyType = Con1 Integer | Con2 Bool
deriveEq ''MyType
@

__Generated code:__

* Pattern-matching clauses for each constructor
* Short-circuit evaluation (stops at first inequality)
* @INLINEABLE@ pragma for cross-module optimization
* Proper handling of phantom type parameters

__Supported types:__

* Regular datatypes with any number of constructors
* Newtypes
* Types with phantom type parameters
* Types with strict or lazy fields
* Records
* Self-recursive types
* Mutually recursive types (when all types in the group have Eq instances)

__Unsupported:__

* GADTs (may produce type errors)
* Existentially quantified types
* Type families (not tested)

See 'PlutusTx.Eq.Class.Eq' for the class definition. -}
deriveEq :: Name -> Q [Dec]
deriveEq :: Name -> Q [Dec]
deriveEq Name
name = do
  DatatypeInfo
    { datatypeName :: DatatypeInfo -> Name
datatypeName = Name
tyConName
    , datatypeInstTypes :: DatatypeInfo -> [Type]
datatypeInstTypes = [Type]
tyVars0
    , datatypeCons :: DatatypeInfo -> [ConstructorInfo]
datatypeCons = [ConstructorInfo]
cons
    } <-
    Name -> Q DatatypeInfo
reifyDatatype Name
name

  [Role]
roles <- Name -> Q [Role]
reifyRoles Name
name

  let
    -- The purpose of the `VarT . varTToName` roundtrip is to remove the kind
    -- signatures attached to the type variables in `tyVars0`. Otherwise, the
    -- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
    tyVars :: [Type]
tyVars = Name -> Type
VarT (Name -> Type) -> (Type -> Name) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Name
varTToName (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
tyVars0

    nonPhantomTyVars :: [Type]
nonPhantomTyVars =
      Name -> Type
VarT (Name -> Type) -> ((Role, Type) -> Name) -> (Role, Type) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Name
varTToName (Type -> Name) -> ((Role, Type) -> Type) -> (Role, Type) -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Role, Type) -> Type
forall a b. (a, b) -> b
snd ((Role, Type) -> Type) -> [(Role, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Role, Type) -> Bool) -> [(Role, Type)] -> [(Role, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
/= Role
PhantomR) (Role -> Bool) -> ((Role, Type) -> Role) -> (Role, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Role, Type) -> Role
forall a b. (a, b) -> a
fst) ([Role] -> [Type] -> [(Role, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Role]
roles [Type]
tyVars0)

    instanceCxt :: Cxt
    instanceCxt :: [Type]
instanceCxt = Type -> Type -> Type
AppT (Name -> Type
ConT ''Eq) (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
nonPhantomTyVars

    instanceType :: Type
    instanceType :: Type
instanceType = Type -> Type -> Type
AppT (Name -> Type
ConT ''Eq) (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Type -> Type -> Type
AppT (Name -> Type
ConT Name
tyConName) [Type]
tyVars

    maybeDefaultClause :: [ConstructorInfo] -> [Q Clause]
    maybeDefaultClause :: [ConstructorInfo] -> [Q Clause]
maybeDefaultClause = \case
      -- if void datatype aka 0 constructors, generate a True clause
      [] -> [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP, Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'True)) []]
      -- if one constructor no need to generate default clause
      [ConstructorInfo
_] -> []
      -- if >1 constructors, generate a False clause
      [ConstructorInfo]
_ -> [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP, Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'False)) []]

  Dec -> [Dec]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (Dec -> [Dec]) -> Q Dec -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [Type] -> Q Type -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Type] -> m Type -> [m Dec] -> m Dec
instanceD
      ([Type] -> Q [Type]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
instanceCxt)
      (Type -> Q Type
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
instanceType)
      [ Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD '(==) ((ConstructorInfo -> Q Clause) -> [ConstructorInfo] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ConstructorInfo -> Q Clause
deriveEqCons [ConstructorInfo]
cons [Q Clause] -> [Q Clause] -> [Q Clause]
forall a. Semigroup a => a -> a -> a
<> [ConstructorInfo] -> [Q Clause]
maybeDefaultClause [ConstructorInfo]
cons)
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
pragInlD '(==) Inline
Inlinable RuleMatch
FunLike Phases
AllPhases
      ]

-- Clause:    Cons1 l1 l2 l3 .. ln == Cons1 r1 r2 r3 .. rn
deriveEqCons :: ConstructorInfo -> Q Clause
deriveEqCons :: ConstructorInfo -> Q Clause
deriveEqCons (ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
name, constructorFields :: ConstructorInfo -> [Type]
constructorFields = [Type]
fields}) = do
  [Name]
argsL <- [Int] -> (Int -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
1 .. [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fields] \Int
i -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"l" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"l")
  [Name]
argsR <- [Int] -> (Int -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
1 .. [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fields] \Int
i -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String
"r" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"r")
  [Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
    [Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
name ((Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
argsL), Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
name ((Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
argsR)]
    ( Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB
        case [Type]
fields of
          [] -> Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'True
          [Type]
_ ->
            (Q Exp -> Q Exp -> Q Exp) -> [Q Exp] -> Q Exp
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Q Exp
e Q Exp
acc -> Maybe (Q Exp) -> Q Exp -> Maybe (Q Exp) -> Q Exp
forall (m :: * -> *).
Quote m =>
Maybe (m Exp) -> m Exp -> Maybe (m Exp) -> m Exp
infixE (Q Exp -> Maybe (Q Exp)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Q Exp
e) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE '(&&)) (Q Exp -> Maybe (Q Exp)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Q Exp
acc)) ([Q Exp] -> Q Exp) -> [Q Exp] -> Q Exp
forall a b. (a -> b) -> a -> b
$
              (Name -> Name -> Q Exp) -> [Name] -> [Name] -> [Q Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                ( \Name
argL Name
argR ->
                    Maybe (Q Exp) -> Q Exp -> Maybe (Q Exp) -> Q Exp
forall (m :: * -> *).
Quote m =>
Maybe (m Exp) -> m Exp -> Maybe (m Exp) -> m Exp
infixE (Q Exp -> Maybe (Q Exp)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q Exp -> Maybe (Q Exp)) -> Q Exp -> Maybe (Q Exp)
forall a b. (a -> b) -> a -> b
$ Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
argL) (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE '(==)) (Q Exp -> Maybe (Q Exp)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q Exp -> Maybe (Q Exp)) -> Q Exp -> Maybe (Q Exp)
forall a b. (a -> b) -> a -> b
$ Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
argR)
                )
                [Name]
argsL
                [Name]
argsR
    )
    []