{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeApplications #-}

module UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter where

import PlutusCore.Unroll (NatToPeano, UpwardsM (..))

import Control.Monad.Primitive
import Data.Coerce (coerce)
import Data.Primitive qualified as P
import Data.Proxy
import Data.Word
import GHC.TypeNats (KnownNat, Nat, natVal)

-- See Note [Step counter data structure]
-- You might think that since we can store whatever we like in here we might as well
-- use machine words (i.e. 'Word64'), but that is actually slower.
{-| A set of 'Word8' counters that is used in the CEK machine
to count steps. -}
newtype StepCounter (n :: Nat) s = StepCounter (P.MutablePrimArray s Word8)

-- | Make a new 'StepCounter' with the given number of counters.
newCounter :: (KnownNat n, PrimMonad m) => Proxy n -> m (StepCounter n (PrimState m))
newCounter :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
Proxy n -> m (StepCounter n (PrimState m))
newCounter Proxy n
p = do
  let sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal Proxy n
p
  StepCounter n (PrimState m)
c <- MutablePrimArray (PrimState m) Word8 -> StepCounter n (PrimState m)
forall (n :: Nat) s. MutablePrimArray s Word8 -> StepCounter n s
StepCounter (MutablePrimArray (PrimState m) Word8
 -> StepCounter n (PrimState m))
-> m (MutablePrimArray (PrimState m) Word8)
-> m (StepCounter n (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m (MutablePrimArray (PrimState m) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPrimArray Int
sz
  StepCounter n (PrimState m) -> m ()
forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
StepCounter n (PrimState m) -> m ()
resetCounter StepCounter n (PrimState m)
c
  StepCounter n (PrimState m) -> m (StepCounter n (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StepCounter n (PrimState m)
c
{-# INLINE newCounter #-}

-- | Reset all the counters in the given 'StepCounter' to zero.
resetCounter :: forall n m. (KnownNat n, PrimMonad m) => StepCounter n (PrimState m) -> m ()
resetCounter :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
StepCounter n (PrimState m) -> m ()
resetCounter (StepCounter MutablePrimArray (PrimState m) Word8
arr) =
  let sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)
   in MutablePrimArray (PrimState m) Word8 -> Int -> Int -> Word8 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> a -> m ()
P.setPrimArray MutablePrimArray (PrimState m) Word8
arr Int
0 Int
sz Word8
0
{-# INLINE resetCounter #-}

-- | Read the value of a counter.
readCounter :: forall m n. PrimMonad m => StepCounter n (PrimState m) -> Int -> m Word8
readCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> m Word8
readCounter =
  forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce
    @(P.MutablePrimArray (PrimState m) Word8 -> Int -> m Word8)
    @(StepCounter n (PrimState m) -> Int -> m Word8)
    MutablePrimArray (PrimState m) Word8 -> Int -> m Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
P.readPrimArray
{-# INLINE readCounter #-}

-- | Write to a counter.
writeCounter
  :: forall m n
   . PrimMonad m
  => StepCounter n (PrimState m)
  -> Int
  -> Word8
  -> m ()
writeCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> Word8 -> m ()
writeCounter =
  forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce
    @(P.MutablePrimArray (PrimState m) Word8 -> Int -> Word8 -> m ())
    @(StepCounter n (PrimState m) -> Int -> Word8 -> m ())
    MutablePrimArray (PrimState m) Word8 -> Int -> Word8 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
P.writePrimArray
{-# INLINE writeCounter #-}

-- | Modify the value of a counter. Returns the modified value.
modifyCounter
  :: PrimMonad m
  => Int
  -> (Word8 -> Word8)
  -> StepCounter n (PrimState m)
  -> m Word8
modifyCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
Int -> (Word8 -> Word8) -> StepCounter n (PrimState m) -> m Word8
modifyCounter Int
i Word8 -> Word8
f StepCounter n (PrimState m)
c = do
  Word8
v <- StepCounter n (PrimState m) -> Int -> m Word8
forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> m Word8
readCounter StepCounter n (PrimState m)
c Int
i
  let modified :: Word8
modified = Word8 -> Word8
f Word8
v
  StepCounter n (PrimState m) -> Int -> Word8 -> m ()
forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> Word8 -> m ()
writeCounter StepCounter n (PrimState m)
c Int
i Word8
modified
  Word8 -> m Word8
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Word8
modified
{-# INLINE modifyCounter #-}

-- | Traverse the counters with an effectful function.
itraverseCounter_
  :: forall n m
   . (UpwardsM m (NatToPeano n), PrimMonad m)
  => (Int -> Word8 -> m ())
  -> StepCounter n (PrimState m)
  -> m ()
itraverseCounter_ :: forall (n :: Nat) (m :: * -> *).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
(Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
itraverseCounter_ Int -> Word8 -> m ()
f (StepCounter MutablePrimArray (PrimState m) Word8
arr) = do
  -- The safety of this operation is a little subtle. The frozen array is only
  -- safe to use if the underlying mutable array is not mutated 'afterwards'.
  -- In our case it likely _will_ be mutated afterwards... but not until we
  -- are done with the frozen version. That ordering is enforced by the fact that
  -- the whole thing runs in 'm': future accesses to the mutable array can't
  -- happen until this whole function is finished.
  PrimArray Word8
arr' <- MutablePrimArray (PrimState m) Word8 -> m (PrimArray Word8)
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
P.unsafeFreezePrimArray MutablePrimArray (PrimState m) Word8
arr
  forall (f :: * -> *) (n :: Peano).
UpwardsM f n =>
Int -> (Int -> f ()) -> f ()
upwardsM @_ @(NatToPeano n) Int
0 ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int -> Word8 -> m ()
f Int
i (Word8 -> m ()) -> Word8 -> m ()
forall a b. (a -> b) -> a -> b
$ PrimArray Word8 -> Int -> Word8
forall a. Prim a => PrimArray a -> Int -> a
P.indexPrimArray PrimArray Word8
arr' Int
i
{-# INLINE itraverseCounter_ #-}

-- | Traverse the counters with an effectful function.
iforCounter_
  :: (UpwardsM m (NatToPeano n), PrimMonad m)
  => StepCounter n (PrimState m)
  -> (Int -> Word8 -> m ())
  -> m ()
iforCounter_ :: forall (m :: * -> *) (n :: Nat).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
StepCounter n (PrimState m) -> (Int -> Word8 -> m ()) -> m ()
iforCounter_ = ((Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ())
-> StepCounter n (PrimState m) -> (Int -> Word8 -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
forall (n :: Nat) (m :: * -> *).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
(Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
itraverseCounter_
{-# INLINE iforCounter_ #-}

{- Note [Step counter data structure]
The step counter data structure has had several iterations.

Previously we used a "word array", which was a single 'Word64' into which we
packed 8 'Word8's. This worked pretty well: it was pure, and everything reduced
to a bunch of primitive integer operations.

However, it has a key limitation which is that it can only hold 8 counters.
The obvious attempt to expand it to use a 'Word128' performed badly.

The 'PrimArray' approach on the other hand was fairly competitive with the
original 'WordArray', and scales fine to more than 8 counters, so we switched
to using that instead.
-}