-- editorconfig-checker-disable
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}

module PlutusCore.Crypto.BLS12_381.G2
  ( Element (..)
  , add
  , neg
  , scalarMul
  , hashToGroup
  , compress
  , uncompress
  , offchain_zero
  , compressed_zero
  , compressed_generator
  , memSizeBytes
  , compressedSizeBytes
  , multiScalarMul
  ) where

import Cardano.Crypto.EllipticCurve.BLS12_381 qualified as BlstBindings
import Cardano.Crypto.EllipticCurve.BLS12_381.Internal qualified as BlstBindings.Internal

import PlutusCore.Builtin.Result (BuiltinResult (..))
import PlutusCore.Crypto.BLS12_381.Bounds (msmScalarOutOfBounds)
import PlutusCore.Crypto.BLS12_381.Error (BLS12_381_Error (..))
import PlutusCore.Crypto.Utils (byteStringAsHex)
import PlutusCore.Pretty.PrettyConst (ConstConfig)
import Text.PrettyBy (PrettyBy)

import Control.DeepSeq
  ( NFData
  , rnf
  , rwhnf
  )
import Data.ByteString
  ( ByteString
  , length
  )
import Data.Coerce (coerce)
import Data.Hashable
import Data.Proxy (Proxy (..))
import PlutusCore.Flat
import Prettyprinter

-- | See Note [Wrapping the BLS12-381 types in Plutus Core].
newtype Element = Element {Element -> Point Curve2
unElement :: BlstBindings.Point2}
  deriving newtype (Element -> Element -> Bool
(Element -> Element -> Bool)
-> (Element -> Element -> Bool) -> Eq Element
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Element -> Element -> Bool
== :: Element -> Element -> Bool
$c/= :: Element -> Element -> Bool
/= :: Element -> Element -> Bool
Eq)

instance Show Element where
  show :: Element -> String
show = ByteString -> String
byteStringAsHex (ByteString -> String)
-> (Element -> ByteString) -> Element -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> ByteString
compress
instance Pretty Element where
  pretty :: forall ann. Element -> Doc ann
pretty = String -> Doc ann
forall ann. String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (String -> Doc ann) -> (Element -> String) -> Element -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> String
forall a. Show a => a -> String
show
instance PrettyBy ConstConfig Element

{-| We don't support direct flat encoding of G1 elements because of the expense
   of on-chain uncompression.  Users should convert between G1 elements and
   bytestrings using `compress` and `uncompress`: the bytestrings can be
   flat-encoded in the usual way. -}
instance Flat Element where
  -- This might happen on the chain, so `fail` rather than `error`.
  decode :: Get Element
decode = String -> Get Element
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Flat decoding is not supported for objects of type bls12_381_G2_element: use bls12_381_G2_uncompress on a bytestring instead."

  -- This will be a Haskell runtime error, but encoding doesn't happen on chain,
  -- so it's not too bad.
  encode :: Element -> Encoding
encode = String -> Element -> Encoding
forall a. HasCallStack => String -> a
error String
"Flat encoding is not supported for objects of type bls12_381_G2_element: use bls12_381_G2_compress to obtain a bytestring instead."
  size :: Element -> Int -> Int
size Element
_ = Int -> Int
forall a. a -> a
id

instance NFData Element where
  rnf :: Element -> ()
rnf (Element Point Curve2
x) = Point Curve2 -> ()
forall a. a -> ()
rwhnf Point Curve2
x -- Just to be on the safe side.

instance Hashable Element where
  hashWithSalt :: Int -> Element -> Int
hashWithSalt Int
salt = Int -> ByteString -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (ByteString -> Int) -> (Element -> ByteString) -> Element -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> ByteString
compress

-- | Add two G2 group elements
add :: Element -> Element -> Element
add :: Element -> Element -> Element
add = (Point Curve2 -> Point Curve2 -> Point Curve2)
-> Element -> Element -> Element
forall a b. Coercible a b => a -> b
coerce (forall curve.
BLS curve =>
Point curve -> Point curve -> Point curve
BlstBindings.blsAddOrDouble @BlstBindings.Curve2)
{-# INLINE add #-}

-- | Negate a G2 group element
neg :: Element -> Element
neg :: Element -> Element
neg = (Point Curve2 -> Point Curve2) -> Element -> Element
forall a b. Coercible a b => a -> b
coerce (forall curve. BLS curve => Point curve -> Point curve
BlstBindings.blsNeg @BlstBindings.Curve2)
{-# INLINE neg #-}

scalarMul :: Integer -> Element -> Element -- Other way round from library function
scalarMul :: Integer -> Element -> Element
scalarMul = (Integer -> Point Curve2 -> Point Curve2)
-> Integer -> Element -> Element
forall a b. Coercible a b => a -> b
coerce ((Integer -> Point Curve2 -> Point Curve2)
 -> Integer -> Element -> Element)
-> (Integer -> Point Curve2 -> Point Curve2)
-> Integer
-> Element
-> Element
forall a b. (a -> b) -> a -> b
$ (Point Curve2 -> Integer -> Point Curve2)
-> Integer -> Point Curve2 -> Point Curve2
forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall curve. BLS curve => Point curve -> Integer -> Point curve
BlstBindings.blsMult @BlstBindings.Curve2)
{-# INLINE scalarMul #-}

{-| Compress a G2 element to a bytestring. This serialises a curve point to its x
 coordinate only, using an extra bit to determine which of two possible y
 coordinates the point has. The compressed bytestring is 96 bytes long. See
 https://github.com/supranational/blst#serialization-format -}
compress :: Element -> ByteString
compress :: Element -> ByteString
compress = (Point Curve2 -> ByteString) -> Element -> ByteString
forall a b. Coercible a b => a -> b
coerce (forall curve. BLS curve => Point curve -> ByteString
BlstBindings.blsCompress @BlstBindings.Curve2)
{-# INLINE compress #-}

{-| Uncompress a bytestring to get a G2 point.  This will fail if any of the
   following are true:
     * The bytestring is not exactly 96 bytes long
     * The most significant three bits are used incorrectly
     * The bytestring encodes a field element which is not the
       x coordinate of a point on the E2 curve
     * The bytestring does represent a point on the E2 curve, but the
       point is not in the G2 subgroup -}
uncompress :: ByteString -> Either BlstBindings.BLSTError Element
uncompress :: ByteString -> Either BLSTError Element
uncompress = (ByteString -> Either BLSTError (Point Curve2))
-> ByteString -> Either BLSTError Element
forall a b. Coercible a b => a -> b
coerce (forall curve.
BLS curve =>
ByteString -> Either BLSTError (Point curve)
BlstBindings.blsUncompress @BlstBindings.Curve2)
{-# INLINE uncompress #-}

-- Take an arbitrary bytestring and a Domain Separation Tag and hash them to a
-- get point in G2.  See Note [Hashing and Domain Separation Tags].
hashToGroup :: ByteString -> ByteString -> Either BLS12_381_Error Element
hashToGroup :: ByteString -> ByteString -> Either BLS12_381_Error Element
hashToGroup ByteString
msg ByteString
dst =
  if ByteString -> Int
Data.ByteString.length ByteString
dst Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
255
    then BLS12_381_Error -> Either BLS12_381_Error Element
forall a b. a -> Either a b
Left BLS12_381_Error
HashToCurveDstTooBig
    else Element -> Either BLS12_381_Error Element
forall a b. b -> Either a b
Right (Element -> Either BLS12_381_Error Element)
-> (Point Curve2 -> Element)
-> Point Curve2
-> Either BLS12_381_Error Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point Curve2 -> Element
Element (Point Curve2 -> Either BLS12_381_Error Element)
-> Point Curve2 -> Either BLS12_381_Error Element
forall a b. (a -> b) -> a -> b
$ forall curve.
BLS curve =>
ByteString -> Maybe ByteString -> Maybe ByteString -> Point curve
BlstBindings.blsHash @BlstBindings.Curve2 ByteString
msg (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
dst) Maybe ByteString
forall a. Maybe a
Nothing

{-| The zero element of G2.  This cannot be flat-serialised and is provided
only for off-chain testing. -}
offchain_zero :: Element
offchain_zero :: Element
offchain_zero = Point Curve2 -> Element
forall a b. Coercible a b => a -> b
coerce (forall curve. BLS curve => Point curve
BlstBindings.Internal.blsZero @BlstBindings.Curve2)

{-| The zero element of G2 compressed into a bytestring.  This is provided for
convenience in PlutusTx and is not exported as a builtin. -}
compressed_zero :: ByteString
compressed_zero :: ByteString
compressed_zero = Element -> ByteString
compress (Element -> ByteString) -> Element -> ByteString
forall a b. (a -> b) -> a -> b
$ Point Curve2 -> Element
forall a b. Coercible a b => a -> b
coerce (forall curve. BLS curve => Point curve
BlstBindings.Internal.blsZero @BlstBindings.Curve2)

{-| The standard generator of G2 compressed into a bytestring.  This is
provided for convenience in PlutusTx and is not exported as a builtin. -}
compressed_generator :: ByteString
compressed_generator :: ByteString
compressed_generator = Element -> ByteString
compress (Element -> ByteString) -> Element -> ByteString
forall a b. (a -> b) -> a -> b
$ Point Curve2 -> Element
forall a b. Coercible a b => a -> b
coerce (forall curve. BLS curve => Point curve
BlstBindings.Internal.blsGenerator @BlstBindings.Curve2)

-- Utilities (not exposed as builtins)

-- | Memory usage of a G2 point (288 bytes)
memSizeBytes :: Int
memSizeBytes :: Int
memSizeBytes = Proxy Curve2 -> Int
forall curve. BLS curve => Proxy curve -> Int
BlstBindings.Internal.sizePoint (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @BlstBindings.Curve2)

-- | Compressed size of a G2 point (96 bytes)
compressedSizeBytes :: Int
compressedSizeBytes :: Int
compressedSizeBytes = Proxy Curve2 -> Int
forall curve. BLS curve => Proxy curve -> Int
BlstBindings.Internal.compressedSizePoint (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @BlstBindings.Curve2)

{-| Multi-scalar multiplication of G2 points.  We limit the allowable size of
scalars to simplify costing. -}
multiScalarMul :: [Integer] -> [Element] -> BuiltinResult Element
multiScalarMul :: [Integer] -> [Element] -> BuiltinResult Element
multiScalarMul [Integer]
ss [Element]
p
  | (Integer -> Bool) -> [Integer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Integer -> Bool
msmScalarOutOfBounds [Integer]
ss = String -> BuiltinResult Element
forall a. String -> BuiltinResult a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Scalar exceeds 512-byte bound for G2.multiScalarMul"
  | Bool
otherwise = Element -> BuiltinResult Element
forall a. a -> BuiltinResult a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Element -> BuiltinResult Element)
-> (Point Curve2 -> Element)
-> Point Curve2
-> BuiltinResult Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point Curve2 -> Element
forall a b. Coercible a b => a -> b
coerce (Point Curve2 -> BuiltinResult Element)
-> Point Curve2 -> BuiltinResult Element
forall a b. (a -> b) -> a -> b
$ forall curve. BLS curve => [(Integer, Point curve)] -> Point curve
BlstBindings.blsMSM @BlstBindings.Curve2 ([Integer] -> [Point Curve2] -> [(Integer, Point Curve2)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer]
ss ([Element] -> [Point Curve2]
forall a b. Coercible a b => a -> b
coerce [Element]
p))