{-# LANGUAGE BangPatterns #-}

module PlutusCore.Evaluation.Machine.CostStream
    ( CostStream(..)
    , unconsCost
    , reconsCost
    , sumCostStream
    , mapCostStream
    , addCostStream
    , minCostStream
    ) where

import PlutusCore.Evaluation.Machine.ExMemory


{- Note [Single-element streams]
Both 'CostStream' and 'ExBudgetStream' are semantically equivalent to 'NonEmpty' (modulo strictness)
except instead of making the first element of each of these stream types a special one, we make
special the last one. The reason for this is that we want to maximally optimize the case of a
single-element stream, because it's the most common one and with a 'NonEmpty'-style data type we'd
have to pattern match twice in order to extract the value from the head of the stream and make sure
it's a single one, while with our approach it's only one pattern match. Plus we don't need to touch
any recursive parts at all when we're operating with single-element streams. It comes at a cost
however: recursive functions over streams often need to have the single-element case hardcoded even
when the general recursion would suffice, because GHC can't inline recursive functions and we rely
on inlining heavily. So we often manually unwrap one step of recursion just to make the wrapper
inlineable allowing for optimized handling of single-element streams.
-}

-- See Note [Single-element streams]
-- | A lazy stream of 'CostingInteger's. Basically @NonEmpty CostingInteger@, except the elements
-- are stored strictly.
--
-- The semantics of a stream are those of the sum of its elements. I.e. a stream that is a reordered
-- version of another stream is considered equal to that stream.
--
-- All costs are assumed not to be negative and functions handling 'CostStream's may rely on this
-- assumption. Negative costs (a.k.a. allowing the user to forge execution units at runtime)
-- wouldn't make sense.
data CostStream
    = CostLast {-# UNPACK #-} !CostingInteger
    | CostCons {-# UNPACK #-} !CostingInteger CostStream
    deriving stock (Int -> CostStream -> ShowS
[CostStream] -> ShowS
CostStream -> String
(Int -> CostStream -> ShowS)
-> (CostStream -> String)
-> ([CostStream] -> ShowS)
-> Show CostStream
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CostStream -> ShowS
showsPrec :: Int -> CostStream -> ShowS
$cshow :: CostStream -> String
show :: CostStream -> String
$cshowList :: [CostStream] -> ShowS
showList :: [CostStream] -> ShowS
Show)

-- TODO: (# CostingInteger, (# (# #) | CostStream #) #)?
-- | Uncons an element from a 'CostStream' and return the rest of the stream, if not empty.
unconsCost :: CostStream -> (CostingInteger, Maybe CostStream)
unconsCost :: CostStream -> (CostingInteger, Maybe CostStream)
unconsCost (CostLast CostingInteger
cost)       = (CostingInteger
cost, Maybe CostStream
forall a. Maybe a
Nothing)
unconsCost (CostCons CostingInteger
cost CostStream
costs) = (CostingInteger
cost, CostStream -> Maybe CostStream
forall a. a -> Maybe a
Just CostStream
costs)
{-# INLINE unconsCost #-}

-- | Cons an element to a 'CostStream', if given any. Otherwise create a new 'CostStream' using
-- 'CostLast'.
reconsCost :: CostingInteger -> Maybe CostStream -> CostStream
reconsCost :: CostingInteger -> Maybe CostStream -> CostStream
reconsCost CostingInteger
cost = CostStream
-> (CostStream -> CostStream) -> Maybe CostStream -> CostStream
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (CostingInteger -> CostStream
CostLast CostingInteger
cost) (CostingInteger -> CostStream -> CostStream
CostCons CostingInteger
cost)
{-# INLINE reconsCost #-}

{- Note [Global local functions]
Normally when defining a helper function one would put it into a @where@ or a @let@ block.
However if the enclosing function gets inlined, then the definition of the helper one gets inlined
too, which when happens in multiple places can create serious GHC Core bloat, making it really hard
to analyze the generated code. Hence in some cases we optimize for lower amounts of produced GHC
Core by turning some helper functions into global ones.

This doesn't work as well when the helper function captures a variable bound by the enclosing one,
so we leave such helper functions local. We could probably create a global helper and a local
function within it instead, but so far it doesn't appear as those capturing helpers actually get
duplicated in the generated Core.
-}

-- See Note [Global local functions].
sumCostStreamGo :: CostingInteger -> CostStream -> CostingInteger
sumCostStreamGo :: CostingInteger -> CostStream -> CostingInteger
sumCostStreamGo !CostingInteger
acc (CostLast CostingInteger
cost)       = CostingInteger
acc CostingInteger -> CostingInteger -> CostingInteger
forall a. Num a => a -> a -> a
+ CostingInteger
cost
sumCostStreamGo !CostingInteger
acc (CostCons CostingInteger
cost CostStream
costs) = CostingInteger -> CostStream -> CostingInteger
sumCostStreamGo (CostingInteger
acc CostingInteger -> CostingInteger -> CostingInteger
forall a. Num a => a -> a -> a
+ CostingInteger
cost) CostStream
costs

-- | Add up all the costs in a 'CostStream'.
sumCostStream :: CostStream -> CostingInteger
sumCostStream :: CostStream -> CostingInteger
sumCostStream (CostLast CostingInteger
cost0)        = CostingInteger
cost0
sumCostStream (CostCons CostingInteger
cost0 CostStream
costs0) = CostingInteger -> CostStream -> CostingInteger
sumCostStreamGo CostingInteger
cost0 CostStream
costs0
{-# INLINE sumCostStream #-}

-- See Note [Global local functions].
-- | Map a function over a 'CostStream'.
mapCostStream :: (CostingInteger -> CostingInteger) -> CostStream -> CostStream
-- See Note [Single-element streams]
mapCostStream :: (CostingInteger -> CostingInteger) -> CostStream -> CostStream
mapCostStream CostingInteger -> CostingInteger
f (CostLast CostingInteger
cost0)        = CostingInteger -> CostStream
CostLast (CostingInteger -> CostingInteger
f CostingInteger
cost0)
mapCostStream CostingInteger -> CostingInteger
f (CostCons CostingInteger
cost0 CostStream
costs0) = CostingInteger -> CostStream -> CostStream
CostCons (CostingInteger -> CostingInteger
f CostingInteger
cost0) (CostStream -> CostStream) -> CostStream -> CostStream
forall a b. (a -> b) -> a -> b
$ CostStream -> CostStream
go CostStream
costs0 where
    go :: CostStream -> CostStream
    go :: CostStream -> CostStream
go (CostLast CostingInteger
cost)       = CostingInteger -> CostStream
CostLast (CostingInteger -> CostingInteger
f CostingInteger
cost)
    go (CostCons CostingInteger
cost CostStream
costs) = CostingInteger -> CostStream -> CostStream
CostCons (CostingInteger -> CostingInteger
f CostingInteger
cost) (CostStream -> CostStream) -> CostStream -> CostStream
forall a b. (a -> b) -> a -> b
$ CostStream -> CostStream
go CostStream
costs
{-# INLINE mapCostStream #-}

-- See Note [Global local functions].
addCostStreamGo :: CostStream -> CostStream -> CostStream
addCostStreamGo :: CostStream -> CostStream -> CostStream
addCostStreamGo (CostLast CostingInteger
costL)        CostStream
costsR = CostingInteger -> CostStream -> CostStream
CostCons CostingInteger
costL CostStream
costsR
addCostStreamGo (CostCons CostingInteger
costL CostStream
costsL) CostStream
costsR = CostingInteger -> CostStream -> CostStream
CostCons CostingInteger
costL (CostStream -> CostStream) -> CostStream -> CostStream
forall a b. (a -> b) -> a -> b
$ CostStream -> CostStream -> CostStream
addCostStreamGo CostStream
costsR CostStream
costsL

-- | Add two streams by interleaving their elements (as opposed to draining out one of the streams
-- before starting to take elements from the other one). No particular reason to prefer
-- interleaving over draining out one of the streams first.
addCostStream :: CostStream -> CostStream -> CostStream
addCostStream :: CostStream -> CostStream -> CostStream
addCostStream CostStream
costsL0 CostStream
costsR0 = case (CostStream
costsL0, CostStream
costsR0) of
    -- See Note [Single-element streams].
    (CostLast CostingInteger
costL, CostLast CostingInteger
costR) -> CostingInteger -> CostStream
CostLast (CostingInteger -> CostStream) -> CostingInteger -> CostStream
forall a b. (a -> b) -> a -> b
$ CostingInteger
costL CostingInteger -> CostingInteger -> CostingInteger
forall a. Num a => a -> a -> a
+ CostingInteger
costR
    (CostStream, CostStream)
_                                -> CostStream -> CostStream -> CostStream
addCostStreamGo CostStream
costsL0 CostStream
costsR0
{-# INLINE addCostStream #-}

-- See Note [Global local functions].
-- Didn't attempt to optimize it.
minCostStreamGo :: CostStream -> CostStream -> CostStream
minCostStreamGo :: CostStream -> CostStream -> CostStream
minCostStreamGo CostStream
costsL CostStream
costsR =
    -- Peel off a cost from each of the streams, if there's any, compare the two costs, emit
    -- the minimum cost to the outside and recurse. If the two elements aren't equal, then we put
    -- the difference between them back to the stream that had the greatest cost (thus subtracting
    -- the minimum cost from the stream -- since we just accounted for it by lazily emitting it to
    -- the outside). Proceed until one of the streams is drained out.
    let (!CostingInteger
costL, !Maybe CostStream
mayCostsL') = CostStream -> (CostingInteger, Maybe CostStream)
unconsCost CostStream
costsL
        (!CostingInteger
costR, !Maybe CostStream
mayCostsR') = CostStream -> (CostingInteger, Maybe CostStream)
unconsCost CostStream
costsR
        (!CostingInteger
costMin, !Maybe CostStream
mayCostsL'', !Maybe CostStream
mayCostsR'') = case CostingInteger
costL CostingInteger -> CostingInteger -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` CostingInteger
costR of
            Ordering
LT -> (CostingInteger
costL, Maybe CostStream
mayCostsL', CostStream -> Maybe CostStream
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CostStream -> Maybe CostStream) -> CostStream -> Maybe CostStream
forall a b. (a -> b) -> a -> b
$ CostingInteger -> Maybe CostStream -> CostStream
reconsCost (CostingInteger
costR CostingInteger -> CostingInteger -> CostingInteger
forall a. Num a => a -> a -> a
- CostingInteger
costL) Maybe CostStream
mayCostsR')
            Ordering
EQ -> (CostingInteger
costL, Maybe CostStream
mayCostsL', Maybe CostStream
mayCostsR')
            Ordering
GT -> (CostingInteger
costR, CostStream -> Maybe CostStream
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CostStream -> Maybe CostStream) -> CostStream -> Maybe CostStream
forall a b. (a -> b) -> a -> b
$ CostingInteger -> Maybe CostStream -> CostStream
reconsCost (CostingInteger
costL CostingInteger -> CostingInteger -> CostingInteger
forall a. Num a => a -> a -> a
- CostingInteger
costR) Maybe CostStream
mayCostsL', Maybe CostStream
mayCostsR')
    in CostingInteger -> Maybe CostStream -> CostStream
reconsCost CostingInteger
costMin (Maybe CostStream -> CostStream) -> Maybe CostStream -> CostStream
forall a b. (a -> b) -> a -> b
$ CostStream -> CostStream -> CostStream
minCostStreamGo (CostStream -> CostStream -> CostStream)
-> Maybe CostStream -> Maybe (CostStream -> CostStream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe CostStream
mayCostsL'' Maybe (CostStream -> CostStream)
-> Maybe CostStream -> Maybe CostStream
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe CostStream
mayCostsR''

-- | Calculate the minimum of two 'CostStream's. May return a stream that is longer than either of
-- the two (but not more than twice).
minCostStream :: CostStream -> CostStream -> CostStream
minCostStream :: CostStream -> CostStream -> CostStream
minCostStream CostStream
costsL0 CostStream
costsR0 = case (CostStream
costsL0, CostStream
costsR0) of
    -- See Note [Single-element streams].
    (CostLast CostingInteger
costL, CostLast CostingInteger
costR) -> CostingInteger -> CostStream
CostLast (CostingInteger -> CostStream) -> CostingInteger -> CostStream
forall a b. (a -> b) -> a -> b
$ CostingInteger -> CostingInteger -> CostingInteger
forall a. Ord a => a -> a -> a
min CostingInteger
costL CostingInteger
costR
    (CostStream, CostStream)
_                                -> CostStream -> CostStream -> CostStream
minCostStreamGo CostStream
costsL0 CostStream
costsR0
{-# INLINE minCostStream #-}