{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneKindSignatures #-}

module PlutusTx.Compiler.Utils where

import PlutusTx.Compiler.Error
import PlutusTx.Compiler.Types

import PlutusCore qualified as PLC

import GHC.Core qualified as GHC
import GHC.Plugins qualified as GHC
import GHC.Types.TyThing qualified as GHC

import Control.Monad ((<=<))
import Control.Monad.Except
import Control.Monad.Reader (MonadReader, ask)

import Language.Haskell.TH.Syntax qualified as TH

import Data.Kind qualified as Kind
import Data.Map qualified as Map
import Data.Text qualified as T

{-| Identical to `SomeTypeIn` but without existential kind. Having kind fixed to
`Type` makes it easier to pattern match and construct a different type within
universe. See how it's used in 'compileMkNil'. -}
type SomeStarIn :: (Kind.Type -> Kind.Type) -> Kind.Type
data SomeStarIn uni = forall a. SomeStarIn !(uni (PLC.Esc a))

{-| Get the 'GHC.TyCon' for a given 'TH.Name' stored in the builtin name info,
failing if it is missing. -}
lookupGhcTyCon :: Compiling uni fun m ann => TH.Name -> m GHC.TyCon
lookupGhcTyCon :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m TyCon
lookupGhcTyCon Name
thName = do
  CompileContext {NameInfo
ccNameInfo :: NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo} <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just (GHC.ATyCon TyCon
tc) -> TyCon -> m TyCon
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TyCon
tc
    Maybe TyThing
_ -> Error uni fun ann -> m TyCon
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m TyCon) -> Error uni fun ann -> m TyCon
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"TyCon not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

{-| Get the 'GHC.Name' for a given 'TH.Name' stored in the builtin name info,
failing if it is missing. -}
lookupGhcName :: Compiling uni fun m ann => TH.Name -> m GHC.Name
lookupGhcName :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m Name
lookupGhcName Name
thName = do
  CompileContext {NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo :: NameInfo
ccNameInfo} <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just TyThing
thing -> Name -> m Name
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyThing -> Name
forall a. NamedThing a => a -> Name
GHC.getName TyThing
thing)
    Maybe TyThing
Nothing -> Error uni fun ann -> m Name
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m Name) -> Error uni fun ann -> m Name
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"Name not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

{-| Get the 'GHC.Id' for a given 'TH.Name' stored in the builtin name info,
failing if it is missing. -}
lookupGhcId :: Compiling uni fun m ann => TH.Name -> m GHC.Id
lookupGhcId :: forall (uni :: * -> *) fun (m :: * -> *) ann.
Compiling uni fun m ann =>
Name -> m Id
lookupGhcId Name
thName = do
  CompileContext {NameInfo
ccNameInfo :: forall (uni :: * -> *) fun. CompileContext uni fun -> NameInfo
ccNameInfo :: NameInfo
ccNameInfo} <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case Name -> NameInfo -> Maybe TyThing
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
thName NameInfo
ccNameInfo of
    Just (GHC.AnId Id
ghcId) -> Id -> m Id
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Id
ghcId
    Maybe TyThing
_ -> Error uni fun ann -> m Id
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m Id) -> Error uni fun ann -> m Id
forall a b. (a -> b) -> a -> b
$ Text -> Error uni fun ann
forall (uni :: * -> *) fun a. Text -> Error uni fun a
CompilationError (Text -> Error uni fun ann) -> Text -> Error uni fun ann
forall a b. (a -> b) -> a -> b
$ Text
"Id not found: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Name -> String
forall a. Show a => a -> String
show Name
thName)

sdToStr :: MonadReader (CompileContext uni fun) m => GHC.SDoc -> m String
sdToStr :: forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m String
sdToStr SDoc
sd = do
  CompileContext {ccFlags :: forall (uni :: * -> *) fun. CompileContext uni fun -> DynFlags
ccFlags = DynFlags
flags} <- m (CompileContext uni fun)
forall r (m :: * -> *). MonadReader r m => m r
ask
  String -> m String
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> m String) -> String -> m String
forall a b. (a -> b) -> a -> b
$ DynFlags -> UnitState -> NamePprCtx -> SDoc -> String
GHC.showSDocForUser DynFlags
flags UnitState
GHC.emptyUnitState NamePprCtx
GHC.alwaysQualify SDoc
sd

sdToTxt :: MonadReader (CompileContext uni fun) m => GHC.SDoc -> m T.Text
sdToTxt :: forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m Text
sdToTxt = (String -> Text) -> m String -> m Text
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> Text
T.pack (m String -> m Text) -> (SDoc -> m String) -> SDoc -> m Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SDoc -> m String
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m String
sdToStr

throwSd
  :: (MonadError (CompileError uni fun ann) m, MonadReader (CompileContext uni fun) m)
  => (T.Text -> Error uni fun ann)
  -> GHC.SDoc
  -> m a
throwSd :: forall (uni :: * -> *) fun ann (m :: * -> *) a.
(MonadError (CompileError uni fun ann) m,
 MonadReader (CompileContext uni fun) m) =>
(Text -> Error uni fun ann) -> SDoc -> m a
throwSd Text -> Error uni fun ann
constr = (Error uni fun ann -> m a
forall c e (m :: * -> *) a.
MonadError (WithContext c e) m =>
e -> m a
throwPlain (Error uni fun ann -> m a)
-> (Text -> Error uni fun ann) -> Text -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Error uni fun ann
constr) (Text -> m a) -> (SDoc -> m Text) -> SDoc -> m a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< SDoc -> m Text
forall (uni :: * -> *) fun (m :: * -> *).
MonadReader (CompileContext uni fun) m =>
SDoc -> m Text
sdToTxt

tyConsOfExpr :: GHC.CoreExpr -> GHC.UniqSet GHC.TyCon
tyConsOfExpr :: CoreExpr -> UniqSet TyCon
tyConsOfExpr = \case
  GHC.Type Type
ty -> Type -> UniqSet TyCon
GHC.tyConsOfType Type
ty
  GHC.Coercion Coercion
co -> Type -> UniqSet TyCon
GHC.tyConsOfType (Type -> UniqSet TyCon) -> Type -> UniqSet TyCon
forall a b. (a -> b) -> a -> b
$ Coercion -> Type
GHC.mkCoercionTy Coercion
co
  GHC.Var Id
v -> Type -> UniqSet TyCon
GHC.tyConsOfType (Id -> Type
GHC.varType Id
v)
  GHC.Lit Literal
_ -> UniqSet TyCon
forall a. Monoid a => a
mempty
  -- ignore anything in the ann
  GHC.Tick CoreTickish
_ CoreExpr
e -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e
  GHC.App CoreExpr
e1 CoreExpr
e2 -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e1 UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e2
  GHC.Lam Id
bndr CoreExpr
e -> Id -> UniqSet TyCon
tyConsOfBndr Id
bndr UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e
  GHC.Cast CoreExpr
e Coercion
co -> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> Type -> UniqSet TyCon
GHC.tyConsOfType (Coercion -> Type
GHC.mkCoercionTy Coercion
co)
  GHC.Case CoreExpr
scrut Id
bndr Type
ty [Alt Id]
alts ->
    CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
scrut
      UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> Id -> UniqSet TyCon
tyConsOfBndr Id
bndr
      UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> Type -> UniqSet TyCon
GHC.tyConsOfType Type
ty
      UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> (Alt Id -> UniqSet TyCon) -> [Alt Id] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Alt Id -> UniqSet TyCon
tyConsOfAlt [Alt Id]
alts
  GHC.Let Bind Id
bind CoreExpr
body -> Bind Id -> UniqSet TyCon
tyConsOfBind Bind Id
bind UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
body

tyConsOfBndr :: GHC.CoreBndr -> GHC.UniqSet GHC.TyCon
tyConsOfBndr :: Id -> UniqSet TyCon
tyConsOfBndr = Type -> UniqSet TyCon
GHC.tyConsOfType (Type -> UniqSet TyCon) -> (Id -> Type) -> Id -> UniqSet TyCon
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
GHC.varType

tyConsOfBind :: GHC.Bind GHC.CoreBndr -> GHC.UniqSet GHC.TyCon
tyConsOfBind :: Bind Id -> UniqSet TyCon
tyConsOfBind = \case
  GHC.NonRec Id
bndr CoreExpr
rhs -> Id -> CoreExpr -> UniqSet TyCon
binderTyCons Id
bndr CoreExpr
rhs
  GHC.Rec [(Id, CoreExpr)]
bndrs -> ((Id, CoreExpr) -> UniqSet TyCon)
-> [(Id, CoreExpr)] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Id -> CoreExpr -> UniqSet TyCon)
-> (Id, CoreExpr) -> UniqSet TyCon
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Id -> CoreExpr -> UniqSet TyCon
binderTyCons) [(Id, CoreExpr)]
bndrs
  where
    binderTyCons :: Id -> CoreExpr -> UniqSet TyCon
binderTyCons Id
bndr CoreExpr
rhs = Id -> UniqSet TyCon
tyConsOfBndr Id
bndr UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
rhs

tyConsOfAlt :: GHC.CoreAlt -> GHC.UniqSet GHC.TyCon
tyConsOfAlt :: Alt Id -> UniqSet TyCon
tyConsOfAlt (GHC.Alt AltCon
_ [Id]
vars CoreExpr
e) = (Id -> UniqSet TyCon) -> [Id] -> UniqSet TyCon
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Id -> UniqSet TyCon
tyConsOfBndr [Id]
vars UniqSet TyCon -> UniqSet TyCon -> UniqSet TyCon
forall a. Semigroup a => a -> a -> a
<> CoreExpr -> UniqSet TyCon
tyConsOfExpr CoreExpr
e

{-| Get the package name for the module being compiled.
Tries 'lookupUnit' first (works for installed packages), then
'thisPackageName' from DynFlags (works for home library units),
and finally falls back to stripping the version from the unit ID string. -}
getPackageName :: GHC.HscEnv -> GHC.Module -> String
getPackageName :: HscEnv -> Module -> String
getPackageName HscEnv
hscEnv Module
thisModule =
  let unitState :: UnitState
unitState = (() :: Constraint) => HscEnv -> UnitState
HscEnv -> UnitState
GHC.hsc_units HscEnv
hscEnv
      unit :: Unit
unit = Module -> Unit
forall unit. GenModule unit -> unit
GHC.moduleUnit Module
thisModule
   in case UnitState -> Unit -> Maybe UnitInfo
GHC.lookupUnit UnitState
unitState Unit
unit of
        Just UnitInfo
unitInfo -> UnitInfo -> String
forall u. GenUnitInfo u -> String
GHC.unitPackageNameString UnitInfo
unitInfo
        Maybe UnitInfo
Nothing -> case DynFlags -> Maybe String
GHC.thisPackageName (HscEnv -> DynFlags
GHC.hsc_dflags HscEnv
hscEnv) of
          Just String
n -> String
n
          Maybe String
Nothing -> String -> String
stripVersion (Unit -> String
forall u. IsUnitId u => u -> String
GHC.unitString Unit
unit)
  where
    -- Extract "foo-bar" from "foo-bar-1.2.3-inplace-component"
    stripVersion :: String -> String
stripVersion String
s = String -> String -> String
go [] String
s
    go :: String -> String -> String
go String
acc [] = String -> String
forall a. [a] -> [a]
reverse String
acc
    go String
acc (Char
'-' : rest :: String
rest@(Char
c : String
_))
      | Char
c Char -> Char -> Bool
forall a. Ord a => a -> a -> Bool
>= Char
'0', Char
c Char -> Char -> Bool
forall a. Ord a => a -> a -> Bool
<= Char
'9' = String -> String
forall a. [a] -> [a]
reverse String
acc
      | Bool
otherwise = String -> String -> String
go (Char
'-' Char -> String -> String
forall a. a -> [a] -> [a]
: String
acc) String
rest
    go String
acc (Char
c : String
rest) = String -> String -> String
go (Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
acc) String
rest