{-# LANGUAGE BlockArguments  #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Utilities for space-time tradeoff, such as recursion unrolling.
module PlutusTx.Optimize.SpaceTime (peel, unroll) where

import Prelude

import Language.Haskell.TH.Syntax.Compat qualified as TH
import PlutusTx.Function (fix)

{-| Given @n@, and the step function for a recursive function, peel @n@ layers
off of the recursion.

For example @peel 3 (\self -> [[| \case [] -> 0; _ : ys -> 1 + self ys||])@
yields the equivalence of the following function:

@
  lengthPeeled :: [a] -> a
  lengthPeeled xs =
    case xs of                     -- first recursion step
      []     -> 0
      _ : ys -> 1 +
        case ys of                 -- second recursion step
          []     -> 0
          _ : zs -> 1 +
            case zs of             -- third recursion step
              []     -> 0
              _ : ws -> 1 +
                ( fix \self qs ->  -- rest of recursion steps in a tight loop
                    case qs of
                      []     -> 0
                      _ : ts -> 1 + self ts
                ) ws
@
-}
peel
  :: forall a b
   . Int
  -- ^ How many recursion steps to move outside of the recursion loop.
  -> (TH.SpliceQ (a -> b) -> TH.SpliceQ (a -> b))
  {- ^ Function that given a continuation splice returns
  a splice representing a single recursion step calling this continuation.
  -}
  -> TH.SpliceQ (a -> b)
peel :: forall a b.
Int -> (SpliceQ (a -> b) -> SpliceQ (a -> b)) -> SpliceQ (a -> b)
peel Int
0 SpliceQ (a -> b) -> SpliceQ (a -> b)
f = [||((a -> b) -> a -> b) -> a -> b
forall a b. ((a -> b) -> a -> b) -> a -> b
fix \a -> b
self -> $$(SpliceQ (a -> b) -> SpliceQ (a -> b)
f [||a -> b
self||])||]
peel Int
n SpliceQ (a -> b) -> SpliceQ (a -> b)
f
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = SpliceQ (a -> b) -> SpliceQ (a -> b)
f (Int -> (SpliceQ (a -> b) -> SpliceQ (a -> b)) -> SpliceQ (a -> b)
forall a b.
Int -> (SpliceQ (a -> b) -> SpliceQ (a -> b)) -> SpliceQ (a -> b)
peel (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) SpliceQ (a -> b) -> SpliceQ (a -> b)
f)
  | Bool
otherwise = [Char] -> SpliceQ (a -> b)
forall a. HasCallStack => [Char] -> a
error ([Char] -> SpliceQ (a -> b)) -> [Char] -> SpliceQ (a -> b)
forall a b. (a -> b) -> a -> b
$ [Char]
"PlutusTx.Optimize.SpaceTime.peel: negative n: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
n

{-| Given @n@, and the step function for a recursive function,
    unroll recursion @n@ layers at a time

For example @unroll 3 (\self -> [|| \case [] -> 0; _ : ys -> 1 + self ys ||])@
yields the equivalence of the following function:

@
  lengthUnrolled :: [a] -> a
  lengthUnrolled =
    fix \self xs ->                   -- beginning of the recursion "loop"
      case xs of                      -- first recursion step
        []     -> 0
        _ : ys -> 1 +
          case ys of                  -- second recursion step
            []     -> 0
            _ : zs -> 1 +
              case zs of              -- third recursion step
                []     -> 0
                _ : ws -> 1 + self ws -- end of the "loop"

@
-}
unroll
  :: forall a b
   . Int
  -- ^ How many recursion steps to perform inside the recursion loop.
  -> (TH.SpliceQ (a -> b) -> TH.SpliceQ (a -> b))
  {- ^ Function that given a continuation splice returns
  a splice representing a single recursion step calling this continuation.
  -}
  -> TH.SpliceQ (a -> b)
unroll :: forall a b.
Int -> (SpliceQ (a -> b) -> SpliceQ (a -> b)) -> SpliceQ (a -> b)
unroll Int
n SpliceQ (a -> b) -> SpliceQ (a -> b)
f = [||((a -> b) -> a -> b) -> a -> b
forall a b. ((a -> b) -> a -> b) -> a -> b
fix \a -> b
self -> $$(Int
-> (SpliceQ (a -> b) -> SpliceQ (a -> b))
-> SpliceQ (a -> b)
-> SpliceQ (a -> b)
forall a. Int -> (a -> a) -> a -> a
nTimes Int
n SpliceQ (a -> b) -> SpliceQ (a -> b)
f [||a -> b
self||])||]

-- | Apply a function @n@ times to a given value.
nTimes :: Int -> (a -> a) -> (a -> a)
nTimes :: forall a. Int -> (a -> a) -> a -> a
nTimes Int
0 a -> a
_ = a -> a
forall a. a -> a
id
nTimes Int
1 a -> a
f = a -> a
f
nTimes Int
n a -> a
f = a -> a
f (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (a -> a) -> a -> a
forall a. Int -> (a -> a) -> a -> a
nTimes (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> a
f