-- editorconfig-checker-disable-file
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeOperators         #-}
module PlutusIR.Compiler.Types where

import PlutusIR qualified as PIR
import PlutusIR.Compiler.Provenance
import PlutusIR.Error

import Control.Monad (when)
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (MonadReader, local)

import Control.Lens

import PlutusCore qualified as PLC
import PlutusCore.Annotation
import PlutusCore.Builtin qualified as PLC
import PlutusCore.MkPlc qualified as PLC
import PlutusCore.Pretty qualified as PLC
import PlutusCore.Quote
import PlutusCore.StdLib.Type qualified as Types
import PlutusCore.TypeCheck.Internal qualified as PLC
import PlutusCore.Version qualified as PLC
import PlutusIR.Transform.RewriteRules.Internal (RewriteRules)
import PlutusPrelude

import Control.Monad.Error.Lens (throwing)
import Data.Text qualified as T
import PlutusIR.Analysis.Builtins
import Prettyprinter (viaShow)

-- | Extra flag to be passed in the TypeCheckM Reader context,
-- to signal if the PIR expression currently being typechecked is at the top-level
-- and thus its type can escape, or nested and thus not allowed to escape.
data AllowEscape = YesEscape | NoEscape

-- | extending the plc typecheck config with AllowEscape
data PirTCConfig uni fun = PirTCConfig {
      forall (uni :: * -> *) fun.
PirTCConfig uni fun -> TypeCheckConfig uni fun
_pirConfigTCConfig      :: PLC.TypeCheckConfig uni fun
      , forall (uni :: * -> *) fun. PirTCConfig uni fun -> AllowEscape
_pirConfigAllowEscape :: AllowEscape
     }
makeLenses ''PirTCConfig

-- pir config has inside a plc config so it can act like it
instance PLC.HasKindCheckConfig (PirTCConfig uni fun) where
    kindCheckConfig :: Lens' (PirTCConfig uni fun) KindCheckConfig
kindCheckConfig = (TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
forall (uni :: * -> *) fun (uni :: * -> *) fun (f :: * -> *).
Functor f =>
(TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
pirConfigTCConfig ((TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
 -> PirTCConfig uni fun -> f (PirTCConfig uni fun))
-> ((KindCheckConfig -> f KindCheckConfig)
    -> TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> (KindCheckConfig -> f KindCheckConfig)
-> PirTCConfig uni fun
-> f (PirTCConfig uni fun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (KindCheckConfig -> f KindCheckConfig)
-> TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun)
forall c. HasKindCheckConfig c => Lens' c KindCheckConfig
Lens' (TypeCheckConfig uni fun) KindCheckConfig
PLC.kindCheckConfig

instance PLC.HasTypeCheckConfig (PirTCConfig uni fun) uni fun where
    typeCheckConfig :: Lens' (PirTCConfig uni fun) (TypeCheckConfig uni fun)
typeCheckConfig = (TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
forall (uni :: * -> *) fun (uni :: * -> *) fun (f :: * -> *).
Functor f =>
(TypeCheckConfig uni fun -> f (TypeCheckConfig uni fun))
-> PirTCConfig uni fun -> f (PirTCConfig uni fun)
pirConfigTCConfig

-- | What style to use when encoding datatypes.
-- Generally, 'SumsOfProducts' is superior, unless you are targeting an
-- old Plutus Core language version.
--
-- See Note [Encoding of datatypes]
data DatatypeStyle = ScottEncoding | SumsOfProducts
    deriving stock (Int -> DatatypeStyle -> ShowS
[DatatypeStyle] -> ShowS
DatatypeStyle -> String
(Int -> DatatypeStyle -> ShowS)
-> (DatatypeStyle -> String)
-> ([DatatypeStyle] -> ShowS)
-> Show DatatypeStyle
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DatatypeStyle -> ShowS
showsPrec :: Int -> DatatypeStyle -> ShowS
$cshow :: DatatypeStyle -> String
show :: DatatypeStyle -> String
$cshowList :: [DatatypeStyle] -> ShowS
showList :: [DatatypeStyle] -> ShowS
Show, ReadPrec [DatatypeStyle]
ReadPrec DatatypeStyle
Int -> ReadS DatatypeStyle
ReadS [DatatypeStyle]
(Int -> ReadS DatatypeStyle)
-> ReadS [DatatypeStyle]
-> ReadPrec DatatypeStyle
-> ReadPrec [DatatypeStyle]
-> Read DatatypeStyle
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS DatatypeStyle
readsPrec :: Int -> ReadS DatatypeStyle
$creadList :: ReadS [DatatypeStyle]
readList :: ReadS [DatatypeStyle]
$creadPrec :: ReadPrec DatatypeStyle
readPrec :: ReadPrec DatatypeStyle
$creadListPrec :: ReadPrec [DatatypeStyle]
readListPrec :: ReadPrec [DatatypeStyle]
Read, DatatypeStyle -> DatatypeStyle -> Bool
(DatatypeStyle -> DatatypeStyle -> Bool)
-> (DatatypeStyle -> DatatypeStyle -> Bool) -> Eq DatatypeStyle
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DatatypeStyle -> DatatypeStyle -> Bool
== :: DatatypeStyle -> DatatypeStyle -> Bool
$c/= :: DatatypeStyle -> DatatypeStyle -> Bool
/= :: DatatypeStyle -> DatatypeStyle -> Bool
Eq)

instance Pretty DatatypeStyle where
  pretty :: forall ann. DatatypeStyle -> Doc ann
pretty = DatatypeStyle -> Doc ann
forall a ann. Show a => a -> Doc ann
viaShow

newtype DatatypeCompilationOpts = DatatypeCompilationOpts
    { DatatypeCompilationOpts -> DatatypeStyle
_dcoStyle :: DatatypeStyle
    } deriving stock (Int -> DatatypeCompilationOpts -> ShowS
[DatatypeCompilationOpts] -> ShowS
DatatypeCompilationOpts -> String
(Int -> DatatypeCompilationOpts -> ShowS)
-> (DatatypeCompilationOpts -> String)
-> ([DatatypeCompilationOpts] -> ShowS)
-> Show DatatypeCompilationOpts
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DatatypeCompilationOpts -> ShowS
showsPrec :: Int -> DatatypeCompilationOpts -> ShowS
$cshow :: DatatypeCompilationOpts -> String
show :: DatatypeCompilationOpts -> String
$cshowList :: [DatatypeCompilationOpts] -> ShowS
showList :: [DatatypeCompilationOpts] -> ShowS
Show)

makeLenses ''DatatypeCompilationOpts

defaultDatatypeCompilationOpts :: DatatypeCompilationOpts
defaultDatatypeCompilationOpts :: DatatypeCompilationOpts
defaultDatatypeCompilationOpts = DatatypeStyle -> DatatypeCompilationOpts
DatatypeCompilationOpts DatatypeStyle
SumsOfProducts

data CompilationOpts a = CompilationOpts {
    forall a. CompilationOpts a -> Bool
_coOptimize                         :: Bool
    , forall a. CompilationOpts a -> Bool
_coTypecheck                      :: Bool
    , forall a. CompilationOpts a -> Bool
_coPedantic                       :: Bool
    , forall a. CompilationOpts a -> Bool
_coVerbose                        :: Bool
    , forall a. CompilationOpts a -> Bool
_coDebug                          :: Bool
    , forall a. CompilationOpts a -> DatatypeCompilationOpts
_coDatatypes                      :: DatatypeCompilationOpts
    -- Simplifier passes
    , forall a. CompilationOpts a -> Int
_coMaxSimplifierIterations        :: Int
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierUnwrapCancel       :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierCaseReduce         :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierRewrite            :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierBeta               :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierInline             :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierKnownCon           :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierCaseOfCase         :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierEvaluateBuiltins   :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierStrictifyBindings  :: Bool
    , forall a. CompilationOpts a -> Bool
_coDoSimplifierRemoveDeadBindings :: Bool
    , forall a. CompilationOpts a -> InlineHints Name (Provenance a)
_coInlineHints                    :: InlineHints PLC.Name (Provenance a)
    , forall a. CompilationOpts a -> Bool
_coInlineConstants                :: Bool
    -- Profiling
    , forall a. CompilationOpts a -> Bool
_coProfile                        :: Bool
    , forall a. CompilationOpts a -> Bool
_coRelaxedFloatin                 :: Bool
    , forall a. CompilationOpts a -> Bool
_coCaseOfCaseConservative         :: Bool
    -- | Whether to try and preserve the logging beahviour of the program.
    , forall a. CompilationOpts a -> Bool
_coPreserveLogging                :: Bool
    } deriving stock (Int -> CompilationOpts a -> ShowS
[CompilationOpts a] -> ShowS
CompilationOpts a -> String
(Int -> CompilationOpts a -> ShowS)
-> (CompilationOpts a -> String)
-> ([CompilationOpts a] -> ShowS)
-> Show (CompilationOpts a)
forall a. Int -> CompilationOpts a -> ShowS
forall a. [CompilationOpts a] -> ShowS
forall a. CompilationOpts a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Int -> CompilationOpts a -> ShowS
showsPrec :: Int -> CompilationOpts a -> ShowS
$cshow :: forall a. CompilationOpts a -> String
show :: CompilationOpts a -> String
$cshowList :: forall a. [CompilationOpts a] -> ShowS
showList :: [CompilationOpts a] -> ShowS
Show)

makeLenses ''CompilationOpts

defaultCompilationOpts :: CompilationOpts a
defaultCompilationOpts :: forall a. CompilationOpts a
defaultCompilationOpts = CompilationOpts
  { _coOptimize :: Bool
_coOptimize = Bool
True -- synonymous with max-simplifier-iterations=0
  , _coTypecheck :: Bool
_coTypecheck = Bool
True
  , _coPedantic :: Bool
_coPedantic = Bool
False
  , _coVerbose :: Bool
_coVerbose = Bool
False
  , _coDebug :: Bool
_coDebug = Bool
False
  , _coDatatypes :: DatatypeCompilationOpts
_coDatatypes = DatatypeCompilationOpts
defaultDatatypeCompilationOpts
  , _coMaxSimplifierIterations :: Int
_coMaxSimplifierIterations = Int
12
  , _coDoSimplifierUnwrapCancel :: Bool
_coDoSimplifierUnwrapCancel = Bool
True
  , _coDoSimplifierCaseReduce :: Bool
_coDoSimplifierCaseReduce = Bool
True
  , _coDoSimplifierRewrite :: Bool
_coDoSimplifierRewrite = Bool
True
  , _coDoSimplifierKnownCon :: Bool
_coDoSimplifierKnownCon = Bool
True
  , _coDoSimplifierCaseOfCase :: Bool
_coDoSimplifierCaseOfCase = Bool
True
  , _coDoSimplifierBeta :: Bool
_coDoSimplifierBeta = Bool
True
  , _coDoSimplifierInline :: Bool
_coDoSimplifierInline = Bool
True
  , _coDoSimplifierEvaluateBuiltins :: Bool
_coDoSimplifierEvaluateBuiltins = Bool
True
  , _coDoSimplifierStrictifyBindings :: Bool
_coDoSimplifierStrictifyBindings = Bool
True
  , _coInlineHints :: InlineHints Name (Provenance a)
_coInlineHints = InlineHints Name (Provenance a)
forall a. Monoid a => a
mempty
  , _coInlineConstants :: Bool
_coInlineConstants = Bool
True
  , _coProfile :: Bool
_coProfile = Bool
False
  , _coRelaxedFloatin :: Bool
_coRelaxedFloatin = Bool
True
  , _coCaseOfCaseConservative :: Bool
_coCaseOfCaseConservative = Bool
True
  , _coPreserveLogging :: Bool
_coPreserveLogging = Bool
False
  , _coDoSimplifierRemoveDeadBindings :: Bool
_coDoSimplifierRemoveDeadBindings = Bool
True
  }

data CompilationCtx uni fun a = CompilationCtx {
    forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> CompilationOpts a
_ccOpts               :: CompilationOpts a
    , forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> Provenance a
_ccEnclosing        :: Provenance a
    -- | Decide to either typecheck (passing a specific tcconfig) or not by passing 'Nothing'.
    , forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> PirTCConfig uni fun
_ccTypeCheckConfig  :: PirTCConfig uni fun
    , forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> BuiltinsInfo uni fun
_ccBuiltinsInfo     :: BuiltinsInfo uni fun
    , forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> CostingPart uni fun
_ccBuiltinCostModel :: PLC.CostingPart uni fun
    , forall (uni :: * -> *) fun a.
CompilationCtx uni fun a -> RewriteRules uni fun
_ccRewriteRules     :: RewriteRules uni fun
    }

makeLenses ''CompilationCtx

toDefaultCompilationCtx
    :: (Default (BuiltinsInfo uni fun), Default (PLC.CostingPart uni fun), Default (RewriteRules uni fun))
    => PLC.TypeCheckConfig uni fun
    -> CompilationCtx uni fun a
toDefaultCompilationCtx :: forall (uni :: * -> *) fun a.
(Default (BuiltinsInfo uni fun), Default (CostingPart uni fun),
 Default (RewriteRules uni fun)) =>
TypeCheckConfig uni fun -> CompilationCtx uni fun a
toDefaultCompilationCtx TypeCheckConfig uni fun
configPlc = CompilationCtx
       { _ccOpts :: CompilationOpts a
_ccOpts = CompilationOpts a
forall a. CompilationOpts a
defaultCompilationOpts
       , _ccEnclosing :: Provenance a
_ccEnclosing = Provenance a
forall a. Provenance a
noProvenance
       , _ccTypeCheckConfig :: PirTCConfig uni fun
_ccTypeCheckConfig = TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
forall (uni :: * -> *) fun.
TypeCheckConfig uni fun -> AllowEscape -> PirTCConfig uni fun
PirTCConfig TypeCheckConfig uni fun
configPlc AllowEscape
YesEscape
       , _ccBuiltinsInfo :: BuiltinsInfo uni fun
_ccBuiltinsInfo = BuiltinsInfo uni fun
forall a. Default a => a
def
       , _ccBuiltinCostModel :: CostingPart uni fun
_ccBuiltinCostModel = CostingPart uni fun
forall a. Default a => a
def
       , _ccRewriteRules :: RewriteRules uni fun
_ccRewriteRules = RewriteRules uni fun
forall a. Default a => a
def
       }

validateOpts :: Compiling m e uni fun a => PLC.Version -> m ()
validateOpts :: forall (m :: * -> *) e (uni :: * -> *) fun a.
Compiling m e uni fun a =>
Version -> m ()
validateOpts Version
v = do
  DatatypeStyle
datatypes <- Getting DatatypeStyle (CompilationCtx uni fun a) DatatypeStyle
-> m DatatypeStyle
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view ((CompilationOpts a -> Const DatatypeStyle (CompilationOpts a))
-> CompilationCtx uni fun a
-> Const DatatypeStyle (CompilationCtx uni fun a)
forall (uni :: * -> *) fun a (f :: * -> *).
Functor f =>
(CompilationOpts a -> f (CompilationOpts a))
-> CompilationCtx uni fun a -> f (CompilationCtx uni fun a)
ccOpts ((CompilationOpts a -> Const DatatypeStyle (CompilationOpts a))
 -> CompilationCtx uni fun a
 -> Const DatatypeStyle (CompilationCtx uni fun a))
-> ((DatatypeStyle -> Const DatatypeStyle DatatypeStyle)
    -> CompilationOpts a -> Const DatatypeStyle (CompilationOpts a))
-> Getting DatatypeStyle (CompilationCtx uni fun a) DatatypeStyle
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DatatypeCompilationOpts
 -> Const DatatypeStyle DatatypeCompilationOpts)
-> CompilationOpts a -> Const DatatypeStyle (CompilationOpts a)
forall a (f :: * -> *).
Functor f =>
(DatatypeCompilationOpts -> f DatatypeCompilationOpts)
-> CompilationOpts a -> f (CompilationOpts a)
coDatatypes ((DatatypeCompilationOpts
  -> Const DatatypeStyle DatatypeCompilationOpts)
 -> CompilationOpts a -> Const DatatypeStyle (CompilationOpts a))
-> ((DatatypeStyle -> Const DatatypeStyle DatatypeStyle)
    -> DatatypeCompilationOpts
    -> Const DatatypeStyle DatatypeCompilationOpts)
-> (DatatypeStyle -> Const DatatypeStyle DatatypeStyle)
-> CompilationOpts a
-> Const DatatypeStyle (CompilationOpts a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DatatypeStyle -> Const DatatypeStyle DatatypeStyle)
-> DatatypeCompilationOpts
-> Const DatatypeStyle DatatypeCompilationOpts
Iso' DatatypeCompilationOpts DatatypeStyle
dcoStyle)
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DatatypeStyle
datatypes DatatypeStyle -> DatatypeStyle -> Bool
forall a. Eq a => a -> a -> Bool
== DatatypeStyle
SumsOfProducts Bool -> Bool -> Bool
&& Version
v Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
PLC.plcVersion110) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ AReview e Text -> Text -> m ()
forall e (m :: * -> *) t x.
MonadError e m =>
AReview e t -> t -> m x
throwing AReview e Text
forall r (uni :: * -> *) fun a.
AsError r uni fun a =>
Prism' r Text
Prism' e Text
_OptionsError (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"Cannot use sums-of-products to compile a program with version less than 1.10. Program version is:" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Version -> String
forall a. Show a => a -> String
show Version
v

getEnclosing :: MonadReader (CompilationCtx uni fun a) m => m (Provenance a)
getEnclosing :: forall (uni :: * -> *) fun a (m :: * -> *).
MonadReader (CompilationCtx uni fun a) m =>
m (Provenance a)
getEnclosing = Getting (Provenance a) (CompilationCtx uni fun a) (Provenance a)
-> m (Provenance a)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Provenance a) (CompilationCtx uni fun a) (Provenance a)
forall (uni :: * -> *) fun a (f :: * -> *).
Functor f =>
(Provenance a -> f (Provenance a))
-> CompilationCtx uni fun a -> f (CompilationCtx uni fun a)
ccEnclosing

withEnclosing :: MonadReader (CompilationCtx uni fun a) m => (Provenance a -> Provenance a) -> m b -> m b
withEnclosing :: forall (uni :: * -> *) fun a (m :: * -> *) b.
MonadReader (CompilationCtx uni fun a) m =>
(Provenance a -> Provenance a) -> m b -> m b
withEnclosing Provenance a -> Provenance a
f = (CompilationCtx uni fun a -> CompilationCtx uni fun a)
-> m b -> m b
forall a.
(CompilationCtx uni fun a -> CompilationCtx uni fun a)
-> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (ASetter
  (CompilationCtx uni fun a)
  (CompilationCtx uni fun a)
  (Provenance a)
  (Provenance a)
-> (Provenance a -> Provenance a)
-> CompilationCtx uni fun a
-> CompilationCtx uni fun a
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter
  (CompilationCtx uni fun a)
  (CompilationCtx uni fun a)
  (Provenance a)
  (Provenance a)
forall (uni :: * -> *) fun a (f :: * -> *).
Functor f =>
(Provenance a -> f (Provenance a))
-> CompilationCtx uni fun a -> f (CompilationCtx uni fun a)
ccEnclosing Provenance a -> Provenance a
f)

runIf
  :: MonadReader (CompilationCtx uni fun a) m
  => m Bool
  -> (b -> m b)
  -> (b -> m b)
runIf :: forall (uni :: * -> *) fun a (m :: * -> *) b.
MonadReader (CompilationCtx uni fun a) m =>
m Bool -> (b -> m b) -> b -> m b
runIf m Bool
condition b -> m b
pass b
arg = do
  Bool
doPass <- m Bool
condition
  if Bool
doPass then b -> m b
pass b
arg else b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
arg

runIfOpts :: MonadReader (CompilationCtx uni fun a) m => (b -> m b) -> (b -> m b)
runIfOpts :: forall (uni :: * -> *) fun a (m :: * -> *) b.
MonadReader (CompilationCtx uni fun a) m =>
(b -> m b) -> b -> m b
runIfOpts = m Bool -> (b -> m b) -> b -> m b
forall (uni :: * -> *) fun a (m :: * -> *) b.
MonadReader (CompilationCtx uni fun a) m =>
m Bool -> (b -> m b) -> b -> m b
runIf (m Bool -> (b -> m b) -> b -> m b)
-> m Bool -> (b -> m b) -> b -> m b
forall a b. (a -> b) -> a -> b
$ Getting Bool (CompilationCtx uni fun a) Bool -> m Bool
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view ((CompilationOpts a -> Const Bool (CompilationOpts a))
-> CompilationCtx uni fun a
-> Const Bool (CompilationCtx uni fun a)
forall (uni :: * -> *) fun a (f :: * -> *).
Functor f =>
(CompilationOpts a -> f (CompilationOpts a))
-> CompilationCtx uni fun a -> f (CompilationCtx uni fun a)
ccOpts ((CompilationOpts a -> Const Bool (CompilationOpts a))
 -> CompilationCtx uni fun a
 -> Const Bool (CompilationCtx uni fun a))
-> ((Bool -> Const Bool Bool)
    -> CompilationOpts a -> Const Bool (CompilationOpts a))
-> Getting Bool (CompilationCtx uni fun a) Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Const Bool Bool)
-> CompilationOpts a -> Const Bool (CompilationOpts a)
forall a (f :: * -> *).
Functor f =>
(Bool -> f Bool) -> CompilationOpts a -> f (CompilationOpts a)
coOptimize)

type PLCProgram uni fun a = PLC.Program PLC.TyName PLC.Name uni fun (Provenance a)
type PLCTerm uni fun a = PLC.Term PLC.TyName PLC.Name uni fun (Provenance a)
type PLCType uni a = PLC.Type PLC.TyName uni (Provenance a)

-- | A possibly recursive type.
data PLCRecType uni fun a
    = PlainType (PLCType uni a)
    | RecursiveType (Types.RecursiveType uni fun (Provenance a))

-- | Get the actual type inside a 'PLCRecType'.
getType :: PLCRecType uni fun a -> PLCType uni a
getType :: forall (uni :: * -> *) fun a. PLCRecType uni fun a -> PLCType uni a
getType PLCRecType uni fun a
r = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
t                                                -> PLCType uni a
t
    RecursiveType Types.RecursiveType {_recursiveType :: forall (uni :: * -> *) fun ann.
RecursiveType uni fun ann -> Type TyName uni ann
Types._recursiveType=PLCType uni a
t} -> PLCType uni a
t

-- | Wrap a term appropriately for a possibly recursive type.
wrap :: Provenance a -> PLCRecType uni fun a -> [PLCType uni a] -> PIRTerm uni fun a -> PIRTerm uni fun a
wrap :: forall a (uni :: * -> *) fun.
Provenance a
-> PLCRecType uni fun a
-> [PLCType uni a]
-> PIRTerm uni fun a
-> PIRTerm uni fun a
wrap Provenance a
p PLCRecType uni fun a
r [PLCType uni a]
tvs PIRTerm uni fun a
t = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
_                                                      -> PIRTerm uni fun a
t
    RecursiveType Types.RecursiveType {_recursiveWrap :: forall (uni :: * -> *) fun ann.
RecursiveType uni fun ann
-> forall (term :: * -> *).
   TermLike term TyName Name uni fun =>
   [Type TyName uni ann] -> term ann -> term ann
Types._recursiveWrap=forall (term :: * -> *).
TermLike term TyName Name uni fun =>
[PLCType uni a] -> term (Provenance a) -> term (Provenance a)
wrapper} -> Provenance a -> PIRTerm uni fun a -> PIRTerm uni fun a
forall (f :: * -> *) b a.
Functor f =>
Provenance b -> f a -> f (Provenance b)
setProvenance Provenance a
p (PIRTerm uni fun a -> PIRTerm uni fun a)
-> PIRTerm uni fun a -> PIRTerm uni fun a
forall a b. (a -> b) -> a -> b
$ [PLCType uni a] -> PIRTerm uni fun a -> PIRTerm uni fun a
forall (term :: * -> *).
TermLike term TyName Name uni fun =>
[PLCType uni a] -> term (Provenance a) -> term (Provenance a)
wrapper [PLCType uni a]
tvs PIRTerm uni fun a
t

-- | Unwrap a term appropriately for a possibly recursive type.
unwrap :: Provenance a -> PLCRecType uni fun a -> PIRTerm uni fun a -> PIRTerm uni fun a
unwrap :: forall a (uni :: * -> *) fun.
Provenance a
-> PLCRecType uni fun a -> PIRTerm uni fun a -> PIRTerm uni fun a
unwrap Provenance a
p PLCRecType uni fun a
r PIRTerm uni fun a
t = case PLCRecType uni fun a
r of
    PlainType PLCType uni a
_                          -> PIRTerm uni fun a
t
    RecursiveType Types.RecursiveType {} -> Provenance a -> PIRTerm uni fun a -> PIRTerm uni fun a
forall tyname name (uni :: * -> *) fun a.
a -> Term tyname name uni fun a -> Term tyname name uni fun a
PIR.Unwrap Provenance a
p PIRTerm uni fun a
t

type PIRTerm uni fun a = PIR.Term PIR.TyName PIR.Name uni fun (Provenance a)
type PIRType uni a = PIR.Type PIR.TyName uni (Provenance a)

type Compiling m e uni fun a =
    ( Monad m
    , MonadReader (CompilationCtx uni fun a) m
    , AsTypeError e (PIR.Term PIR.TyName PIR.Name uni fun ()) uni fun (Provenance a)
    , AsTypeErrorExt e uni (Provenance a)
    , AsError e uni fun (Provenance a)
    , MonadError e m
    , MonadQuote m
    , Ord a
    , PLC.Typecheckable uni fun
    , PLC.GEq uni
    -- Pretty printing instances
    , PLC.PrettyUni uni
    , PLC.Pretty fun
    , PLC.Pretty a
    )

type TermDef tyname name uni fun a = PLC.Def (PLC.VarDecl tyname name uni a) (PIR.Term tyname name uni fun a)

-- | We generate some shared definitions compilation, this datatype
-- defines the "keys" for those definitions.
data SharedName =
    FixpointCombinator Integer
    | FixBy
    deriving stock (Int -> SharedName -> ShowS
[SharedName] -> ShowS
SharedName -> String
(Int -> SharedName -> ShowS)
-> (SharedName -> String)
-> ([SharedName] -> ShowS)
-> Show SharedName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SharedName -> ShowS
showsPrec :: Int -> SharedName -> ShowS
$cshow :: SharedName -> String
show :: SharedName -> String
$cshowList :: [SharedName] -> ShowS
showList :: [SharedName] -> ShowS
Show, SharedName -> SharedName -> Bool
(SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool) -> Eq SharedName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SharedName -> SharedName -> Bool
== :: SharedName -> SharedName -> Bool
$c/= :: SharedName -> SharedName -> Bool
/= :: SharedName -> SharedName -> Bool
Eq, Eq SharedName
Eq SharedName =>
(SharedName -> SharedName -> Ordering)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> Bool)
-> (SharedName -> SharedName -> SharedName)
-> (SharedName -> SharedName -> SharedName)
-> Ord SharedName
SharedName -> SharedName -> Bool
SharedName -> SharedName -> Ordering
SharedName -> SharedName -> SharedName
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SharedName -> SharedName -> Ordering
compare :: SharedName -> SharedName -> Ordering
$c< :: SharedName -> SharedName -> Bool
< :: SharedName -> SharedName -> Bool
$c<= :: SharedName -> SharedName -> Bool
<= :: SharedName -> SharedName -> Bool
$c> :: SharedName -> SharedName -> Bool
> :: SharedName -> SharedName -> Bool
$c>= :: SharedName -> SharedName -> Bool
>= :: SharedName -> SharedName -> Bool
$cmax :: SharedName -> SharedName -> SharedName
max :: SharedName -> SharedName -> SharedName
$cmin :: SharedName -> SharedName -> SharedName
min :: SharedName -> SharedName -> SharedName
Ord)

toProgramName :: SharedName -> Quote PLC.Name
toProgramName :: SharedName -> Quote Name
toProgramName (FixpointCombinator Integer
n) = Text -> Quote Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName (Text
"fix" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Integer -> String
forall a. Show a => a -> String
show Integer
n))
toProgramName SharedName
FixBy                  = Text -> Quote Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName Text
"fixBy"