{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE CPP                #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE PatternSynonyms    #-}
{-# LANGUAGE TemplateHaskell    #-}
{-# LANGUAGE TypeApplications   #-}
{-# LANGUAGE ViewPatterns       #-}

module PlutusTx.AsData (asData, asDataFor) where

import Control.Lens (ifor)
import Control.Monad (unless)
import Data.Traversable (for)

import Language.Haskell.TH qualified as TH
import Language.Haskell.TH.Datatype qualified as TH
import Language.Haskell.TH.Datatype.TyVarBndr qualified as TH

import PlutusTx.Builtins qualified as Builtins
import PlutusTx.IsData.Class (ToData, UnsafeFromData)
import PlutusTx.IsData.TH (mkConstrCreateExpr, mkUnsafeConstrMatchPattern)

import Prelude

{- | 'asData' takes a datatype declaration and "backs it" by 'BuiltinData'. It does
this by replacing the datatype with a newtype around 'BuiltinData', and providing
pattern synonyms that match the behaviour of the original datatype.

Since 'BuiltinData' can only contain 'BuiltinData', the pattern synonyms
encode and decode all the fields using 'toBuiltinData' and 'unsafeFromBuiltinData'.
That means that these operations are called on every pattern match or construction.
This *can* be very efficient if, for example, the datatypes for the fields have
also been defined with 'asData', and so have trivial conversions to/from 'BuiltinData'
(or have very cheap conversions, like 'Integer' and 'ByteString').
But it can be very expensive otherwise, so take care.

Deriving clauses are transferred from the quoted declaration to the generated newtype
declaration. Note that you may therefore need to do strange things like use
@deriving newtype@ on a data declaration.

Example:
@
  $(asData [d|
      data Example a = Ex1 Integer | Ex2 a a
        deriving newtype (Eq)
    |])
@
becomes
@
  newtype Example a = Example BuiltinData
    deriving newtype (Eq)

  pattern Ex1 :: (ToData a, UnsafeFromData a) => Integer -> Example a
  pattern Ex1 i <- Example (unsafeDataAsConstr -> ((==) 0 -> True, [unsafeFromBuiltinData -> i]))
    where Ex1 i = Example (mkConstr 0 [toBuiltinData i])

  pattern Ex2 :: (ToData a, UnsafeFromData a) => a -> a -> Example a
  pattern Ex2 a1 a2 <- Example (unsafeDataAsConstr -> ((==) 1 -> True,
    [ unsafeFromBuiltinData -> a1
    , unsafeFromBuiltinData -> a2
    ]))
    where Ex2 a1 a2 = Example (mkConstr 1 [toBuiltinData a1, toBuiltinData a2])

  {-# COMPLETE Ex1, Ex2 #-}
@
-}
asData :: TH.Q [TH.Dec] -> TH.Q [TH.Dec]
asData :: Q [Dec] -> Q [Dec]
asData Q [Dec]
decQ = do
  [Dec]
decs <- Q [Dec]
decQ
  [[Dec]]
outputDecs <- [Dec] -> (Dec -> Q [Dec]) -> Q [[Dec]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Dec]
decs Dec -> Q [Dec]
asDataFor
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
outputDecs

asDataFor :: TH.Dec -> TH.Q [TH.Dec]
asDataFor :: Dec -> Q [Dec]
asDataFor Dec
dec = do
  -- th-abstraction doesn't include deriving clauses, so we have to handle that here
  let derivs :: [DerivClause]
derivs = case Dec
dec of
        TH.DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Kind
_ [Con]
_ [DerivClause]
deriv -> [DerivClause]
deriv
        Dec
_                        -> []

  di :: DatatypeInfo
di@(
    TH.DatatypeInfo
      { datatypeVariant :: DatatypeInfo -> DatatypeVariant
TH.datatypeVariant = DatatypeVariant
dVariant
      , datatypeCons :: DatatypeInfo -> [ConstructorInfo]
TH.datatypeCons    = [ConstructorInfo]
cons
      , datatypeName :: DatatypeInfo -> Name
TH.datatypeName    = Name
name
      , datatypeVars :: DatatypeInfo -> [TyVarBndr ()]
TH.datatypeVars    = [TyVarBndr ()]
tTypeVars
      }
    ) <- Dec -> Q DatatypeInfo
TH.normalizeDec Dec
dec

  -- Other stuff is data families and so on
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (DatatypeVariant
dVariant DatatypeVariant -> DatatypeVariant -> Bool
forall a. Eq a => a -> a -> Bool
== DatatypeVariant
TH.Datatype) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"asData: can't handle datatype variant " String -> String -> String
forall a. [a] -> [a] -> [a]
++ DatatypeVariant -> String
forall a. Show a => a -> String
show DatatypeVariant
dVariant
  -- a fresh name for the new datatype, but same lexically as the old one
  Name
cname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (Name -> String
forall a. Show a => a -> String
show Name
name)
  -- The newtype declaration
  let ntD :: Dec
ntD =
        let con :: Con
con = Name -> [BangType] -> Con
TH.NormalC Name
cname
              [ ( SourceUnpackedness -> SourceStrictness -> Bang
TH.Bang SourceUnpackedness
TH.NoSourceUnpackedness SourceStrictness
TH.NoSourceStrictness
                , Name -> Kind
TH.ConT ''Builtins.BuiltinData
                )
              ]
        in Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Kind
-> Con
-> [DerivClause]
-> Dec
TH.NewtypeD [] Name
name
#if MIN_VERSION_template_haskell(2,21,0)
            (TH.changeTVFlags TH.BndrReq tTypeVars)
#else
            [TyVarBndr ()]
tTypeVars
#endif
            Maybe Kind
forall a. Maybe a
Nothing Con
con [DerivClause]
derivs

  -- The pattern synonyms, one for each constructor
  [[Dec]]
pats <- [ConstructorInfo]
-> (Int -> ConstructorInfo -> Q [Dec]) -> Q [[Dec]]
forall i (t :: * -> *) (f :: * -> *) a b.
(TraversableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f (t b)
ifor [ConstructorInfo]
cons ((Int -> ConstructorInfo -> Q [Dec]) -> Q [[Dec]])
-> (Int -> ConstructorInfo -> Q [Dec]) -> Q [[Dec]]
forall a b. (a -> b) -> a -> b
$
    \Int
conIx TH.ConstructorInfo
      { constructorName :: ConstructorInfo -> Name
TH.constructorName = Name
conName
      , constructorFields :: ConstructorInfo -> Cxt
TH.constructorFields = Cxt
fields
      , constructorVariant :: ConstructorInfo -> ConstructorVariant
TH.constructorVariant = ConstructorVariant
cVariant
      } -> do
    -- If we have a record constructor, we need to reuse the names for the
    -- matching part of the pattern synonym
    [Name]
fieldNames <- case ConstructorVariant
cVariant of
      TH.RecordConstructor [Name]
names -> [Name] -> Q [Name]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Name]
names
      -- otherwise whatever
      ConstructorVariant
_                          -> Cxt -> (Int -> Kind -> Q Name) -> Q [Name]
forall i (t :: * -> *) (f :: * -> *) a b.
(TraversableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f (t b)
ifor Cxt
fields ((Int -> Kind -> Q Name) -> Q [Name])
-> (Int -> Kind -> Q Name) -> Q [Name]
forall a b. (a -> b) -> a -> b
$ \Int
fieldIx Kind
_ -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
"arg" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
fieldIx
    [Name]
createFieldNames <- [Name] -> (Name -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Name]
fieldNames (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (String -> Q Name) -> (Name -> String) -> Name -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
forall a. Show a => a -> String
show)
    Q PatSynArgs
patSynArgs <- case ConstructorVariant
cVariant of
      ConstructorVariant
TH.NormalConstructor   -> Q PatSynArgs -> Q (Q PatSynArgs)
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q PatSynArgs -> Q (Q PatSynArgs))
-> Q PatSynArgs -> Q (Q PatSynArgs)
forall a b. (a -> b) -> a -> b
$ [Name] -> Q PatSynArgs
forall (m :: * -> *). Quote m => [Name] -> m PatSynArgs
TH.prefixPatSyn [Name]
fieldNames
      TH.RecordConstructor [Name]
_ -> Q PatSynArgs -> Q (Q PatSynArgs)
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q PatSynArgs -> Q (Q PatSynArgs))
-> Q PatSynArgs -> Q (Q PatSynArgs)
forall a b. (a -> b) -> a -> b
$ [Name] -> Q PatSynArgs
forall (m :: * -> *). Quote m => [Name] -> m PatSynArgs
TH.recordPatSyn [Name]
fieldNames
      ConstructorVariant
TH.InfixConstructor    -> case [Name]
fieldNames of
        [Name
f1,Name
f2] -> Q PatSynArgs -> Q (Q PatSynArgs)
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q PatSynArgs -> Q (Q PatSynArgs))
-> Q PatSynArgs -> Q (Q PatSynArgs)
forall a b. (a -> b) -> a -> b
$ Name -> Name -> Q PatSynArgs
forall (m :: * -> *). Quote m => Name -> Name -> m PatSynArgs
TH.infixPatSyn Name
f1 Name
f2
        [Name]
_       -> String -> Q (Q PatSynArgs)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"asData: infix data constructor with other than two fields"
    let
      pat :: Q Pat
pat = Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
TH.conP Name
cname [Integer -> [Name] -> Q Pat
mkUnsafeConstrMatchPattern (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
conIx) [Name]
fieldNames]

      createExpr :: Q Exp
createExpr = [|$(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
TH.conE Name
cname) $(Integer -> [Name] -> Q Exp
mkConstrCreateExpr (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
conIx) [Name]
createFieldNames) |]
      clause :: Q Clause
clause = [Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
TH.clause ((Name -> Q Pat) -> [Name] -> [Q Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
TH.varP [Name]
createFieldNames) (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
TH.normalB Q Exp
createExpr) []
      patSynD :: Q Dec
patSynD = Name -> Q PatSynArgs -> Q PatSynDir -> Q Pat -> Q Dec
forall (m :: * -> *).
Quote m =>
Name -> m PatSynArgs -> m PatSynDir -> m Pat -> m Dec
TH.patSynD Name
conName Q PatSynArgs
patSynArgs ([Q Clause] -> Q PatSynDir
forall (m :: * -> *). Quote m => [m Clause] -> m PatSynDir
TH.explBidir [Q Clause
clause]) Q Pat
pat
      dataConstraints :: Kind -> Cxt
dataConstraints Kind
t = [Name -> Kind
TH.ConT ''ToData Kind -> Kind -> Kind
`TH.AppT` Kind
t, Name -> Kind
TH.ConT ''UnsafeFromData Kind -> Kind -> Kind
`TH.AppT` Kind
t]
      -- Try and be a little clever and only add constraints on the variables used in the arguments
      varsInArgs :: [TyVarBndr ()]
varsInArgs = Cxt -> [TyVarBndr ()]
TH.freeVariablesWellScoped Cxt
fields
      ctxForArgs :: Cxt
ctxForArgs = (TyVarBndr () -> Cxt) -> [TyVarBndr ()] -> Cxt
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Kind -> Cxt
dataConstraints (Kind -> Cxt) -> (TyVarBndr () -> Kind) -> TyVarBndr () -> Cxt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Kind
TH.VarT (Name -> Kind) -> (TyVarBndr () -> Name) -> TyVarBndr () -> Kind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr () -> Name
forall flag. TyVarBndr_ flag -> Name
TH.tvName) [TyVarBndr ()]
varsInArgs
      conTy :: Kind
conTy = (Kind -> Kind -> Kind) -> Kind -> Cxt -> Kind
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Kind
ty Kind
acc -> Kind
TH.ArrowT Kind -> Kind -> Kind
`TH.AppT` Kind
ty Kind -> Kind -> Kind
`TH.AppT` Kind
acc) (DatatypeInfo -> Kind
TH.datatypeType DatatypeInfo
di) Cxt
fields
      allFreeVars :: [TyVarBndr ()]
allFreeVars = Cxt -> [TyVarBndr ()]
TH.freeVariablesWellScoped [Kind
conTy]
      fullTy :: Kind
fullTy = [TyVarBndr Specificity] -> Cxt -> Kind -> Kind
TH.ForallT (Specificity -> [TyVarBndr ()] -> [TyVarBndr Specificity]
forall newFlag oldFlag.
newFlag -> [TyVarBndr_ oldFlag] -> [TyVarBndr_ newFlag]
TH.changeTVFlags Specificity
TH.SpecifiedSpec [TyVarBndr ()]
allFreeVars) Cxt
ctxForArgs Kind
conTy
      patSynSigD :: Q Dec
patSynSigD = Dec -> Q Dec
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$ Name -> Kind -> Dec
TH.PatSynSigD Name
conName Kind
fullTy

    [Q Dec] -> Q [Dec]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Q Dec
patSynSigD, Q Dec
patSynD]
  -- A complete pragma, to top it off
  let compl :: Dec
compl = Pragma -> Dec
TH.PragmaD ([Name] -> Maybe Name -> Pragma
TH.CompleteP ((ConstructorInfo -> Name) -> [ConstructorInfo] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ConstructorInfo -> Name
TH.constructorName [ConstructorInfo]
cons) Maybe Name
forall a. Maybe a
Nothing)
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ Dec
ntD Dec -> [Dec] -> [Dec]
forall a. a -> [a] -> [a]
: Dec
compl Dec -> [Dec] -> [Dec]
forall a. a -> [a] -> [a]
: [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Dec]]
pats