{-# LANGUAGE BangPatterns         #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE PatternSynonyms      #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns         #-}
module Data.RandomAccessList.SkewBinary
    ( RAList(Cons,Nil)
    , contIndexZero
    , contIndexOne
    , safeIndexZero
    , unsafeIndexZero
    , safeIndexOne
    , unsafeIndexOne
    , Data.RandomAccessList.SkewBinary.null
    , uncons
    ) where

import Data.Bits (setBit, unsafeShiftL, unsafeShiftR)
import Data.Word
import GHC.Exts

import Data.RandomAccessList.Class qualified as RAL

-- 'Node' appears first to make it more likely for GHC to reorder pattern matches to make the 'Node'
-- one appear first (which makes it more efficient).
-- | A complete binary tree.
-- Note: the size of the tree is not stored/cached,
-- unless it appears as a root tree in 'RAList', which the size is stored inside the Cons.
data Tree a = Node a !(Tree a) !(Tree a)
            | Leaf a
            deriving stock (Tree a -> Tree a -> Bool
(Tree a -> Tree a -> Bool)
-> (Tree a -> Tree a -> Bool) -> Eq (Tree a)
forall a. Eq a => Tree a -> Tree a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Tree a -> Tree a -> Bool
== :: Tree a -> Tree a -> Bool
$c/= :: forall a. Eq a => Tree a -> Tree a -> Bool
/= :: Tree a -> Tree a -> Bool
Eq, Int -> Tree a -> ShowS
[Tree a] -> ShowS
Tree a -> String
(Int -> Tree a -> ShowS)
-> (Tree a -> String) -> ([Tree a] -> ShowS) -> Show (Tree a)
forall a. Show a => Int -> Tree a -> ShowS
forall a. Show a => [Tree a] -> ShowS
forall a. Show a => Tree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Tree a -> ShowS
showsPrec :: Int -> Tree a -> ShowS
$cshow :: forall a. Show a => Tree a -> String
show :: Tree a -> String
$cshowList :: forall a. Show a => [Tree a] -> ShowS
showList :: [Tree a] -> ShowS
Show)

-- | A strict list of complete binary trees accompanied by their size.
-- The trees appear in >=-size order.
-- Note: this list is strict in its spine, unlike the Prelude list
data RAList a = BHead
               {-# UNPACK #-} !Word64 -- ^ the size of the head tree
               !(Tree a) -- ^ the head tree
               !(RAList a) -- ^ the tail trees
             | Nil
             -- the derived Eq instance is correct,
             -- because binary skew numbers have unique representation
             -- and hence all trees of the same size will have the same structure
             deriving stock (RAList a -> RAList a -> Bool
(RAList a -> RAList a -> Bool)
-> (RAList a -> RAList a -> Bool) -> Eq (RAList a)
forall a. Eq a => RAList a -> RAList a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => RAList a -> RAList a -> Bool
== :: RAList a -> RAList a -> Bool
$c/= :: forall a. Eq a => RAList a -> RAList a -> Bool
/= :: RAList a -> RAList a -> Bool
Eq, Int -> RAList a -> ShowS
[RAList a] -> ShowS
RAList a -> String
(Int -> RAList a -> ShowS)
-> (RAList a -> String) -> ([RAList a] -> ShowS) -> Show (RAList a)
forall a. Show a => Int -> RAList a -> ShowS
forall a. Show a => [RAList a] -> ShowS
forall a. Show a => RAList a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> RAList a -> ShowS
showsPrec :: Int -> RAList a -> ShowS
$cshow :: forall a. Show a => RAList a -> String
show :: RAList a -> String
$cshowList :: forall a. Show a => [RAList a] -> ShowS
showList :: [RAList a] -> ShowS
Show)
             deriving (Int -> [Item (RAList a)] -> RAList a
[Item (RAList a)] -> RAList a
RAList a -> [Item (RAList a)]
([Item (RAList a)] -> RAList a)
-> (Int -> [Item (RAList a)] -> RAList a)
-> (RAList a -> [Item (RAList a)])
-> IsList (RAList a)
forall a. Int -> [Item (RAList a)] -> RAList a
forall a. [Item (RAList a)] -> RAList a
forall a. RAList a -> [Item (RAList a)]
forall l.
([Item l] -> l)
-> (Int -> [Item l] -> l) -> (l -> [Item l]) -> IsList l
$cfromList :: forall a. [Item (RAList a)] -> RAList a
fromList :: [Item (RAList a)] -> RAList a
$cfromListN :: forall a. Int -> [Item (RAList a)] -> RAList a
fromListN :: Int -> [Item (RAList a)] -> RAList a
$ctoList :: forall a. RAList a -> [Item (RAList a)]
toList :: RAList a -> [Item (RAList a)]
IsList) via RAL.AsRAL (RAList a)

null :: RAList a -> Bool
null :: forall a. RAList a -> Bool
null RAList a
Nil = Bool
True
null RAList a
_   = Bool
False
{-# INLINE null #-}

{-# COMPLETE Cons, Nil #-}
{-# COMPLETE BHead, Nil #-}

-- /O(1)/
pattern Cons :: a -> RAList a -> RAList a
pattern $mCons :: forall {r} {a}.
RAList a -> (a -> RAList a -> r) -> ((# #) -> r) -> r
$bCons :: forall a. a -> RAList a -> RAList a
Cons x xs <- (uncons -> Just (x, xs)) where
  Cons a
x RAList a
xs = a -> RAList a -> RAList a
forall a. a -> RAList a -> RAList a
cons a
x RAList a
xs

-- O(1) worst-case
cons :: a -> RAList a -> RAList a
cons :: forall a. a -> RAList a -> RAList a
cons a
x = \case
    (BHead Word64
w1 Tree a
t1 (BHead Word64
w2 Tree a
t2 RAList a
ts')) | Word64
w1 Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
w2 ->
        -- 'unsafeShiftL w1 1 `setBit`0' is supposed to be a faster version of '(2*w1)+1'
        Word64 -> Tree a -> RAList a -> RAList a
forall a. Word64 -> Tree a -> RAList a -> RAList a
BHead (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftL Word64
w1 Int
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`setBit` Int
0) (a -> Tree a -> Tree a -> Tree a
forall a. a -> Tree a -> Tree a -> Tree a
Node a
x Tree a
t1 Tree a
t2) RAList a
ts'
    RAList a
ts -> Word64 -> Tree a -> RAList a -> RAList a
forall a. Word64 -> Tree a -> RAList a -> RAList a
BHead Word64
1 (a -> Tree a
forall a. a -> Tree a
Leaf a
x) RAList a
ts
{-# INLINE cons #-}

-- /O(1)/
uncons :: RAList a -> Maybe (a, RAList a)
uncons :: forall a. RAList a -> Maybe (a, RAList a)
uncons = \case
    BHead Word64
_ (Leaf a
x) RAList a
ts -> (a, RAList a) -> Maybe (a, RAList a)
forall a. a -> Maybe a
Just (a
x, RAList a
ts)
    BHead Word64
treeSize (Node a
x Tree a
t1 Tree a
t2) RAList a
ts ->
        -- probably faster than `div w 2`
        let halfSize :: Word64
halfSize = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
treeSize Int
1
            -- split the node in two)
        in (a, RAList a) -> Maybe (a, RAList a)
forall a. a -> Maybe a
Just (a
x, Word64 -> Tree a -> RAList a -> RAList a
forall a. Word64 -> Tree a -> RAList a -> RAList a
BHead Word64
halfSize Tree a
t1 (RAList a -> RAList a) -> RAList a -> RAList a
forall a b. (a -> b) -> a -> b
$ Word64 -> Tree a -> RAList a -> RAList a
forall a. Word64 -> Tree a -> RAList a -> RAList a
BHead Word64
halfSize Tree a
t2 RAList a
ts)
    RAList a
Nil -> Maybe (a, RAList a)
forall a. Maybe a
Nothing
{-# INLINE uncons #-}

{- Note [Optimizations of contIndexZero]
Bangs in the local definitions of 'contIndexZero' are needed to tell GHC that the functions are
strict in the 'Word64' argument, so that GHC produces workers operating on @Word64#@.

The function itself is CPS-ed, so that the arguments force the local definitions to be retained
within 'contIndexZero' instead of being pulled out via full-laziness or some other optimization
pass. This ensures that when 'contIndexZero' gets inlined, the local definitions appear directly
in the GHC Core, allowing GHC to inline the arguments of 'contIndexZero' and transform the whole
thing into a beautiful recursive join point full of @Word64#@s, i.e. allocating very little if
anything at all.
-}

-- See Note [Optimizations of contIndexZero].
contIndexZero :: forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexZero :: forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexZero b
z a -> b
f = RAList a -> Word64 -> b
findTree where
    findTree :: RAList a -> Word64 -> b
    -- See Note [Optimizations of contIndexZero].
    findTree :: RAList a -> Word64 -> b
findTree RAList a
Nil !Word64
_ = b
z
    findTree (BHead Word64
w Tree a
t RAList a
ts) Word64
i =
        if Word64
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
w
        then Word64 -> Word64 -> Tree a -> b
indexTree Word64
w Word64
i Tree a
t
        else RAList a -> Word64 -> b
findTree RAList a
ts (Word64
iWord64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
-Word64
w)

    indexTree :: Word64 -> Word64 -> Tree a -> b
    -- See Note [Optimizations of contIndexZero].
    indexTree :: Word64 -> Word64 -> Tree a -> b
indexTree !Word64
w Word64
0 Tree a
t = case Tree a
t of
        Node a
x Tree a
_ Tree a
_ -> a -> b
f a
x
        Leaf a
x     -> if Word64
w Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
1 then a -> b
f a
x else b
z
    indexTree Word64
_ Word64
_ (Leaf a
_) = b
z
    indexTree Word64
treeSize Word64
offset (Node a
_ Tree a
t1 Tree a
t2) =
        let halfSize :: Word64
halfSize = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
treeSize Int
1 -- probably faster than `div w 2`
        in if Word64
offset Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word64
halfSize
           then Word64 -> Word64 -> Tree a -> b
indexTree Word64
halfSize (Word64
offset Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1) Tree a
t1
           else Word64 -> Word64 -> Tree a -> b
indexTree Word64
halfSize (Word64
offset Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
halfSize) Tree a
t2
{-# INLINE contIndexZero #-}

contIndexOne :: forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexOne :: forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexOne b
z a -> b
_ RAList a
_ Word64
0 = b
z
contIndexOne b
z a -> b
f RAList a
t Word64
n = b -> (a -> b) -> RAList a -> Word64 -> b
forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexZero b
z a -> b
f RAList a
t (Word64
n Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1)
{-# INLINE contIndexOne #-}

-- 0-based
unsafeIndexZero :: RAList a -> Word64 -> a
unsafeIndexZero :: forall a. RAList a -> Word64 -> a
unsafeIndexZero = a -> (a -> a) -> RAList a -> Word64 -> a
forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexZero (String -> a
forall a. HasCallStack => String -> a
error String
"out of bounds") a -> a
forall a. a -> a
id
{-# INLINE unsafeIndexZero #-}

-- 0-based
safeIndexZero :: RAList a -> Word64 -> Maybe a
safeIndexZero :: forall a. RAList a -> Word64 -> Maybe a
safeIndexZero = Maybe a -> (a -> Maybe a) -> RAList a -> Word64 -> Maybe a
forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexZero Maybe a
forall a. Maybe a
Nothing a -> Maybe a
forall a. a -> Maybe a
Just
{-# INLINE safeIndexZero #-}

-- 1-based
unsafeIndexOne :: RAList a -> Word64 -> a
unsafeIndexOne :: forall a. RAList a -> Word64 -> a
unsafeIndexOne = a -> (a -> a) -> RAList a -> Word64 -> a
forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexOne (String -> a
forall a. HasCallStack => String -> a
error String
"out of bounds") a -> a
forall a. a -> a
id
{-# INLINE unsafeIndexOne #-}

-- 1-based
safeIndexOne :: RAList a -> Word64 -> Maybe a
safeIndexOne :: forall a. RAList a -> Word64 -> Maybe a
safeIndexOne = Maybe a -> (a -> Maybe a) -> RAList a -> Word64 -> Maybe a
forall a b. b -> (a -> b) -> RAList a -> Word64 -> b
contIndexOne Maybe a
forall a. Maybe a
Nothing a -> Maybe a
forall a. a -> Maybe a
Just
{-# INLINE safeIndexOne #-}

instance RAL.RandomAccessList (RAList a) where
    type Element (RAList a) = a

    empty :: RAList a
empty = RAList a
forall a. RAList a
Nil
    {-# INLINE empty #-}

    cons :: Element (RAList a) -> RAList a -> RAList a
cons = a -> RAList a -> RAList a
Element (RAList a) -> RAList a -> RAList a
forall a. a -> RAList a -> RAList a
Cons
    {-# INLINE cons #-}

    uncons :: RAList a -> Maybe (Element (RAList a), RAList a)
uncons = RAList a -> Maybe (a, RAList a)
RAList a -> Maybe (Element (RAList a), RAList a)
forall a. RAList a -> Maybe (a, RAList a)
uncons
    {-# INLINE uncons #-}

    length :: RAList a -> Word64
length = Word64 -> RAList a -> Word64
forall {a}. Word64 -> RAList a -> Word64
go Word64
0 where
        go :: Word64 -> RAList a -> Word64
go !Word64
acc RAList a
Nil             = Word64
acc
        go !Word64
acc (BHead Word64
sz Tree a
_ RAList a
tl) = Word64 -> RAList a -> Word64
go (Word64
acc Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
sz) RAList a
tl
    {-# INLINE length #-}

    indexZero :: RAList a -> Word64 -> Maybe (Element (RAList a))
indexZero = RAList a -> Word64 -> Maybe a
RAList a -> Word64 -> Maybe (Element (RAList a))
forall a. RAList a -> Word64 -> Maybe a
safeIndexZero
    {-# INLINE indexZero #-}

    indexOne :: RAList a -> Word64 -> Maybe (Element (RAList a))
indexOne = RAList a -> Word64 -> Maybe a
RAList a -> Word64 -> Maybe (Element (RAList a))
forall a. RAList a -> Word64 -> Maybe a
safeIndexOne
    {-# INLINE indexOne #-}

    unsafeIndexZero :: RAList a -> Word64 -> Element (RAList a)
unsafeIndexZero = RAList a -> Word64 -> a
RAList a -> Word64 -> Element (RAList a)
forall a. RAList a -> Word64 -> a
unsafeIndexZero
    {-# INLINE unsafeIndexZero #-}

    unsafeIndexOne :: RAList a -> Word64 -> Element (RAList a)
unsafeIndexOne = RAList a -> Word64 -> a
RAList a -> Word64 -> Element (RAList a)
forall a. RAList a -> Word64 -> a
unsafeIndexOne
    {-# INLINE unsafeIndexOne #-}