{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
module UntypedPlutusCore.Core.Zip
    ( pzipWith
    , pzip
    , tzipWith
    , tzip
    ) where

import Control.Monad (void, when)
import Control.Monad.Except (MonadError, throwError)
import Data.Vector
import UntypedPlutusCore.Core.Instance.Eq ()
import UntypedPlutusCore.Core.Type

-- | Zip two programs using a combinator function for annotations.
--
-- Throws an error if the input programs are not "equal" modulo annotations.
-- Note that the function is "left-biased", so in case that the 2 input programs contain `Name`s,
-- the output program will contain just the `Name`s of the first input program.
pzipWith :: forall p name uni fun ann1 ann2 ann3 m.
           (p ~ Program name uni fun, (Eq (Term name uni fun ())), MonadError String m)
         => (ann1 -> ann2 -> ann3)
         -> p ann1
         -> p ann2
         -> m (p ann3)
pzipWith :: forall (p :: * -> *) name (uni :: * -> *) fun ann1 ann2 ann3
       (m :: * -> *).
(p ~ Program name uni fun, Eq (Term name uni fun ()),
 MonadError String m) =>
(ann1 -> ann2 -> ann3) -> p ann1 -> p ann2 -> m (p ann3)
pzipWith ann1 -> ann2 -> ann3
f (Program ann1
ann1 Version
ver1 Term name uni fun ann1
t1) (Program ann2
ann2 Version
ver2 Term name uni fun ann2
t2) = do
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
ver1 Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
/= Version
ver2) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
       String -> m ()
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"zip: Versions do not match."
    ann3
-> Version -> Term name uni fun ann3 -> Program name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> Version -> Term name uni fun ann -> Program name uni fun ann
Program (ann1 -> ann2 -> ann3
f ann1
ann1 ann2
ann2) Version
ver1 (Term name uni fun ann3 -> p ann3)
-> m (Term name uni fun ann3) -> m (p ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ann1 -> ann2 -> ann3)
-> Term name uni fun ann1
-> Term name uni fun ann2
-> m (Term name uni fun ann3)
forall (t :: * -> *) name (uni :: * -> *) fun ann1 ann2 ann3
       (m :: * -> *).
(t ~ Term name uni fun, Eq (t ()), MonadError String m) =>
(ann1 -> ann2 -> ann3) -> t ann1 -> t ann2 -> m (t ann3)
tzipWith ann1 -> ann2 -> ann3
f Term name uni fun ann1
t1 Term name uni fun ann2
t2

-- | Zip two terms using a combinator function for annotations.
--
-- Throws an error if the input terms are not "equal" modulo annotations.
-- Note that the function is "left-biased", so in case that the 2 input terms contain `Name`s,
-- the output term will contain just the `Name`s of the first input term.
-- TODO: this is not an optimal implementation
tzipWith :: forall t name uni fun ann1 ann2 ann3 m.
           (t ~ Term name uni fun, Eq (t ()), MonadError String m)
         => (ann1 -> ann2 -> ann3)
         -> t ann1
         -> t ann2
         -> m (t ann3)
tzipWith :: forall (t :: * -> *) name (uni :: * -> *) fun ann1 ann2 ann3
       (m :: * -> *).
(t ~ Term name uni fun, Eq (t ()), MonadError String m) =>
(ann1 -> ann2 -> ann3) -> t ann1 -> t ann2 -> m (t ann3)
tzipWith ann1 -> ann2 -> ann3
f t ann1
term1 t ann2
term2 = do
    -- Prior establishing t1==t2 avoids the need to check for Eq uni, Eq fun and alpha-equivalence.
    -- Slower this way because we have to re-traverse the terms.
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (t ann1 -> t ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void t ann1
term1 t () -> t () -> Bool
forall a. Eq a => a -> a -> Bool
/= t ann2 -> t ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void t ann2
term2) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
       String -> m ()
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"zip: Terms do not match."
    t ann1 -> t ann2 -> m (t ann3)
go t ann1
term1 t ann2
term2
 where
   go :: t ann1 -> t ann2 -> m (t ann3)
   -- MAYBE: some boilerplate could be removed on the following clauses if termAnn was a lens
   go :: t ann1 -> t ann2 -> m (t ann3)
go (Constant ann1
a1 Some (ValueOf uni)
s1) (Constant ann2
a2 Some (ValueOf uni)
_s2)    = t ann3 -> m (t ann3)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (t ann3 -> m (t ann3)) -> t ann3 -> m (t ann3)
forall a b. (a -> b) -> a -> b
$ ann3 -> Some (ValueOf uni) -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> Some (ValueOf uni) -> Term name uni fun ann
Constant (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) Some (ValueOf uni)
s1
   go (Builtin ann1
a1 fun
f1) (Builtin ann2
a2 fun
_f2)      = t ann3 -> m (t ann3)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (t ann3 -> m (t ann3)) -> t ann3 -> m (t ann3)
forall a b. (a -> b) -> a -> b
$ ann3 -> fun -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> fun -> Term name uni fun ann
Builtin (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) fun
f1
   go (Var ann1
a1 name
n1) (Var ann2
a2 name
_n2)              = t ann3 -> m (t ann3)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (t ann3 -> m (t ann3)) -> t ann3 -> m (t ann3)
forall a b. (a -> b) -> a -> b
$ ann3 -> name -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann
Var (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) name
n1
   go (Error ann1
a1) (Error ann2
a2)                 = t ann3 -> m (t ann3)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (t ann3 -> m (t ann3)) -> t ann3 -> m (t ann3)
forall a b. (a -> b) -> a -> b
$ ann3 -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann. ann -> Term name uni fun ann
Error (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2)
   -- MAYBE: some boilerplate could be removed here if we used parallel subterm traversals/toListOf
   go (LamAbs ann1
a1 name
n1 Term name uni fun ann1
t1) (LamAbs ann2
a2 name
_n2 Term name uni fun ann2
t2)  = ann3 -> name -> Term name uni fun ann3 -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) name
n1 (Term name uni fun ann3 -> t ann3)
-> m (Term name uni fun ann3) -> m (t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1 t ann2
Term name uni fun ann2
t2
   go (Apply ann1
a1 Term name uni fun ann1
t1a Term name uni fun ann1
t1b) (Apply ann2
a2 Term name uni fun ann2
t2a Term name uni fun ann2
t2b) = ann3
-> Term name uni fun ann3
-> Term name uni fun ann3
-> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) (Term name uni fun ann3 -> Term name uni fun ann3 -> t ann3)
-> m (Term name uni fun ann3)
-> m (Term name uni fun ann3 -> t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1a t ann2
Term name uni fun ann2
t2a m (Term name uni fun ann3 -> t ann3)
-> m (Term name uni fun ann3) -> m (t ann3)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1b t ann2
Term name uni fun ann2
t2b
   go (Force ann1
a1 Term name uni fun ann1
t1) (Force ann2
a2 Term name uni fun ann2
t2)           = ann3 -> Term name uni fun ann3 -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) (Term name uni fun ann3 -> t ann3)
-> m (Term name uni fun ann3) -> m (t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1 t ann2
Term name uni fun ann2
t2
   go (Delay ann1
a1 Term name uni fun ann1
t1) (Delay ann2
a2 Term name uni fun ann2
t2)           = ann3 -> Term name uni fun ann3 -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) (Term name uni fun ann3 -> t ann3)
-> m (Term name uni fun ann3) -> m (t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1 t ann2
Term name uni fun ann2
t2
   go (Constr ann1
a1 Word64
i1 [Term name uni fun ann1]
ts1) (Constr ann2
a2 Word64
_i2 [Term name uni fun ann2]
ts2) = ann3
-> Word64 -> [Term name uni fun ann3] -> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann -> Word64 -> [Term name uni fun ann] -> Term name uni fun ann
Constr (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) Word64
i1 ([Term name uni fun ann3] -> t ann3)
-> m [Term name uni fun ann3] -> m (t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t ann1 -> t ann2 -> m (Term name uni fun ann3))
-> [t ann1] -> [t ann2] -> m [Term name uni fun ann3]
forall (n :: * -> *) a b c.
MonadError String n =>
(a -> b -> n c) -> [a] -> [b] -> n [c]
zipExactWithM t ann1 -> t ann2 -> m (t ann3)
t ann1 -> t ann2 -> m (Term name uni fun ann3)
go [t ann1]
[Term name uni fun ann1]
ts1 [t ann2]
[Term name uni fun ann2]
ts2
   go (Case ann1
a1 Term name uni fun ann1
t1 Vector (Term name uni fun ann1)
vs1) (Case ann2
a2 Term name uni fun ann2
t2 Vector (Term name uni fun ann2)
vs2) =
       ann3
-> Term name uni fun ann3
-> Vector (Term name uni fun ann3)
-> Term name uni fun ann3
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Vector (Term name uni fun ann)
-> Term name uni fun ann
Case (ann1 -> ann2 -> ann3
f ann1
a1 ann2
a2) (Term name uni fun ann3
 -> Vector (Term name uni fun ann3) -> t ann3)
-> m (Term name uni fun ann3)
-> m (Vector (Term name uni fun ann3) -> t ann3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t ann1 -> t ann2 -> m (t ann3)
go t ann1
Term name uni fun ann1
t1 t ann2
Term name uni fun ann2
t2 m (Vector (Term name uni fun ann3) -> t ann3)
-> m (Vector (Term name uni fun ann3)) -> m (t ann3)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([Term name uni fun ann3] -> Vector (Term name uni fun ann3)
forall a. [a] -> Vector a
fromList ([Term name uni fun ann3] -> Vector (Term name uni fun ann3))
-> m [Term name uni fun ann3]
-> m (Vector (Term name uni fun ann3))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t ann1 -> t ann2 -> m (Term name uni fun ann3))
-> [t ann1] -> [t ann2] -> m [Term name uni fun ann3]
forall (n :: * -> *) a b c.
MonadError String n =>
(a -> b -> n c) -> [a] -> [b] -> n [c]
zipExactWithM t ann1 -> t ann2 -> m (t ann3)
t ann1 -> t ann2 -> m (Term name uni fun ann3)
go (Vector (t ann1) -> [t ann1]
forall a. Vector a -> [a]
toList Vector (t ann1)
Vector (Term name uni fun ann1)
vs1) (Vector (t ann2) -> [t ann2]
forall a. Vector a -> [a]
toList Vector (t ann2)
Vector (Term name uni fun ann2)
vs2))
   go t ann1
_ t ann2
_                                   =
       String -> m (t ann3)
forall a. String -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"zip: This should not happen, because we prior established term equality."

   zipExactWithM :: MonadError String n => (a -> b -> n c) -> [a] -> [b] -> n [c]
   zipExactWithM :: forall (n :: * -> *) a b c.
MonadError String n =>
(a -> b -> n c) -> [a] -> [b] -> n [c]
zipExactWithM a -> b -> n c
g (a
a:[a]
as) (b
b:[b]
bs) = (:) (c -> [c] -> [c]) -> n c -> n ([c] -> [c])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> b -> n c
g a
a b
b n ([c] -> [c]) -> n [c] -> n [c]
forall a b. n (a -> b) -> n a -> n b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> b -> n c) -> [a] -> [b] -> n [c]
forall (n :: * -> *) a b c.
MonadError String n =>
(a -> b -> n c) -> [a] -> [b] -> n [c]
zipExactWithM a -> b -> n c
g [a]
as [b]
bs
   zipExactWithM a -> b -> n c
_ [] []         = [c] -> n [c]
forall a. a -> n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
   zipExactWithM a -> b -> n c
_ [a]
_ [b]
_           = String -> n [c]
forall a. String -> n a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"zipExactWithM: not exact"

-- | Zip 2 programs by pairing their annotations
pzip :: (p ~ Program name uni fun, Eq (Term name uni fun ()), MonadError String m)
     => p ann1
     -> p ann2
     -> m (p (ann1,ann2))
pzip :: forall (p :: * -> *) name (uni :: * -> *) fun (m :: * -> *) ann1
       ann2.
(p ~ Program name uni fun, Eq (Term name uni fun ()),
 MonadError String m) =>
p ann1 -> p ann2 -> m (p (ann1, ann2))
pzip = (ann1 -> ann2 -> (ann1, ann2))
-> p ann1 -> p ann2 -> m (p (ann1, ann2))
forall (p :: * -> *) name (uni :: * -> *) fun ann1 ann2 ann3
       (m :: * -> *).
(p ~ Program name uni fun, Eq (Term name uni fun ()),
 MonadError String m) =>
(ann1 -> ann2 -> ann3) -> p ann1 -> p ann2 -> m (p ann3)
pzipWith (,)

-- | Zip 2 terms by pairing their annotations
tzip :: (t ~ Term name uni fun, Eq (t ()), MonadError String m)
     => t ann1
     -> t ann2
     -> m (t (ann1,ann2))
tzip :: forall (t :: * -> *) name (uni :: * -> *) fun (m :: * -> *) ann1
       ann2.
(t ~ Term name uni fun, Eq (t ()), MonadError String m) =>
t ann1 -> t ann2 -> m (t (ann1, ann2))
tzip = (ann1 -> ann2 -> (ann1, ann2))
-> t ann1 -> t ann2 -> m (t (ann1, ann2))
forall (t :: * -> *) name (uni :: * -> *) fun ann1 ann2 ann3
       (m :: * -> *).
(t ~ Term name uni fun, Eq (t ()), MonadError String m) =>
(ann1 -> ann2 -> ann3) -> t ann1 -> t ann2 -> m (t ann3)
tzipWith (,)