{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE OverloadedStrings #-}

module PlutusTx.Compiler.Trace where

import PlutusTx.Compiler.Error
import PlutusTx.Compiler.Types
import PlutusTx.Compiler.Utils

import Control.Monad.Except
import Control.Monad.Extra
import Control.Monad.Reader
import Control.Monad.State
import Data.Maybe
import Data.Text (Text)
import Debug.Trace
import GHC.Plugins qualified as GHC

-- | A combination of `withContextM` and `traceCompilationStep`.
--
-- `withContextM` emits a stack trace when the compilation fails, and can be
-- turned on via `-fcontext-level=<level>`.
--
-- `traceCompilationStep` dumps the full compilation trace, and can be
-- turned on via `-fdump-compilation-trace`.
traceCompilation ::
  (MonadReader (CompileContext uni fun) m, MonadState CompileState m
  , MonadError (WithContext Text e) m) =>
  -- | Context level
  Int ->
  -- | The thing (expr, type, kind, etc.) being compiled
  GHC.SDoc ->
  -- | The compilation action
  m a ->
  m a
traceCompilation :: forall (uni :: * -> *) fun (m :: * -> *) e a.
(MonadReader (CompileContext uni fun) m, MonadState CompileState m,
 MonadError (WithContext Text e) m) =>
Int -> SDoc -> m a -> m a
traceCompilation Int
p SDoc
sd = Int -> m Text -> m a -> m a
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
Int -> m c -> m a -> m a
withContextM Int
p (SDoc -> m Text
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m Text
sdToTxt SDoc
sd) (m a -> m a) -> (m a -> m a) -> m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SDoc -> m a -> m a
forall (uni :: * -> *) fun (m :: * -> *) a.
(MonadReader (CompileContext uni fun) m,
 MonadState CompileState m) =>
SDoc -> m a -> m a
traceCompilationStep SDoc
sd

traceCompilationStep ::
  (MonadReader (CompileContext uni fun) m, MonadState CompileState m) =>
  -- | The thing (expr, type, kind, etc.) being compiled
  GHC.SDoc ->
  -- | The compilation action
  m a ->
  m a
traceCompilationStep :: forall (uni :: * -> *) fun (m :: * -> *) a.
(MonadReader (CompileContext uni fun) m,
 MonadState CompileState m) =>
SDoc -> m a -> m a
traceCompilationStep SDoc
sd m a
compile = m Bool -> m a -> m a -> m a
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM (m Bool -> m Bool
forall (m :: * -> *). Functor m => m Bool -> m Bool
notM ((CompileContext uni fun -> Bool) -> m Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompileContext uni fun -> Bool
forall (uni :: * -> *) fun. CompileContext uni fun -> Bool
ccDebugTraceOn)) m a
compile (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
  CompileState Int
nextStep [Int]
prevSteps <- m CompileState
forall s (m :: * -> *). MonadState s m => m s
get
  CompileState -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (CompileState -> m ()) -> CompileState -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> CompileState
CompileState (Int
nextStep Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
nextStep Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
prevSteps)
  let mbParentStep :: Maybe Int
mbParentStep = [Int] -> Maybe Int
forall a. [a] -> Maybe a
listToMaybe [Int]
prevSteps
  String
s <- SDoc -> m String
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m String
sdToStr SDoc
sd
  String -> m ()
forall (f :: * -> *). Applicative f => String -> f ()
traceM (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
    String
"<Step "
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
nextStep
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> (Int -> String) -> Maybe Int -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (\Int
parentStep -> String
", Parent Step: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
parentStep) Maybe Int
mbParentStep
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
">"
  String -> m ()
forall (f :: * -> *). Applicative f => String -> f ()
traceM String
s
  a
res <- m a
compile
  String -> m ()
forall (f :: * -> *). Applicative f => String -> f ()
traceM (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
    String
"<Completed step "
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
nextStep
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> (Int -> String) -> Maybe Int -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (\Int
parentStep -> String
", Returning to step " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
parentStep) Maybe Int
mbParentStep
      String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
">"
  (CompileState -> CompileState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((CompileState -> CompileState) -> m ())
-> (CompileState -> CompileState) -> m ()
forall a b. (a -> b) -> a -> b
$ \(CompileState Int
nextStep' [Int]
prevSteps') -> Int -> [Int] -> CompileState
CompileState Int
nextStep' (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
1 [Int]
prevSteps')
  a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res