{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE TypeFamilies  #-}
{-# LANGUAGE UnboxedTuples #-}

{- |
Memory access primitives.

Includes code from the [store-core](https://hackage.haskell.org/package/store-core) package.
-}
module PlutusCore.Flat.Memory
  ( chunksToByteString
  , chunksToByteArray
  , ByteArray
  , pokeByteArray
  , pokeByteString
  , unsafeCreateUptoN'
  , minusPtr
  , peekByteString
  )
where

import Control.Monad (foldM_, when)
import Control.Monad.Primitive (PrimMonad (..))
import Data.ByteString qualified as B
import Data.ByteString.Internal qualified as BS
import Data.Primitive.ByteArray (ByteArray, ByteArray#, MutableByteArray (..), newByteArray,
                                 unsafeFreezeByteArray)
import Foreign (Ptr, Word8, minusPtr, plusPtr, withForeignPtr)
import Foreign.Marshal.Utils (copyBytes)
import GHC.Prim (copyAddrToByteArray#, copyByteArrayToAddr#)
import GHC.Ptr (Ptr (..))
import GHC.Types (IO (..), Int (..))
import System.IO.Unsafe (unsafeDupablePerformIO, unsafePerformIO)

unsafeCreateUptoN' :: Int -> (Ptr Word8 -> IO (Int, a)) -> (BS.ByteString, a)
unsafeCreateUptoN' :: forall a. Int -> (Ptr Word8 -> IO (Int, a)) -> (ByteString, a)
unsafeCreateUptoN' Int
l Ptr Word8 -> IO (Int, a)
f = IO (ByteString, a) -> (ByteString, a)
forall a. IO a -> a
unsafeDupablePerformIO (Int -> (Ptr Word8 -> IO (Int, a)) -> IO (ByteString, a)
forall a. Int -> (Ptr Word8 -> IO (Int, a)) -> IO (ByteString, a)
createUptoN' Int
l Ptr Word8 -> IO (Int, a)
f)
{-# INLINE unsafeCreateUptoN' #-}

createUptoN' :: Int -> (Ptr Word8 -> IO (Int, a)) -> IO (BS.ByteString, a)
createUptoN' :: forall a. Int -> (Ptr Word8 -> IO (Int, a)) -> IO (ByteString, a)
createUptoN' Int
l Ptr Word8 -> IO (Int, a)
f = do
  ForeignPtr Word8
fp        <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
BS.mallocByteString Int
l
  (Int
l', a
res) <- ForeignPtr Word8 -> (Ptr Word8 -> IO (Int, a)) -> IO (Int, a)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO (Int, a)) -> IO (Int, a))
-> (Ptr Word8 -> IO (Int, a)) -> IO (Int, a)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr Word8 -> IO (Int, a)
f Ptr Word8
p
  --print (unwords ["Buffer allocated:",show l,"bytes, used:",show l',"bytes"])
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
l) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error
    ([[Char]] -> [Char]
unwords
      [[Char]
"Buffer overflow, allocated:", Int -> [Char]
forall a. Show a => a -> [Char]
show Int
l, [Char]
"bytes, used:", Int -> [Char]
forall a. Show a => a -> [Char]
show Int
l', [Char]
"bytes"]
    )
  (ByteString, a) -> IO (ByteString, a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Word8 -> Int -> Int -> ByteString
BS.PS ForeignPtr Word8
fp Int
0 Int
l', a
res) -- , minusPtr l')
{-# INLINE createUptoN' #-}

-- |Copy bytestring to given pointer, returns new pointer
pokeByteString :: B.ByteString -> Ptr Word8 -> IO (Ptr Word8)
pokeByteString :: ByteString -> Ptr Word8 -> IO (Ptr Word8)
pokeByteString (BS.PS ForeignPtr Word8
foreignPointer Int
sourceOffset Int
sourceLength) Ptr Word8
destPointer =
  do
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
foreignPointer ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sourcePointer -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes
      Ptr Word8
destPointer
      (Ptr Word8
sourcePointer Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
sourceOffset)
      Int
sourceLength
    Ptr Word8 -> IO (Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
destPointer Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
sourceLength)

{-| Create a new bytestring, copying sourceLen bytes from sourcePtr

@since 0.6
-}
peekByteString ::
  Ptr Word8 -- ^ sourcePtr
  -> Int -- ^ sourceLen
  -> BS.ByteString
peekByteString :: Ptr Word8 -> Int -> ByteString
peekByteString Ptr Word8
sourcePtr Int
sourceLength = Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
sourceLength ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
destPointer -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
destPointer Ptr Word8
sourcePtr Int
sourceLength

-- |Copy ByteArray to given pointer, returns new pointer
pokeByteArray :: ByteArray# -> Int -> Int -> Ptr Word8 -> IO (Ptr Word8)
pokeByteArray :: ByteArray# -> Int -> Int -> Ptr Word8 -> IO (Ptr Word8)
pokeByteArray ByteArray#
sourceArr Int
sourceOffset Int
len Ptr Word8
dest = do
  ByteArray# -> Int -> Ptr Word8 -> Int -> IO ()
forall a. ByteArray# -> Int -> Ptr a -> Int -> IO ()
copyByteArrayToAddr ByteArray#
sourceArr Int
sourceOffset Ptr Word8
dest Int
len
  let !dest' :: Ptr Word8
dest' = Ptr Word8
dest Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len
  Ptr Word8 -> IO (Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr Word8
dest'
{-# INLINE pokeByteArray #-}


-- | Wrapper around @copyByteArrayToAddr#@ primop.
--
-- Copied from the store-core package
copyByteArrayToAddr :: ByteArray# -> Int -> Ptr a -> Int -> IO ()
copyByteArrayToAddr :: forall a. ByteArray# -> Int -> Ptr a -> Int -> IO ()
copyByteArrayToAddr ByteArray#
arr (I# Int#
offset) (Ptr Addr#
addr) (I# Int#
len) =
  (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO (\State# RealWorld
s -> (# ByteArray#
-> Int# -> Addr# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
ByteArray# -> Int# -> Addr# -> Int# -> State# d -> State# d
copyByteArrayToAddr# ByteArray#
arr Int#
offset Addr#
addr Int#
len State# RealWorld
s, () #))
{-# INLINE copyByteArrayToAddr #-}

chunksToByteString :: (Ptr Word8, [Int]) -> BS.ByteString
chunksToByteString :: (Ptr Word8, [Int]) -> ByteString
chunksToByteString (Ptr Word8
sourcePtr0, [Int]
lens) =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
lens) ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
destPtr0 -> ((Ptr Word8, Ptr Word8) -> Int -> IO (Ptr Word8, Ptr Word8))
-> (Ptr Word8, Ptr Word8) -> [Int] -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_
    (\(Ptr Word8
destPtr, Ptr Word8
sourcePtr) Int
sourceLength ->
      Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
destPtr Ptr Word8
sourcePtr Int
sourceLength
        IO () -> IO (Ptr Word8, Ptr Word8) -> IO (Ptr Word8, Ptr Word8)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Ptr Word8, Ptr Word8) -> IO (Ptr Word8, Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
             ( Ptr Word8
destPtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
sourceLength
             , Ptr Word8
sourcePtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
sourceLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
             )
    )
    (Ptr Word8
destPtr0, Ptr Word8
sourcePtr0)
    [Int]
lens

chunksToByteArray :: (Ptr Word8, [Int]) -> (ByteArray, Int)
chunksToByteArray :: (Ptr Word8, [Int]) -> (ByteArray, Int)
chunksToByteArray (Ptr Word8
sourcePtr0, [Int]
lens) = IO (ByteArray, Int) -> (ByteArray, Int)
forall a. IO a -> a
unsafePerformIO (IO (ByteArray, Int) -> (ByteArray, Int))
-> IO (ByteArray, Int) -> (ByteArray, Int)
forall a b. (a -> b) -> a -> b
$ do
  let len :: Int
len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
lens
  MutableByteArray RealWorld
arr <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
len
  ((Int, Ptr Word8) -> Int -> IO (Int, Ptr Word8))
-> (Int, Ptr Word8) -> [Int] -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_
    (\(Int
destOff, Ptr Word8
sourcePtr) Int
sourceLength ->
      Ptr Word8 -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
forall a.
Ptr a -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
copyAddrToByteArray Ptr Word8
sourcePtr MutableByteArray RealWorld
MutableByteArray (PrimState IO)
arr Int
destOff Int
sourceLength IO () -> IO (Int, Ptr Word8) -> IO (Int, Ptr Word8)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Ptr Word8) -> IO (Int, Ptr Word8)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
        (Int
destOff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sourceLength, Ptr Word8
sourcePtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
sourceLength Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
    )
    (Int
0, Ptr Word8
sourcePtr0)
    [Int]
lens
  ByteArray
farr <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
arr
  (ByteArray, Int) -> IO (ByteArray, Int)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteArray
farr, Int
len)


-- | Wrapper around @copyAddrToByteArray#@ primop.
--
-- Copied from the store-core package
copyAddrToByteArray
  :: Ptr a -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
copyAddrToByteArray :: forall a.
Ptr a -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
copyAddrToByteArray (Ptr Addr#
addr) (MutableByteArray MutableByteArray# (PrimState IO)
arr) (I# Int#
offset) (I# Int#
len) =
  (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO (\State# RealWorld
s -> (# Addr#
-> MutableByteArray# RealWorld
-> Int#
-> Int#
-> State# RealWorld
-> State# RealWorld
forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# Addr#
addr MutableByteArray# RealWorld
MutableByteArray# (PrimState IO)
arr Int#
offset Int#
len State# RealWorld
s, () #))
{-# INLINE copyAddrToByteArray #-}