{-# LANGUAGE AllowAmbiguousTypes      #-}
{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE DataKinds                #-}
{-# LANGUAGE FlexibleInstances        #-}
{-# LANGUAGE KindSignatures           #-}
{-# LANGUAGE MultiParamTypeClasses    #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications         #-}
{-# LANGUAGE TypeFamilies             #-}
{-# LANGUAGE TypeOperators            #-}
{-# LANGUAGE UndecidableInstances     #-}
module UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter where

import Control.Monad.Primitive
import Data.Coerce (coerce)
import Data.Kind
import Data.Primitive qualified as P
import Data.Proxy
import Data.Word
import GHC.TypeNats (KnownNat, Nat, natVal, type (-))

-- See Note [Step counter data structure]
-- You might think that since we can store whatever we like in here we might as well
-- use machine words (i.e. 'Word64'), but that is actually slower.
-- | A set of 'Word8' counters that is used in the CEK machine
-- to count steps.
newtype StepCounter (n :: Nat) s = StepCounter (P.MutablePrimArray s Word8)

-- | Make a new 'StepCounter' with the given number of counters.
newCounter :: (KnownNat n, PrimMonad m) => Proxy n -> m (StepCounter n (PrimState m))
newCounter :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
Proxy n -> m (StepCounter n (PrimState m))
newCounter Proxy n
p = do
  let sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal Proxy n
p
  StepCounter n (PrimState m)
c <- MutablePrimArray (PrimState m) Word8 -> StepCounter n (PrimState m)
forall (n :: Nat) s. MutablePrimArray s Word8 -> StepCounter n s
StepCounter (MutablePrimArray (PrimState m) Word8
 -> StepCounter n (PrimState m))
-> m (MutablePrimArray (PrimState m) Word8)
-> m (StepCounter n (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m (MutablePrimArray (PrimState m) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPrimArray Int
sz
  StepCounter n (PrimState m) -> m ()
forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
StepCounter n (PrimState m) -> m ()
resetCounter StepCounter n (PrimState m)
c
  StepCounter n (PrimState m) -> m (StepCounter n (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StepCounter n (PrimState m)
c
{-# INLINE newCounter #-}

-- | Reset all the counters in the given 'StepCounter' to zero.
resetCounter :: forall n m . (KnownNat n, PrimMonad m) => StepCounter n (PrimState m) -> m ()
resetCounter :: forall (n :: Nat) (m :: * -> *).
(KnownNat n, PrimMonad m) =>
StepCounter n (PrimState m) -> m ()
resetCounter (StepCounter MutablePrimArray (PrimState m) Word8
arr) =
  let sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)
  in MutablePrimArray (PrimState m) Word8 -> Int -> Int -> Word8 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> a -> m ()
P.setPrimArray MutablePrimArray (PrimState m) Word8
arr Int
0 Int
sz Word8
0
{-# INLINE resetCounter #-}

-- | Read the value of a counter.
readCounter :: forall m n . PrimMonad m => StepCounter n (PrimState m) -> Int -> m Word8
readCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> m Word8
readCounter =
  forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce
  @(P.MutablePrimArray (PrimState m) Word8 -> Int -> m Word8)
  @(StepCounter n (PrimState m) -> Int -> m Word8)
  MutablePrimArray (PrimState m) Word8 -> Int -> m Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
P.readPrimArray
{-# INLINE readCounter #-}

-- | Write to a counter.
writeCounter
  :: forall m n
  . PrimMonad m
  => StepCounter n (PrimState m)
  -> Int
  -> Word8
  -> m ()
writeCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> Word8 -> m ()
writeCounter =
  forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce
  @(P.MutablePrimArray (PrimState m) Word8 -> Int -> Word8 -> m ())
  @(StepCounter n (PrimState m) -> Int -> Word8 -> m ())
  MutablePrimArray (PrimState m) Word8 -> Int -> Word8 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
P.writePrimArray
{-# INLINE writeCounter #-}

-- | Modify the value of a counter. Returns the modified value.
modifyCounter
  :: PrimMonad m
  => Int
  -> (Word8 -> Word8)
  -> StepCounter n (PrimState m)
  -> m Word8
modifyCounter :: forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
Int -> (Word8 -> Word8) -> StepCounter n (PrimState m) -> m Word8
modifyCounter Int
i Word8 -> Word8
f StepCounter n (PrimState m)
c = do
  Word8
v <- StepCounter n (PrimState m) -> Int -> m Word8
forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> m Word8
readCounter StepCounter n (PrimState m)
c Int
i
  let modified :: Word8
modified = Word8 -> Word8
f Word8
v
  StepCounter n (PrimState m) -> Int -> Word8 -> m ()
forall (m :: * -> *) (n :: Nat).
PrimMonad m =>
StepCounter n (PrimState m) -> Int -> Word8 -> m ()
writeCounter StepCounter n (PrimState m)
c Int
i Word8
modified
  Word8 -> m Word8
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Word8
modified
{-# INLINE modifyCounter #-}

-- | The type of natural numbers in Peano form.
data Peano
  = Z
  | S Peano

type NatToPeano :: Nat -> Peano
type family NatToPeano n where
    NatToPeano 0 = 'Z
    NatToPeano n = 'S (NatToPeano (n - 1))

type UpwardsM :: (Type -> Type) -> Peano -> Constraint
class Applicative f => UpwardsM f n where
  -- | @upwardsM i k@ means @k i *> k (i + 1) *> ... *> k (i + n - 1)@.
  -- We use this function in order to statically unroll a loop in 'itraverseCounter_' through
  -- instance resolution. This makes @validation@ benchmarks a couple of percent faster.
  upwardsM :: Int -> (Int -> f ()) -> f ()

instance Applicative f => UpwardsM f 'Z where
  upwardsM :: Int -> (Int -> f ()) -> f ()
upwardsM Int
_ Int -> f ()
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  {-# INLINE upwardsM #-}

instance UpwardsM f n => UpwardsM f ('S n) where
  upwardsM :: Int -> (Int -> f ()) -> f ()
upwardsM !Int
i Int -> f ()
k = Int -> f ()
k Int
i f () -> f () -> f ()
forall a b. f a -> f b -> f b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) (n :: Peano).
UpwardsM f n =>
Int -> (Int -> f ()) -> f ()
upwardsM @f @n (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> f ()
k
  {-# INLINE upwardsM #-}

-- | Traverse the counters with an effectful function.
itraverseCounter_
  :: forall n m
  . (UpwardsM m (NatToPeano n), PrimMonad m)
  => (Int -> Word8 -> m ())
  -> StepCounter n (PrimState m)
  -> m ()
itraverseCounter_ :: forall (n :: Nat) (m :: * -> *).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
(Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
itraverseCounter_ Int -> Word8 -> m ()
f (StepCounter MutablePrimArray (PrimState m) Word8
arr) = do
  -- The safety of this operation is a little subtle. The frozen array is only
  -- safe to use if the underlying mutable array is not mutated 'afterwards'.
  -- In our case it likely _will_ be mutated afterwards... but not until we
  -- are done with the frozen version. That ordering is enforced by the fact that
  -- the whole thing runs in 'm': future accesses to the mutable array can't
  -- happen until this whole function is finished.
  PrimArray Word8
arr' <- MutablePrimArray (PrimState m) Word8 -> m (PrimArray Word8)
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
P.unsafeFreezePrimArray MutablePrimArray (PrimState m) Word8
arr
  forall (f :: * -> *) (n :: Peano).
UpwardsM f n =>
Int -> (Int -> f ()) -> f ()
upwardsM @_ @(NatToPeano n) Int
0 ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int -> Word8 -> m ()
f Int
i (Word8 -> m ()) -> Word8 -> m ()
forall a b. (a -> b) -> a -> b
$ PrimArray Word8 -> Int -> Word8
forall a. Prim a => PrimArray a -> Int -> a
P.indexPrimArray PrimArray Word8
arr' Int
i
{-# INLINE itraverseCounter_ #-}

-- | Traverse the counters with an effectful function.
iforCounter_
  :: (UpwardsM m (NatToPeano n), PrimMonad m)
  => StepCounter n (PrimState m)
  -> (Int -> Word8 -> m ())
  -> m ()
iforCounter_ :: forall (m :: * -> *) (n :: Nat).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
StepCounter n (PrimState m) -> (Int -> Word8 -> m ()) -> m ()
iforCounter_ = ((Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ())
-> StepCounter n (PrimState m) -> (Int -> Word8 -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
forall (n :: Nat) (m :: * -> *).
(UpwardsM m (NatToPeano n), PrimMonad m) =>
(Int -> Word8 -> m ()) -> StepCounter n (PrimState m) -> m ()
itraverseCounter_
{-# INLINE iforCounter_ #-}

{- Note [Step counter data structure]
The step counter data structure has had several iterations.

Previously we used a "word array", which was a single 'Word64' into which we
packed 8 'Word8's. This worked pretty well: it was pure, and everything reduced
to a bunch of primitive integer operations.

However, it has a key limitation which is that it can only hold 8 counters.
The obvious attempt to expand it to use a 'Word128' performed badly.

The 'PrimArray' approach on the other hand was fairly competitive with the
original 'WordArray', and scales fine to more than 8 counters, so we switched
to using that instead.
-}