{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module PlutusTx.Ord.TH (deriveOrd) where

import Data.Deriving.Internal (varTToName)
import Data.Foldable
import Data.Traversable
import Language.Haskell.TH as TH
import Language.Haskell.TH.Datatype as TH
import PlutusTx.Ord.Class
import Prelude hiding
  ( Bool (True)
  , Eq ((==))
  , Ord (..)
  , Ordering (..)
  , (&&)
  )

{-| derive a PlutusTx.Ord instance for a datatype/newtype, similar to Haskell's `deriving stock Ord`.

One shortcoming compared to Haskell's deriving is that you cannot `PlutusTx.deriveOrd` for polymorphic phantom types. -}
deriveOrd :: TH.Name -> TH.Q [TH.Dec]
deriveOrd :: Name -> Q [Dec]
deriveOrd Name
name = do
  TH.DatatypeInfo
    { datatypeName :: DatatypeInfo -> Name
TH.datatypeName = Name
tyConName
    , datatypeInstTypes :: DatatypeInfo -> [Type]
TH.datatypeInstTypes = [Type]
tyVars0
    , datatypeCons :: DatatypeInfo -> [ConstructorInfo]
TH.datatypeCons = [ConstructorInfo]
cons
    } <-
    Name -> Q DatatypeInfo
TH.reifyDatatype Name
name

  [Role]
roles <- Name -> Q [Role]
reifyRoles Name
name

  let
    -- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind
    -- signatures attached to the type variables in `tyVars0`. Otherwise, the
    -- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
    tyVars :: [Type]
tyVars = Name -> Type
TH.VarT (Name -> Type) -> (Type -> Name) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Name
varTToName (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
tyVars0

    nonPhantomTyVars :: [Type]
nonPhantomTyVars = Name -> Type
VarT (Name -> Type) -> ((Role, Type) -> Name) -> (Role, Type) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Name
varTToName (Type -> Name) -> ((Role, Type) -> Type) -> (Role, Type) -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Role, Type) -> Type
forall a b. (a, b) -> b
snd ((Role, Type) -> Type) -> [(Role, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Role, Type) -> Bool) -> [(Role, Type)] -> [(Role, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
/= Role
PhantomR) (Role -> Bool) -> ((Role, Type) -> Role) -> (Role, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Role, Type) -> Role
forall a b. (a, b) -> a
fst) ([Role] -> [Type] -> [(Role, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Role]
roles [Type]
tyVars0)

    instanceCxt :: TH.Cxt
    instanceCxt :: [Type]
instanceCxt = Type -> Type -> Type
TH.AppT (Name -> Type
TH.ConT ''Ord) (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
nonPhantomTyVars

    instanceType :: TH.Type
    instanceType :: Type
instanceType = Type -> Type -> Type
TH.AppT (Name -> Type
TH.ConT ''Ord) (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Type -> Type -> Type
TH.AppT (Name -> Type
TH.ConT Name
tyConName) [Type]
tyVars

  Dec -> [Dec]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (Dec -> [Dec]) -> Q Dec -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [Type] -> Q Type -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Type] -> m Type -> [m Dec] -> m Dec
instanceD
      ([Type] -> Q [Type]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
instanceCxt)
      (Type -> Q Type
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
instanceType)
      [ Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'compare ((ConstructorInfo -> Q Clause) -> [ConstructorInfo] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ConstructorInfo -> Q Clause
deriveOrdSame [ConstructorInfo]
cons [Q Clause] -> [Q Clause] -> [Q Clause]
forall a. [a] -> [a] -> [a]
++ [ConstructorInfo] -> [Q Clause]
deriveOrdDifferent [ConstructorInfo]
cons)
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD 'compare Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases
      ]

deriveOrdSame :: ConstructorInfo -> Q Clause
deriveOrdSame :: ConstructorInfo -> Q Clause
deriveOrdSame (ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
name, constructorFields :: ConstructorInfo -> [Type]
constructorFields = [Type]
fields}) = do
  [Name]
argsL <- [Int] -> (Int -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
1 .. [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fields] ((Int -> Q Name) -> Q [Name]) -> (Int -> Q Name) -> Q [Name]
forall a b. (a -> b) -> a -> b
$ \Int
i -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (String
"l" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"l")
  [Name]
argsR <- [Int] -> (Int -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
1 .. [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
fields] ((Int -> Q Name) -> Q [Name]) -> (Int -> Q Name) -> Q [Name]
forall a b. (a -> b) -> a -> b
$ \Int
i -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (String
"r" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"r")
  Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Pat] -> Body -> [Dec] -> Clause
TH.Clause
        [Name -> [Type] -> [Pat] -> Pat
ConP Name
name [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Pat
VarP [Name]
argsL), Name -> [Type] -> [Pat] -> Pat
ConP Name
name [] ((Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Pat
VarP [Name]
argsR)]
        ( Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$
            case [Type]
fields of
              [] -> Name -> Exp
TH.ConE 'EQ
              [Type]
_ ->
                (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 (\Exp
e Exp
acc -> Maybe Exp -> Exp -> Maybe Exp -> Exp
TH.InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e) (Name -> Exp
TH.VarE '(<>)) (Exp -> Maybe Exp
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
acc)) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
                  (Name -> Name -> Exp) -> [Name] -> [Name] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                    ( \Name
argL Name
argR ->
                        Maybe Exp -> Exp -> Maybe Exp -> Exp
TH.InfixE (Exp -> Maybe Exp
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
TH.VarE Name
argL) (Name -> Exp
TH.VarE 'compare) (Exp -> Maybe Exp
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
TH.VarE Name
argR)
                    )
                    [Name]
argsL
                    [Name]
argsR
        )
        []
    )

{-| Generate clauses for cross-constructor comparisons.
Since same-constructor clauses (from 'deriveOrdSame') come first, we can use
wildcard patterns here. For each constructor except the last, we generate:
  compare Ci{} _   = LT   (Ci is less than any later constructor)
  compare _   Ci{} = GT   (any later constructor is greater than Ci)
This produces O(n) clauses instead of O(n^2). -}
deriveOrdDifferent :: [ConstructorInfo] -> [Q Clause]
deriveOrdDifferent :: [ConstructorInfo] -> [Q Clause]
deriveOrdDifferent = \case
  -- Void datatype (0 constructors): vacuously EQ
  [] -> [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP, Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (Q Exp -> Q Body) -> Q Exp -> Q Body
forall a b. (a -> b) -> a -> b
$ Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'EQ) []]
  -- 1 constructor: same-constructor clause already covers it, nothing needed
  [ConstructorInfo
_] -> []
  -- 2+ constructors: generate wildcard LT/GT pairs for all but the last
  [ConstructorInfo]
cons -> (ConstructorInfo -> [Q Clause]) -> [ConstructorInfo] -> [Q Clause]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Q Clause]
forall {m :: * -> *}. Quote m => ConstructorInfo -> [m Clause]
mkPair ([ConstructorInfo] -> [ConstructorInfo]
forall a. HasCallStack => [a] -> [a]
init [ConstructorInfo]
cons)
  where
    mkPair :: ConstructorInfo -> [m Clause]
mkPair (ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
name}) =
      [ [m Pat] -> m Body -> [m Dec] -> m Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [Name -> [m FieldPat] -> m Pat
forall (m :: * -> *). Quote m => Name -> [m FieldPat] -> m Pat
recP Name
name [], m Pat
forall (m :: * -> *). Quote m => m Pat
wildP] (m Exp -> m Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (m Exp -> m Body) -> m Exp -> m Body
forall a b. (a -> b) -> a -> b
$ Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'LT) []
      , [m Pat] -> m Body -> [m Dec] -> m Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [m Pat
forall (m :: * -> *). Quote m => m Pat
wildP, Name -> [m FieldPat] -> m Pat
forall (m :: * -> *). Quote m => Name -> [m FieldPat] -> m Pat
recP Name
name []] (m Exp -> m Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (m Exp -> m Body) -> m Exp -> m Body
forall a b. (a -> b) -> a -> b
$ Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'GT) []
      ]