-- editorconfig-checker-disable-file
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE DeriveAnyClass        #-}
{-# LANGUAGE DerivingStrategies    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE RoleAnnotations       #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module PlutusTx.Code where

import Control.Exception
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Flat (Flat (..), unflat)
import Flat.Decoder (DecodeException)
import PlutusCore qualified as PLC
import PlutusCore.Annotation
import PlutusCore.Pretty
import PlutusIR qualified as PIR
import PlutusTx.Coverage
import PlutusTx.Lift.Instances ()
import UntypedPlutusCore qualified as UPLC
-- We do not use qualified import because the whole module contains off-chain code
import PlutusPrelude
import Prelude as Haskell

-- The final type parameter is inferred to be phantom, but we give it a nominal
-- role, since it corresponds to the Haskell type of the program that was compiled into
-- this 'CompiledCodeIn'. It could be okay to give it a representational role, since
-- we compile newtypes the same as their underlying types, but people probably just
-- shouldn't coerce the final parameter regardless, so we play it safe with a nominal role.
type role CompiledCodeIn representational representational nominal
-- NOTE: any changes to this type must be paralleled by changes
-- in the plugin code that generates values of this type. That is
-- done by code generation so it's not typechecked normally.
-- | A compiled Plutus Tx program. The last type parameter indicates
-- the type of the Haskell expression that was compiled, and
-- hence the type of the compiled code.
--
-- Note: the compiled PLC program does *not* have normalized types,
-- if you want to put it on the chain you must normalize the types first.
data CompiledCodeIn uni fun a =
    -- | Serialized UPLC code and possibly serialized PIR code with metadata used for program coverage.
    SerializedCode BS.ByteString (Maybe BS.ByteString) CoverageIndex
    -- | Deserialized UPLC program, and possibly deserialized PIR program with metadata used for program coverage.
    | DeserializedCode
        (UPLC.Program UPLC.NamedDeBruijn uni fun SrcSpans)
        (Maybe (PIR.Program PLC.TyName PLC.Name uni fun SrcSpans))
        CoverageIndex

-- | 'CompiledCodeIn' instantiated with default built-in types and functions.
type CompiledCode = CompiledCodeIn PLC.DefaultUni PLC.DefaultFun

-- | Apply a compiled function to a compiled argument. Will fail if the versions don't match.
applyCode
    :: (PLC.Closed uni
        , uni `PLC.Everywhere` Flat
        , Flat fun
        , Pretty fun
        , PLC.Everywhere uni PrettyConst
        , PrettyBy RenderContext (PLC.SomeTypeIn uni))
    => CompiledCodeIn uni fun (a -> b)
    -> CompiledCodeIn uni fun a
    -> Either String (CompiledCodeIn uni fun b)
applyCode :: forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun, Pretty fun,
 Everywhere uni PrettyConst,
 PrettyBy RenderContext (SomeTypeIn uni)) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a
-> Either String (CompiledCodeIn uni fun b)
applyCode CompiledCodeIn uni fun (a -> b)
fun CompiledCodeIn uni fun a
arg = do
  let uplc :: Program NamedDeBruijn uni fun SrcSpans
uplc = Either ApplyProgramError (Program NamedDeBruijn uni fun SrcSpans)
-> Program NamedDeBruijn uni fun SrcSpans
forall e a. Show e => Either e a -> a
unsafeFromRight (Either ApplyProgramError (Program NamedDeBruijn uni fun SrcSpans)
 -> Program NamedDeBruijn uni fun SrcSpans)
-> Either
     ApplyProgramError (Program NamedDeBruijn uni fun SrcSpans)
-> Program NamedDeBruijn uni fun SrcSpans
forall a b. (a -> b) -> a -> b
$ Program NamedDeBruijn uni fun SrcSpans
-> Program NamedDeBruijn uni fun SrcSpans
-> Either
     ApplyProgramError (Program NamedDeBruijn uni fun SrcSpans)
forall (m :: * -> *) a name (uni :: * -> *) fun.
(MonadError ApplyProgramError m, Semigroup a) =>
Program name uni fun a
-> Program name uni fun a -> m (Program name uni fun a)
UPLC.applyProgram (CompiledCodeIn uni fun (a -> b)
-> Program NamedDeBruijn uni fun SrcSpans
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
getPlc CompiledCodeIn uni fun (a -> b)
fun) (CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
getPlc CompiledCodeIn uni fun a
arg)
  -- Probably this could be done with more appropriate combinators, but the
  -- nested Maybes make it very easy to do the wrong thing here (I did it
  -- wrong first!), so I wrote it painfully explicitly.
  Maybe (Program TyName Name uni fun SrcSpans)
pir <- case (CompiledCodeIn uni fun (a -> b)
-> Maybe (Program TyName Name uni fun SrcSpans)
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
getPir CompiledCodeIn uni fun (a -> b)
fun, CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
getPir CompiledCodeIn uni fun a
arg) of
    (Just Program TyName Name uni fun SrcSpans
funPir, Just Program TyName Name uni fun SrcSpans
argPir) -> case Program TyName Name uni fun SrcSpans
-> Program TyName Name uni fun SrcSpans
-> Either ApplyProgramError (Program TyName Name uni fun SrcSpans)
forall (m :: * -> *) a tyname name (uni :: * -> *) fun.
(MonadError ApplyProgramError m, Semigroup a) =>
Program tyname name uni fun a
-> Program tyname name uni fun a
-> m (Program tyname name uni fun a)
PIR.applyProgram Program TyName Name uni fun SrcSpans
funPir Program TyName Name uni fun SrcSpans
argPir of
        Right Program TyName Name uni fun SrcSpans
appliedPir -> Maybe (Program TyName Name uni fun SrcSpans)
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Program TyName Name uni fun SrcSpans
-> Maybe (Program TyName Name uni fun SrcSpans)
forall a. a -> Maybe a
Just Program TyName Name uni fun SrcSpans
appliedPir)
        -- Had PIR for both, but failed to apply them, this should fail
        Left ApplyProgramError
err         -> String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. a -> Either a b
Left (String
 -> Either String (Maybe (Program TyName Name uni fun SrcSpans)))
-> String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. (a -> b) -> a -> b
$ ApplyProgramError -> String
forall a. Show a => a -> String
show ApplyProgramError
err
    -- Missing PIR for one or both, this succeeds but has no PIR
    (Just Program TyName Name uni fun SrcSpans
funPir, Maybe (Program TyName Name uni fun SrcSpans)
Nothing) ->
        String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. a -> Either a b
Left (String
 -> Either String (Maybe (Program TyName Name uni fun SrcSpans)))
-> String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. (a -> b) -> a -> b
$ String
"Missing PIR for the argument."
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"Got PIR for the function program \n"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Program TyName Name uni fun SrcSpans -> String
forall str a. (Pretty a, Render str) => a -> str
display Program TyName Name uni fun SrcSpans
funPir
    (Maybe (Program TyName Name uni fun SrcSpans)
Nothing, Just Program TyName Name uni fun SrcSpans
argPir) ->
        String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. a -> Either a b
Left (String
 -> Either String (Maybe (Program TyName Name uni fun SrcSpans)))
-> String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. (a -> b) -> a -> b
$ String
"Missing PIR for the function program."
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"Got PIR for the argument \n"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Program TyName Name uni fun SrcSpans -> String
forall str a. (Pretty a, Render str) => a -> str
display Program TyName Name uni fun SrcSpans
argPir
    (Maybe (Program TyName Name uni fun SrcSpans)
Nothing, Maybe (Program TyName Name uni fun SrcSpans)
Nothing) -> String
-> Either String (Maybe (Program TyName Name uni fun SrcSpans))
forall a b. a -> Either a b
Left String
"Missing PIR for both the function program and the argument."

  CompiledCodeIn uni fun b
-> Either String (CompiledCodeIn uni fun b)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CompiledCodeIn uni fun b
 -> Either String (CompiledCodeIn uni fun b))
-> CompiledCodeIn uni fun b
-> Either String (CompiledCodeIn uni fun b)
forall a b. (a -> b) -> a -> b
$ Program NamedDeBruijn uni fun SrcSpans
-> Maybe (Program TyName Name uni fun SrcSpans)
-> CoverageIndex
-> CompiledCodeIn uni fun b
forall (uni :: * -> *) fun a.
Program NamedDeBruijn uni fun SrcSpans
-> Maybe (Program TyName Name uni fun SrcSpans)
-> CoverageIndex
-> CompiledCodeIn uni fun a
DeserializedCode Program NamedDeBruijn uni fun SrcSpans
uplc Maybe (Program TyName Name uni fun SrcSpans)
pir (CompiledCodeIn uni fun (a -> b) -> CoverageIndex
forall (uni :: * -> *) fun a.
CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun (a -> b)
fun CoverageIndex -> CoverageIndex -> CoverageIndex
forall a. Semigroup a => a -> a -> a
<> CompiledCodeIn uni fun a -> CoverageIndex
forall (uni :: * -> *) fun a.
CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun a
arg)

-- | Apply a compiled function to a compiled argument. Will throw if the versions don't match,
-- should only be used in non-production code.
unsafeApplyCode
    :: (PLC.Closed uni
    , uni `PLC.Everywhere` Flat
    , Flat fun
    , Pretty fun
    , PLC.Everywhere uni PrettyConst
    , PrettyBy RenderContext (PLC.SomeTypeIn uni))
    => CompiledCodeIn uni fun (a -> b) -> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
unsafeApplyCode :: forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun, Pretty fun,
 Everywhere uni PrettyConst,
 PrettyBy RenderContext (SomeTypeIn uni)) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
unsafeApplyCode CompiledCodeIn uni fun (a -> b)
fun CompiledCodeIn uni fun a
arg = case CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a
-> Either String (CompiledCodeIn uni fun b)
forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun, Pretty fun,
 Everywhere uni PrettyConst,
 PrettyBy RenderContext (SomeTypeIn uni)) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a
-> Either String (CompiledCodeIn uni fun b)
applyCode CompiledCodeIn uni fun (a -> b)
fun CompiledCodeIn uni fun a
arg of
  Right CompiledCodeIn uni fun b
c  -> CompiledCodeIn uni fun b
c
  Left String
err -> String -> CompiledCodeIn uni fun b
forall a. HasCallStack => String -> a
error String
err

-- | The size of a 'CompiledCodeIn', in AST nodes.
sizePlc :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun) => CompiledCodeIn uni fun a -> Integer
sizePlc :: forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Integer
sizePlc = Size -> Integer
UPLC.unSize (Size -> Integer)
-> (CompiledCodeIn uni fun a -> Size)
-> CompiledCodeIn uni fun a
-> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Program NamedDeBruijn uni fun SrcSpans -> Size
forall name (uni :: * -> *) fun ann.
Program name uni fun ann -> Size
UPLC.programSize (Program NamedDeBruijn uni fun SrcSpans -> Size)
-> (CompiledCodeIn uni fun a
    -> Program NamedDeBruijn uni fun SrcSpans)
-> CompiledCodeIn uni fun a
-> Size
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
getPlc

{- Note [Deserializing the AST]
The types suggest that we can fail to deserialize the AST that we embedded in the program.
However, we just did it ourselves, so this should be impossible, and we signal this with an
exception.
-}
newtype ImpossibleDeserialisationFailure = ImpossibleDeserialisationFailure DecodeException
    deriving anyclass (Show ImpossibleDeserialisationFailure
Typeable ImpossibleDeserialisationFailure
(Typeable ImpossibleDeserialisationFailure,
 Show ImpossibleDeserialisationFailure) =>
(ImpossibleDeserialisationFailure -> SomeException)
-> (SomeException -> Maybe ImpossibleDeserialisationFailure)
-> (ImpossibleDeserialisationFailure -> String)
-> Exception ImpossibleDeserialisationFailure
SomeException -> Maybe ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure -> String
ImpossibleDeserialisationFailure -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: ImpossibleDeserialisationFailure -> SomeException
toException :: ImpossibleDeserialisationFailure -> SomeException
$cfromException :: SomeException -> Maybe ImpossibleDeserialisationFailure
fromException :: SomeException -> Maybe ImpossibleDeserialisationFailure
$cdisplayException :: ImpossibleDeserialisationFailure -> String
displayException :: ImpossibleDeserialisationFailure -> String
Exception)
instance Show ImpossibleDeserialisationFailure where
    show :: ImpossibleDeserialisationFailure -> String
show (ImpossibleDeserialisationFailure DecodeException
e) = String
"Failed to deserialise our own program! This is a bug, please report it. Caused by: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ DecodeException -> String
forall a. Show a => a -> String
show DecodeException
e

-- | Get the actual Plutus Core program out of a 'CompiledCodeIn'.
getPlc
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
    => CompiledCodeIn uni fun a -> UPLC.Program UPLC.NamedDeBruijn uni fun SrcSpans
getPlc :: forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
getPlc CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
    SerializedCode ByteString
plc Maybe ByteString
_ CoverageIndex
_ -> case ByteString
-> Decoded (UnrestrictedProgram NamedDeBruijn uni fun SrcSpans)
forall a b. (Flat a, AsByteString b) => b -> Decoded a
unflat (ByteString -> ByteString
BSL.fromStrict ByteString
plc) of
        Left DecodeException
e                             -> ImpossibleDeserialisationFailure
-> Program NamedDeBruijn uni fun SrcSpans
forall a e. Exception e => e -> a
throw (ImpossibleDeserialisationFailure
 -> Program NamedDeBruijn uni fun SrcSpans)
-> ImpossibleDeserialisationFailure
-> Program NamedDeBruijn uni fun SrcSpans
forall a b. (a -> b) -> a -> b
$ DecodeException -> ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure DecodeException
e
        Right (UPLC.UnrestrictedProgram Program NamedDeBruijn uni fun SrcSpans
p) -> Program NamedDeBruijn uni fun SrcSpans
p
    DeserializedCode Program NamedDeBruijn uni fun SrcSpans
plc Maybe (Program TyName Name uni fun SrcSpans)
_ CoverageIndex
_ -> Program NamedDeBruijn uni fun SrcSpans
plc

getPlcNoAnn
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
    => CompiledCodeIn uni fun a -> UPLC.Program UPLC.NamedDeBruijn uni fun ()
getPlcNoAnn :: forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun ()
getPlcNoAnn = Program NamedDeBruijn uni fun SrcSpans
-> Program NamedDeBruijn uni fun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Program NamedDeBruijn uni fun SrcSpans
 -> Program NamedDeBruijn uni fun ())
-> (CompiledCodeIn uni fun a
    -> Program NamedDeBruijn uni fun SrcSpans)
-> CompiledCodeIn uni fun a
-> Program NamedDeBruijn uni fun ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Program NamedDeBruijn uni fun SrcSpans
getPlc

-- | Get the Plutus IR program, if there is one, out of a 'CompiledCodeIn'.
getPir
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
    => CompiledCodeIn uni fun a -> Maybe (PIR.Program PIR.TyName PIR.Name uni fun SrcSpans)
getPir :: forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
getPir CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
    SerializedCode ByteString
_ Maybe ByteString
pir CoverageIndex
_ -> case Maybe ByteString
pir of
        Just ByteString
bs -> case ByteString -> Decoded (Program TyName Name uni fun SrcSpans)
forall a b. (Flat a, AsByteString b) => b -> Decoded a
unflat (ByteString -> ByteString
BSL.fromStrict ByteString
bs) of
            Left DecodeException
e  -> ImpossibleDeserialisationFailure
-> Maybe (Program TyName Name uni fun SrcSpans)
forall a e. Exception e => e -> a
throw (ImpossibleDeserialisationFailure
 -> Maybe (Program TyName Name uni fun SrcSpans))
-> ImpossibleDeserialisationFailure
-> Maybe (Program TyName Name uni fun SrcSpans)
forall a b. (a -> b) -> a -> b
$ DecodeException -> ImpossibleDeserialisationFailure
ImpossibleDeserialisationFailure DecodeException
e
            Right Program TyName Name uni fun SrcSpans
p -> Program TyName Name uni fun SrcSpans
-> Maybe (Program TyName Name uni fun SrcSpans)
forall a. a -> Maybe a
Just Program TyName Name uni fun SrcSpans
p
        Maybe ByteString
Nothing -> Maybe (Program TyName Name uni fun SrcSpans)
forall a. Maybe a
Nothing
    DeserializedCode Program NamedDeBruijn uni fun SrcSpans
_ Maybe (Program TyName Name uni fun SrcSpans)
pir CoverageIndex
_ -> Maybe (Program TyName Name uni fun SrcSpans)
pir

getPirNoAnn
    :: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
    => CompiledCodeIn uni fun a -> Maybe (PIR.Program PIR.TyName PIR.Name uni fun ())
getPirNoAnn :: forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a -> Maybe (Program TyName Name uni fun ())
getPirNoAnn = (Program TyName Name uni fun SrcSpans
 -> Program TyName Name uni fun ())
-> Maybe (Program TyName Name uni fun SrcSpans)
-> Maybe (Program TyName Name uni fun ())
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Program TyName Name uni fun SrcSpans
-> Program TyName Name uni fun ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Maybe (Program TyName Name uni fun SrcSpans)
 -> Maybe (Program TyName Name uni fun ()))
-> (CompiledCodeIn uni fun a
    -> Maybe (Program TyName Name uni fun SrcSpans))
-> CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
forall (uni :: * -> *) fun a.
(Closed uni, Everywhere uni Flat, Flat fun) =>
CompiledCodeIn uni fun a
-> Maybe (Program TyName Name uni fun SrcSpans)
getPir

getCovIdx :: CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx :: forall (uni :: * -> *) fun a.
CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx CompiledCodeIn uni fun a
wrapper = case CompiledCodeIn uni fun a
wrapper of
  SerializedCode ByteString
_ Maybe ByteString
_ CoverageIndex
idx   -> CoverageIndex
idx
  DeserializedCode Program NamedDeBruijn uni fun SrcSpans
_ Maybe (Program TyName Name uni fun SrcSpans)
_ CoverageIndex
idx -> CoverageIndex
idx