{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module PlutusTx.Enum.TH (Enum (..), deriveEnum) where

import Control.Monad
import Data.Deriving.Internal (varTToName)
import Data.Foldable
import Data.Tuple
import Language.Haskell.TH as TH
import Language.Haskell.TH.Datatype as TH
import PlutusTx.Enum.Class
import PlutusTx.ErrorCodes
import PlutusTx.Trace
import Prelude hiding (Bool (True), Enum (..), Eq, (&&), (==))

data SuccPred = Succ | Pred
  deriving stock (Int -> SuccPred -> ShowS
[SuccPred] -> ShowS
SuccPred -> String
(Int -> SuccPred -> ShowS)
-> (SuccPred -> String) -> ([SuccPred] -> ShowS) -> Show SuccPred
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SuccPred -> ShowS
showsPrec :: Int -> SuccPred -> ShowS
$cshow :: SuccPred -> String
show :: SuccPred -> String
$cshowList :: [SuccPred] -> ShowS
showList :: [SuccPred] -> ShowS
Show)

{-| Derive PlutusTx.Enum typeclass for datatypes, much like `deriving stock Enum` does for Haskell

Note: requires enabling OverloadedStrings language extension -}
deriveEnum :: TH.Name -> TH.Q [TH.Dec]
deriveEnum :: Name -> Q [Dec]
deriveEnum Name
name = do
  TH.DatatypeInfo
    { datatypeName :: DatatypeInfo -> Name
TH.datatypeName = Name
tyConName
    , datatypeInstTypes :: DatatypeInfo -> [Type]
TH.datatypeInstTypes = [Type]
tyVars0
    , datatypeCons :: DatatypeInfo -> [ConstructorInfo]
TH.datatypeCons = [ConstructorInfo]
cons
    } <-
    Name -> Q DatatypeInfo
TH.reifyDatatype Name
name
  let
    -- The purpose of the `TH.VarT . varTToName` roundtrip is to remove the kind
    -- signatures attached to the type variables in `tyVars0`. Otherwise, the
    -- `KindSignatures` extension would be needed whenever `length tyVars0 > 0`.
    tyVars :: [Type]
tyVars = Name -> Type
TH.VarT (Name -> Type) -> (Type -> Name) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Name
varTToName (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
tyVars0
    instanceType :: TH.Type
    instanceType :: Type
instanceType = Type -> Type -> Type
TH.AppT (Name -> Type
TH.ConT ''Enum) (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Type -> Type -> Type
TH.AppT (Name -> Type
TH.ConT Name
tyConName) [Type]
tyVars

    table :: [(Lit, Name)]
table = [Lit] -> [Name] -> [(Lit, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Integer -> Lit) -> [Integer] -> [Lit]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Lit
IntegerL [Integer
0 ..]) ((ConstructorInfo -> Name) -> [ConstructorInfo] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ConstructorInfo -> Name
constructorName [ConstructorInfo]
cons)

  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ConstructorInfo] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ConstructorInfo]
cons) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
      String
"Can't make a derived instance of `Enum "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyConName
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"`: "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
tyConName
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" must must be an enumeration type (an enumeration consists of one or more nullary, non-GADT constructors)"

  Dec -> [Dec]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (Dec -> [Dec]) -> Q Dec -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [Type] -> Q Type -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Type] -> m Type -> [m Dec] -> m Dec
instanceD
      ([Type] -> Q [Type]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
      (Type -> Q Type
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
instanceType)
      [ Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'succ (((ConstructorInfo, Maybe ConstructorInfo) -> Q Clause)
-> [(ConstructorInfo, Maybe ConstructorInfo)] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SuccPred -> (ConstructorInfo, Maybe ConstructorInfo) -> Q Clause
deriveSuccPred SuccPred
Succ) ([ConstructorInfo]
-> [Maybe ConstructorInfo]
-> [(ConstructorInfo, Maybe ConstructorInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ConstructorInfo]
cons ([Maybe ConstructorInfo] -> [Maybe ConstructorInfo]
forall a. HasCallStack => [a] -> [a]
tail (ConstructorInfo -> Maybe ConstructorInfo
forall a. a -> Maybe a
Just (ConstructorInfo -> Maybe ConstructorInfo)
-> [ConstructorInfo] -> [Maybe ConstructorInfo]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ConstructorInfo]
cons) [Maybe ConstructorInfo]
-> [Maybe ConstructorInfo] -> [Maybe ConstructorInfo]
forall a. [a] -> [a] -> [a]
++ [Maybe ConstructorInfo
forall a. Maybe a
Nothing])))
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD 'succ Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases
      , Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'pred (((ConstructorInfo, Maybe ConstructorInfo) -> Q Clause)
-> [(ConstructorInfo, Maybe ConstructorInfo)] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SuccPred -> (ConstructorInfo, Maybe ConstructorInfo) -> Q Clause
deriveSuccPred SuccPred
Pred) ([ConstructorInfo]
-> [Maybe ConstructorInfo]
-> [(ConstructorInfo, Maybe ConstructorInfo)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ConstructorInfo]
cons (Maybe ConstructorInfo
forall a. Maybe a
Nothing Maybe ConstructorInfo
-> [Maybe ConstructorInfo] -> [Maybe ConstructorInfo]
forall a. a -> [a] -> [a]
: [Maybe ConstructorInfo] -> [Maybe ConstructorInfo]
forall a. HasCallStack => [a] -> [a]
init (ConstructorInfo -> Maybe ConstructorInfo
forall a. a -> Maybe a
Just (ConstructorInfo -> Maybe ConstructorInfo)
-> [ConstructorInfo] -> [Maybe ConstructorInfo]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ConstructorInfo]
cons))))
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD 'pred Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases
      , Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'toEnum ([Q Clause] -> Q Dec) -> [Q Clause] -> Q Dec
forall a b. (a -> b) -> a -> b
$ ((Lit, Name) -> Q Clause) -> [(Lit, Name)] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lit, Name) -> Q Clause
deriveToEnum [(Lit, Name)]
table [Q Clause] -> [Q Clause] -> [Q Clause]
forall a. Semigroup a => a -> a -> a
<> [Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Clause
toEnumDefaultClause]
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD 'toEnum Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases
      , Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'fromEnum ([Q Clause] -> Q Dec) -> [Q Clause] -> Q Dec
forall a b. (a -> b) -> a -> b
$ ((Lit, Name) -> Q Clause) -> [(Lit, Name)] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Name, Lit) -> Q Clause
deriveFromEnum ((Name, Lit) -> Q Clause)
-> ((Lit, Name) -> (Name, Lit)) -> (Lit, Name) -> Q Clause
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lit, Name) -> (Name, Lit)
forall a b. (a, b) -> (b, a)
swap) [(Lit, Name)]
table
      , Name -> Inline -> RuleMatch -> Phases -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> Inline -> RuleMatch -> Phases -> m Dec
TH.pragInlD 'fromEnum Inline
TH.Inlinable RuleMatch
TH.FunLike Phases
TH.AllPhases
      ]

toEnumDefaultClause :: Clause
toEnumDefaultClause :: Clause
toEnumDefaultClause =
  [Pat] -> Body -> [Dec] -> Clause
TH.Clause
    [Pat
WildP]
    ( Exp -> Body
TH.NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$
        Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'traceError) (Name -> Exp
VarE 'toEnumBadArgumentError)
    )
    []

deriveToEnum :: (Lit, Name) -> Q Clause
deriveToEnum :: (Lit, Name) -> Q Clause
deriveToEnum (Lit
l, Name
n) = Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Pat] -> Body -> [Dec] -> Clause
TH.Clause [Lit -> Pat
LitP Lit
l] (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE Name
n) [])

deriveFromEnum :: (Name, Lit) -> Q Clause
deriveFromEnum :: (Name, Lit) -> Q Clause
deriveFromEnum (Name
n, Lit
l) = Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Pat] -> Body -> [Dec] -> Clause
TH.Clause [Name -> [Type] -> [Pat] -> Pat
ConP Name
n [] []] (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Lit -> Exp
LitE Lit
l) [])

deriveSuccPred :: SuccPred -> (ConstructorInfo, Maybe ConstructorInfo) -> Q Clause
deriveSuccPred :: SuccPred -> (ConstructorInfo, Maybe ConstructorInfo) -> Q Clause
deriveSuccPred
  SuccPred
succPred
  ( ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
nameL, constructorFields :: ConstructorInfo -> [Type]
constructorFields = []}
    , Maybe ConstructorInfo
Nothing
    ) =
    Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [Pat] -> Body -> [Dec] -> Clause
TH.Clause
          [Name -> [Type] -> [Pat] -> Pat
ConP Name
nameL [] []]
          ( Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$
              Exp -> Exp -> Exp
AppE
                (Name -> Exp
VarE 'traceError)
                ( Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ case SuccPred
succPred of
                    SuccPred
Succ -> 'succBadArgumentError
                    SuccPred
Pred -> 'predBadArgumentError
                )
          )
          []
      )
deriveSuccPred
  SuccPred
_
  ( ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
nameL, constructorFields :: ConstructorInfo -> [Type]
constructorFields = []}
    , Just ConstructorInfo {constructorName :: ConstructorInfo -> Name
constructorName = Name
nameR, constructorFields :: ConstructorInfo -> [Type]
constructorFields = []}
    ) =
    Clause -> Q Clause
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [Pat] -> Body -> [Dec] -> Clause
TH.Clause
          [Name -> [Type] -> [Pat] -> Pat
ConP Name
nameL [] []]
          (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE Name
nameR)
          []
      )
deriveSuccPred SuccPred
_ (ConstructorInfo, Maybe ConstructorInfo)
_ = String -> Q Clause
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Can't make a derived instance of Enum when constructor has fields"