{-| Insert necessary or recommended GHC flags and extensions via a driver plugin.
See https://plutus.cardano.intersectmbo.org/docs/using-plinth/extensions-flags-pragmas -}
module PlutusTx.Plugin.Boilerplate where

import PlutusTx.Compiler.Compat qualified as Compat

import GHC.Driver.Flags qualified as GHC
import GHC.Hs qualified as GHC
import GHC.LanguageExtensions qualified as GHC
import GHC.Plugins qualified as GHC
import GHC.Tc.Types qualified as GHC
import GHC.Types.SourceText qualified as GHC

{-| Unfortunately, it seems like the `Strict` extension set by the driver plugin cannot be
unset via @LANGUAGE NoStrict@. So we add a plugin flag to allow users to do so. -}
optNoStrict :: GHC.CommandLineOption
optNoStrict :: CommandLineOption
optNoStrict = CommandLineOption
"no-strict"

boilerplateOpts :: [GHC.CommandLineOption]
boilerplateOpts :: [CommandLineOption]
boilerplateOpts = [CommandLineOption
optNoStrict]

removeBoilerplateOpts :: [GHC.CommandLineOption] -> [GHC.CommandLineOption]
removeBoilerplateOpts :: [CommandLineOption] -> [CommandLineOption]
removeBoilerplateOpts = (CommandLineOption -> Bool)
-> [CommandLineOption] -> [CommandLineOption]
forall a. (a -> Bool) -> [a] -> [a]
filter (CommandLineOption -> [CommandLineOption] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [CommandLineOption]
boilerplateOpts)

addFlagsAndExts :: [GHC.CommandLineOption] -> GHC.HscEnv -> IO GHC.HscEnv
addFlagsAndExts :: [CommandLineOption] -> HscEnv -> IO HscEnv
addFlagsAndExts [CommandLineOption]
opts HscEnv
env = HscEnv -> IO HscEnv
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HscEnv
env {GHC.hsc_dflags = dflags}
  where
    dflags :: DynFlags
dflags = DynFlags -> DynFlags
setStrict (DynFlags -> DynFlags)
-> (DynFlags -> DynFlags) -> DynFlags -> DynFlags
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DynFlags -> DynFlags
unsetFlags (DynFlags -> DynFlags) -> DynFlags -> DynFlags
forall a b. (a -> b) -> a -> b
$ HscEnv -> DynFlags
GHC.hsc_dflags HscEnv
env

    unsetFlags :: GHC.DynFlags -> GHC.DynFlags
    unsetFlags :: DynFlags -> DynFlags
unsetFlags =
      (DynFlags -> [GeneralFlag] -> DynFlags)
-> [GeneralFlag] -> DynFlags -> DynFlags
forall a b c. (a -> b -> c) -> b -> a -> c
flip
        ((DynFlags -> GeneralFlag -> DynFlags)
-> DynFlags -> [GeneralFlag] -> DynFlags
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl DynFlags -> GeneralFlag -> DynFlags
GHC.gopt_unset)
        [ GeneralFlag
GHC.Opt_IgnoreInterfacePragmas
        , GeneralFlag
GHC.Opt_OmitInterfacePragmas
        , GeneralFlag
GHC.Opt_FullLaziness
        , GeneralFlag
GHC.Opt_SpecConstr
        , GeneralFlag
GHC.Opt_Specialise
        , GeneralFlag
GHC.Opt_Strictness
        , GeneralFlag
GHC.Opt_UnboxStrictFields
        , GeneralFlag
GHC.Opt_UnboxSmallStrictFields
        ]

    setStrict :: GHC.DynFlags -> GHC.DynFlags
    setStrict :: DynFlags -> DynFlags
setStrict =
      if CommandLineOption
optNoStrict CommandLineOption -> [CommandLineOption] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CommandLineOption]
opts then DynFlags -> DynFlags
forall a. a -> a
id else (DynFlags -> Extension -> DynFlags
`GHC.xopt_set` Extension
GHC.Strict)

-- | Add INLINEABLE to all bindings that carry no user-written inline pragma.
addInlineables :: GHC.TcGblEnv -> GHC.TcM GHC.TcGblEnv
addInlineables :: TcGblEnv -> TcM TcGblEnv
addInlineables TcGblEnv
env = TcGblEnv -> TcM TcGblEnv
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcGblEnv
env {GHC.tcg_binds = binds'}
  where
    binds :: Bag
  (XRec
     (GhcPass 'Typechecked)
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
binds = TcGblEnv
-> Bag
     (XRec
        (GhcPass 'Typechecked)
        (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
GHC.tcg_binds TcGblEnv
env
    binds' :: Bag
  (XRec
     (GhcPass 'Typechecked)
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
binds' = ([XRec
    (GhcPass 'Typechecked)
    (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))]
 -> [XRec
       (GhcPass 'Typechecked)
       (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))])
-> Bag
     (XRec
        (GhcPass 'Typechecked)
        (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
-> Bag
     (XRec
        (GhcPass 'Typechecked)
        (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
forall p.
([LHsBindLR p p] -> [LHsBindLR p p]) -> LHsBinds p -> LHsBinds p
Compat.modifyBinds ((XRec
   (GhcPass 'Typechecked)
   (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
 -> XRec
      (GhcPass 'Typechecked)
      (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
-> [XRec
      (GhcPass 'Typechecked)
      (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))]
-> [XRec
      (GhcPass 'Typechecked)
      (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))]
forall a b. (a -> b) -> [a] -> [b]
map XRec
  (GhcPass 'Typechecked)
  (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
-> XRec
     (GhcPass 'Typechecked)
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
addInlineable) Bag
  (XRec
     (GhcPass 'Typechecked)
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
Bag
  (GenLocated
     SrcSpanAnnA
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
binds

addInlineableABE :: GHC.ABExport -> GHC.ABExport
addInlineableABE :: ABExport -> ABExport
addInlineableABE ABExport
abe = case ABExport
abe of
  GHC.ABE {abe_poly :: ABExport -> Id
GHC.abe_poly = Id
v}
    | Id -> Bool
needsInlineable Id
v -> ABExport
abe {GHC.abe_poly = GHC.setInlinePragma v inlineable}
  ABExport
_ -> ABExport
abe

addInlineable :: GHC.LHsBindLR GHC.GhcTc GHC.GhcTc -> GHC.LHsBindLR GHC.GhcTc GHC.GhcTc
addInlineable :: XRec
  (GhcPass 'Typechecked)
  (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
-> XRec
     (GhcPass 'Typechecked)
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
addInlineable (GHC.L SrcSpanAnnA
loc HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
bind) = SrcSpanAnnA
-> HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
-> GenLocated
     SrcSpanAnnA
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
forall l e. l -> e -> GenLocated l e
GHC.L SrcSpanAnnA
loc (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
 -> GenLocated
      SrcSpanAnnA
      (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)))
-> HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
-> GenLocated
     SrcSpanAnnA
     (HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked))
forall a b. (a -> b) -> a -> b
$ case HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
bind of
  b :: HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
b@GHC.FunBind {fun_id :: forall idL idR. HsBindLR idL idR -> LIdP idL
GHC.fun_id = GHC.L SrcSpanAnnN
idLoc Id
v}
    | Id -> Bool
needsInlineable Id
v ->
        HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
b {GHC.fun_id = GHC.L idLoc (GHC.setInlinePragma v inlineable)}
  b :: HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
b@GHC.VarBind {var_id :: forall idL idR. HsBindLR idL idR -> IdP idL
GHC.var_id = IdP (GhcPass 'Typechecked)
v}
    | Id -> Bool
needsInlineable IdP (GhcPass 'Typechecked)
Id
v ->
        HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
b {GHC.var_id = GHC.setInlinePragma v inlineable}
  GHC.PatSynBind XPatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
x psb :: PatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
psb@GHC.PSB {psb_id :: forall idL idR. PatSynBind idL idR -> LIdP idL
GHC.psb_id = GHC.L SrcSpanAnnN
idLoc Id
v}
    | Id -> Bool
needsInlineable Id
v ->
        XPatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
-> PatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
-> HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
forall idL idR.
XPatSynBind idL idR -> PatSynBind idL idR -> HsBindLR idL idR
GHC.PatSynBind XPatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
x PatSynBind (GhcPass 'Typechecked) (GhcPass 'Typechecked)
psb {GHC.psb_id = GHC.L idLoc (GHC.setInlinePragma v inlineable)}
  GHC.XHsBindsLR XXHsBindsLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
as ->
    XXHsBindsLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
-> HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
forall idL idR. XXHsBindsLR idL idR -> HsBindLR idL idR
GHC.XHsBindsLR
      XXHsBindsLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
as
        { GHC.abs_exports = addInlineableABE <$> GHC.abs_exports as
        , GHC.abs_binds = Compat.modifyBinds (map addInlineable) (GHC.abs_binds as)
        }
  HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
other -> HsBindLR (GhcPass 'Typechecked) (GhcPass 'Typechecked)
other

-- | Return true if the @Id@ carries no user-written inline pragma.
needsInlineable :: GHC.Id -> Bool
needsInlineable :: Id -> Bool
needsInlineable = InlinePragma -> Bool
GHC.isDefaultInlinePragma (InlinePragma -> Bool) -> (Id -> InlinePragma) -> Id -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> InlinePragma
GHC.idInlinePragma

inlineable :: GHC.InlinePragma
inlineable :: InlinePragma
inlineable =
  GHC.InlinePragma
    { inl_src :: SourceText
GHC.inl_src = SourceText
GHC.NoSourceText
    , inl_inline :: InlineSpec
GHC.inl_inline = SourceText -> InlineSpec
GHC.Inlinable SourceText
GHC.NoSourceText
    , inl_sat :: Maybe Arity
GHC.inl_sat = Maybe Arity
forall a. Maybe a
Nothing
    , inl_act :: Activation
GHC.inl_act = Activation
GHC.AlwaysActive
    , inl_rule :: RuleMatchInfo
GHC.inl_rule = RuleMatchInfo
GHC.FunLike
    }