module Codec.Extras.FlatViaSerialise
    ( FlatViaSerialise (..)
    ) where

import Codec.Serialise (Serialise, deserialiseOrFail, serialise)
import Data.ByteString.Lazy qualified as BSL (toStrict)
import Flat

{- Note [Flat serialisation for strict and lazy bytestrings]
The `flat` serialisation of a bytestring consists of a sequence of chunks, with each chunk preceded
by a single byte saying how long it is.  The end of a serialised bytestring is marked by a
zero-length chunk.  In the Plutus Core specification we recommend that all bytestrings should be
serialised in a canonical way as a sequence of zero or more 255-byte chunks followed by an optional
final chunk of length less than 255 followed by a zero-length chunk (ie, a 0x00 byte). We do allow
the decoder to accept non-canonical encodings.  The `flat` library always encodes strict Haskell
bytestrings in this way, but lazy bytestrings, which are essentially lists of strict bytestrings,
may be encoded non-canonically since it's more efficient just to emit a short chunk as is.  The
Plutus Core `bytestring` type is strict so bytestring values are always encoded canonically.
However, we serialise `Data` objects (and perhaps objects of other types as well) by encoding them
to CBOR and then flat-serialising the resulting bytestring; but the `serialise` method from
`Codec.Serialise` produces lazy bytestrings and if we were to serialise them directly then we could
end up with non-canonical encodings, which would mean that identical `Data` objects might be
serialised into different bytestrings.  To avoid this we convert the output of `serialise` into a
strict bytestring before flat-encoding it.  This may lead to a small loss of efficiency during
encoding, but this doesn't matter because we only ever do flat serialisation off the chain.  We can
convert `Data` objects to bytestrings on the chain using the `serialiseData` builtin, but this
performs CBOR serialisation and the result is always in a canonical form. -}

-- | For deriving 'Flat' instances via 'Serialize'.
newtype FlatViaSerialise a = FlatViaSerialise { forall a. FlatViaSerialise a -> a
unFlatViaSerialise :: a }

instance Serialise a => Flat (FlatViaSerialise a) where
    -- See Note [Flat serialisation for strict and lazy bytestrings]
    encode :: FlatViaSerialise a -> Encoding
encode = ByteString -> Encoding
forall a. Flat a => a -> Encoding
encode (ByteString -> Encoding)
-> (FlatViaSerialise a -> ByteString)
-> FlatViaSerialise a
-> Encoding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.toStrict (ByteString -> ByteString)
-> (FlatViaSerialise a -> ByteString)
-> FlatViaSerialise a
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ByteString
forall a. Serialise a => a -> ByteString
serialise (a -> ByteString)
-> (FlatViaSerialise a -> a) -> FlatViaSerialise a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatViaSerialise a -> a
forall a. FlatViaSerialise a -> a
unFlatViaSerialise
    decode :: Get (FlatViaSerialise a)
decode = do
        Either DeserialiseFailure a
errOrX <- ByteString -> Either DeserialiseFailure a
forall a. Serialise a => ByteString -> Either DeserialiseFailure a
deserialiseOrFail (ByteString -> Either DeserialiseFailure a)
-> Get ByteString -> Get (Either DeserialiseFailure a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
forall a. Flat a => Get a
decode
        case Either DeserialiseFailure a
errOrX of
            Left DeserialiseFailure
err -> String -> Get (FlatViaSerialise a)
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get (FlatViaSerialise a))
-> String -> Get (FlatViaSerialise a)
forall a b. (a -> b) -> a -> b
$ DeserialiseFailure -> String
forall a. Show a => a -> String
show DeserialiseFailure
err  -- Here we embed a 'Serialise' error into a 'Flat' one.
            Right a
x  -> FlatViaSerialise a -> Get (FlatViaSerialise a)
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FlatViaSerialise a -> Get (FlatViaSerialise a))
-> FlatViaSerialise a -> Get (FlatViaSerialise a)
forall a b. (a -> b) -> a -> b
$ a -> FlatViaSerialise a
forall a. a -> FlatViaSerialise a
FlatViaSerialise a
x
    size :: FlatViaSerialise a -> NumBits -> NumBits
size = ByteString -> NumBits -> NumBits
forall a. Flat a => a -> NumBits -> NumBits
size (ByteString -> NumBits -> NumBits)
-> (FlatViaSerialise a -> ByteString)
-> FlatViaSerialise a
-> NumBits
-> NumBits
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.toStrict (ByteString -> ByteString)
-> (FlatViaSerialise a -> ByteString)
-> FlatViaSerialise a
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ByteString
forall a. Serialise a => a -> ByteString
serialise (a -> ByteString)
-> (FlatViaSerialise a -> a) -> FlatViaSerialise a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FlatViaSerialise a -> a
forall a. FlatViaSerialise a -> a
unFlatViaSerialise