{-# LANGUAGE ConstraintKinds   #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}

module PlutusTx.Compiler.Utils where

import PlutusTx.Compiler.Error
import PlutusTx.Compiler.Types

import GHC.Core qualified as GHC
import GHC.Plugins qualified as GHC
import GHC.Types.TyThing qualified as GHC

import Control.Monad ((<=<))
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (MonadReader, ask)

import Language.Haskell.TH.Syntax qualified as TH

import Data.Map qualified as Map
import Data.Text qualified as T

-- | Get the 'GHC.TyCon' for a given 'TH.Name' stored in the builtin name info,
-- failing if it is missing.
lookupGhcTyCon :: Compiling uni fun m ann => TH.Name -> m GHC.TyCon
lookupGhcTyCon :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m TyCon
lookupGhcTyCon Name
thName = do
  CompileContext { NameInfo
ccNameInfo :: NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo } <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just (GHC.ATyCon TyCon
tc) -> TyCon -> m TyCon
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TyCon
tc
    Maybe TyThing
_ -> Error uni fun ann -> m TyCon
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m TyCon) -> Error uni fun ann -> m TyCon
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"TyCon not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

-- | Get the 'GHC.Name' for a given 'TH.Name' stored in the builtin name info,
-- failing if it is missing.
lookupGhcName :: Compiling uni fun m ann => TH.Name -> m GHC.Name
lookupGhcName :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m Name
lookupGhcName Name
thName = do
  CompileContext { NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo :: NameInfo
ccNameInfo } <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just TyThing
thing -> Name -> m Name
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyThing -> Name
forall a. NamedThing a => a -> Name
GHC.getName TyThing
thing)
    Maybe TyThing
Nothing    -> Error uni fun ann -> m Name
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m Name) -> Error uni fun ann -> m Name
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"Name not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

-- | Get the 'GHC.Id' for a given 'TH.Name' stored in the builtin name info,
-- failing if it is missing.
lookupGhcId :: Compiling uni fun m ann => TH.Name -> m GHC.Id
lookupGhcId :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m Id
lookupGhcId Name
thName = do
  CompileContext { NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo :: NameInfo
ccNameInfo } <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just (GHC.AnId Id
ghcId) -> Id -> m Id
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Id
ghcId
    Maybe TyThing
_ -> Error uni fun ann -> m Id
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m Id) -> Error uni fun ann -> m Id
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"Id not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

sdToStr :: MonadReader (CompileContext uni fun) m => GHC.SDoc -> m String
sdToStr :: forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m String
sdToStr SDoc
sd = do
  CompileContext { ccFlags :: forall (uni :: * -> *) fun. CompileContext uni fun -> DynFlags
ccFlags=DynFlags
flags } <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  String -> m String
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> m String) -> String -> m String
forall a b. (a -> b) -> a -> b
$ DynFlags -> UnitState -> NamePprCtx -> SDoc -> String
GHC.showSDocForUser DynFlags
flags UnitState
GHC.emptyUnitState NamePprCtx
GHC.alwaysQualify SDoc
sd

sdToTxt :: MonadReader (CompileContext uni fun) m => GHC.SDoc -> m T.Text
sdToTxt :: forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m Text
sdToTxt = (String -> Text) -> m String -> m Text
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> Text
T.pack (m String -> m Text) -> (SDoc -> m String) -> SDoc -> m Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SDoc -> m String
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m String
sdToStr

throwSd ::
    (MonadError (CompileError uni fun ann) m, MonadReader (CompileContext uni fun) m) =>
    (T.Text -> Error uni fun ann) ->
    GHC.SDoc ->
    m a
throwSd :: forall (uni :: * -> *) fun ann (m :: * -> *) a.
(MonadError (CompileError uni fun ann) m,
 MonadReader (CompileContext uni fun) m) =>
(Text -> Error uni fun ann) -> SDoc -> m a
throwSd Text -> Error uni fun ann
constr = (Error uni fun ann -> m a
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m a)
-> (Text -> Error uni fun ann) -> Text -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Error uni fun ann
constr) (Text -> m a) -> (SDoc -> m Text) -> SDoc -> m a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< SDoc -> m Text
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m Text
sdToTxt

tyConsOfExpr :: GHC.CoreExpr -> GHC.UniqSet GHC.TyCon
tyConsOfExpr :: CoreExpr -> UniqSet TyCon
tyConsOfExpr = \case
    GHC.Type Type
ty -> Type -> UniqSet TyCon
GHC.tyConsOfType Type
ty
    GHC.Coercion Coercion
co -> Type -> UniqSet TyCon
GHC.tyConsOfType (Type -> UniqSet TyCon) -> Type -> UniqSet TyCon
forall a b. (a -> b) -> a -> b
$ Coercion -> Type
GHC.mkCoercionTy Coercion
co
    GHC.Var Id
v -> Type -> UniqSet TyCon
GHC.tyConsOfType (Id -> Type
GHC.varType Id
v)
    GHC.Lit Literal
_ -> UniqSet TyCon
forall a. Monoid a => a
mempty
    -- ignore anything in the ann
    GHC.Tick CoreTickish
_ CoreExpr
e -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e
    GHC.App CoreExpr
e1 CoreExpr
e2 -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e1 UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e2
    GHC.Lam Id
bndr CoreExpr
e -> Id -> UniqSet TyCon
tyConsOfBndr Id
bndr UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e
    GHC.Cast CoreExpr
e Coercion
co -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> Type -> UniqSet TyCon
GHC.tyConsOfType (Coercion -> Type
GHC.mkCoercionTy Coercion
co)
    GHC.Case CoreExpr
scrut Id
bndr Type
ty [Alt Id]
alts ->
        CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
scrut UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<>
        Id -> UniqSet TyCon
tyConsOfBndr Id
bndr UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<>
        Type -> UniqSet TyCon
GHC.tyConsOfType Type
ty UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<>
        (Alt Id -> UniqSet TyCon) -> [Alt Id] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Id -> UniqSet TyCon
tyConsOfAlt [Alt Id]
alts
    GHC.Let Bind Id
bind CoreExpr
body -> Bind Id -> UniqSet TyCon
tyConsOfBind Bind Id
bind UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
body

tyConsOfBndr :: GHC.CoreBndr -> GHC.UniqSet GHC.TyCon
tyConsOfBndr :: Id -> UniqSet TyCon
tyConsOfBndr = Type -> UniqSet TyCon
GHC.tyConsOfType (Type -> UniqSet TyCon) -> (Id -> Type) -> Id -> UniqSet TyCon
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
GHC.varType

tyConsOfBind :: GHC.Bind GHC.CoreBndr -> GHC.UniqSet GHC.TyCon
tyConsOfBind :: Bind Id -> UniqSet TyCon
tyConsOfBind = \case
    GHC.NonRec Id
bndr CoreExpr
rhs -> Id -> CoreExpr -> UniqSet TyCon
binderTyCons Id
bndr CoreExpr
rhs
    GHC.Rec [(Id, CoreExpr)]
bndrs       -> ((Id, CoreExpr) -> UniqSet TyCon)
-> [(Id, CoreExpr)] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Id -> CoreExpr -> UniqSet TyCon)
-> (Id, CoreExpr) -> UniqSet TyCon
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Id -> CoreExpr -> UniqSet TyCon
binderTyCons) [(Id, CoreExpr)]
bndrs
    where
        binderTyCons :: Id -> CoreExpr -> UniqSet TyCon
binderTyCons Id
bndr CoreExpr
rhs = Id -> UniqSet TyCon
tyConsOfBndr Id
bndr UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
rhs

tyConsOfAlt :: GHC.CoreAlt -> GHC.UniqSet GHC.TyCon
tyConsOfAlt :: Alt Id -> UniqSet TyCon
tyConsOfAlt (GHC.Alt AltCon
_ [Id]
vars CoreExpr
e) = (Id -> UniqSet TyCon) -> [Id] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Id -> UniqSet TyCon
tyConsOfBndr [Id]
vars UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e