{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE TemplateHaskell  #-}
{-# LANGUAGE TypeApplications #-}

module PlutusTx.Data.List.TH where

import Data.Set (Set)
import Data.Set qualified as Set
import Data.Traversable (for)
import Language.Haskell.TH qualified as TH
import PlutusTx.Data.List qualified as List
import Prelude

-- | Generate variables bound to the given indices of a @BuiltinList@.
--
-- Sample Usage:
--
--  @
--    f :: List Integer -> Integer
--    f list =
--    $( destructList
--         "s"
--         (Set.fromList [1, 4, 5])
--         'list
--         [|s1 + s4 + s5|]
--     )
--  @
--
-- This computes the sum of list elements at indices 1, 4 and 5.
destructList
  :: String
  -- ^ Prefix of the generated bindings
  -> Set Int
  -- ^ Element ids you need, starting from 0
  -> TH.Name
  -- ^ The builtin list to destruct
  -> TH.ExpQ
  -- ^ The computation that consumes the elements
  -> TH.ExpQ
destructList :: String -> Set Int -> Name -> ExpQ -> ExpQ
destructList String
p Set Int
is Name
n ExpQ
k = do
  let strict :: Name -> Q Pat
strict = Q Pat -> Q Pat
forall (m :: * -> *). Quote m => m Pat -> m Pat
TH.bangP (Q Pat -> Q Pat) -> (Name -> Q Pat) -> Name -> Q Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
TH.varP
      nonstrict :: Name -> Q Pat
nonstrict = Q Pat -> Q Pat
forall (m :: * -> *). Quote m => m Pat -> m Pat
TH.tildeP (Q Pat -> Q Pat) -> (Name -> Q Pat) -> Name -> Q Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
TH.varP
      elemName :: a -> Name
elemName a
i = String -> Name
TH.mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
p String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
i
  [Name]
tailNames <- [Int] -> (Int -> Q Name) -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
0 .. Set Int -> Int
forall a. Ord a => Set a -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum Set Int
is] ((Int -> Q Name) -> Q [Name]) -> (Int -> Q Name) -> Q [Name]
forall a b. (a -> b) -> a -> b
$ \Int
i -> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
TH.newName (String
"tail" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)
  [Dec]
decs <- ([[[Dec]]] -> [Dec]) -> Q [[[Dec]]] -> Q [Dec]
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Dec]] -> [Dec]) -> ([[[Dec]]] -> [[Dec]]) -> [[[Dec]]] -> [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[[Dec]]] -> [[Dec]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat) (Q [[[Dec]]] -> Q [Dec])
-> ((Int -> Q [[Dec]]) -> Q [[[Dec]]])
-> (Int -> Q [[Dec]])
-> Q [Dec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> (Int -> Q [[Dec]]) -> Q [[[Dec]]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int
0 .. Set Int -> Int
forall a. Ord a => Set a -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum Set Int
is] ((Int -> Q [[Dec]]) -> Q [Dec]) -> (Int -> Q [[Dec]]) -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    let -- if tailx is only used once, make it non-strict so that it can be inlined
        tailStrictness :: Name -> Q Pat
tailStrictness = if Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Set Int
is then Name -> Q Pat
strict else Name -> Q Pat
nonstrict
        n' :: Name
n' = if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Name
n else [Name]
tailNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    [Q [Dec]] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([Q [Dec]] -> Q [[Dec]]) -> [Q [Dec]] -> Q [[Dec]]
forall a b. (a -> b) -> a -> b
$
      [ [d|$(Name -> Q Pat
strict (Int -> Name
forall {a}. Show a => a -> Name
elemName Int
i)) = List.head $(Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
TH.varE Name
n')|]
      | Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Int
i Set Int
is
      ]
        [Q [Dec]] -> [Q [Dec]] -> [Q [Dec]]
forall a. [a] -> [a] -> [a]
++ [ [d|$(Name -> Q Pat
tailStrictness ([Name]
tailNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
i)) = List.tail $(Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
TH.varE Name
n')|]
           | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Set Int -> Int
forall a. Ord a => Set a -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum Set Int
is
           ]
  [Q Dec] -> ExpQ -> ExpQ
forall (m :: * -> *). Quote m => [m Dec] -> m Exp -> m Exp
TH.letE (Dec -> Q Dec
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dec -> Q Dec) -> [Dec] -> [Q Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Dec]
decs) ExpQ
k