{-# LANGUAGE ConstraintKinds #-}

module PlutusTx.Lattice where

import PlutusTx.Bool
import PlutusTx.Monoid
import PlutusTx.Semigroup

{-| A join semi-lattice, i.e. a partially ordered set equipped with a
binary operation '(\/)'.

Note that the mathematical definition would require an ordering constraint -
we omit that so we can define instances for e.g. '(->)'.
-}
class JoinSemiLattice a where
  (\/) :: a -> a -> a

{-| A meet semi-lattice, i.e. a partially ordered set equipped with a
binary operation '(/\)'.

Note that the mathematical definition would require an ordering constraint -
we omit that so we can define instances for e.g. '(->)'.
-}
class MeetSemiLattice a where
  (/\) :: a -> a -> a

-- | A lattice.
type Lattice a = (JoinSemiLattice a, MeetSemiLattice a)

{-| A bounded join semi-lattice, i.e. a join semi-lattice augmented with
a distinguished element 'bottom' which is the unit of '(\/)'.
-}
class (JoinSemiLattice a) => BoundedJoinSemiLattice a where
  bottom :: a

{-| A bounded meet semi-lattice, i.e. a meet semi-lattice augmented with
a distinguished element 'top' which is the unit of '(/\)'.
-}
class (MeetSemiLattice a) => BoundedMeetSemiLattice a where
  top :: a

-- | A bounded lattice.
type BoundedLattice a = (BoundedJoinSemiLattice a, BoundedMeetSemiLattice a)

-- Wrappers

-- | A wrapper witnessing that a join semi-lattice is a monoid with '(\/)' and 'bottom'.
newtype Join a = Join a

instance (JoinSemiLattice a) => Semigroup (Join a) where
  Join a
l <> :: Join a -> Join a -> Join a
<> Join a
r = a -> Join a
forall a. a -> Join a
Join (a
l a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
\/ a
r)

instance (BoundedJoinSemiLattice a) => Monoid (Join a) where
  mempty :: Join a
mempty = a -> Join a
forall a. a -> Join a
Join a
forall a. BoundedJoinSemiLattice a => a
bottom

-- | A wrapper witnessing that a meet semi-lattice is a monoid with '(/\)' and 'top'.
newtype Meet a = Meet a

instance (MeetSemiLattice a) => Semigroup (Meet a) where
  Meet a
l <> :: Meet a -> Meet a -> Meet a
<> Meet a
r = a -> Meet a
forall a. a -> Meet a
Meet (a
l a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
/\ a
r)

instance (BoundedMeetSemiLattice a) => Monoid (Meet a) where
  mempty :: Meet a
mempty = a -> Meet a
forall a. a -> Meet a
Meet a
forall a. BoundedMeetSemiLattice a => a
top

-- Instances

instance JoinSemiLattice Bool where
  {-# INLINEABLE (\/) #-}
  \/ :: Bool -> Bool -> Bool
(\/) = Bool -> Bool -> Bool
(||)

instance BoundedJoinSemiLattice Bool where
  {-# INLINEABLE bottom #-}
  bottom :: Bool
bottom = Bool
False

instance MeetSemiLattice Bool where
  {-# INLINEABLE (/\) #-}
  /\ :: Bool -> Bool -> Bool
(/\) = Bool -> Bool -> Bool
(&&)

instance BoundedMeetSemiLattice Bool where
  {-# INLINEABLE top #-}
  top :: Bool
top = Bool
True

instance (JoinSemiLattice a, JoinSemiLattice b) => JoinSemiLattice (a, b) where
  {-# INLINEABLE (\/) #-}
  (a
a1, b
b1) \/ :: (a, b) -> (a, b) -> (a, b)
\/ (a
a2, b
b2) = (a
a1 a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
\/ a
a2, b
b1 b -> b -> b
forall a. JoinSemiLattice a => a -> a -> a
\/ b
b2)

instance (BoundedJoinSemiLattice a, BoundedJoinSemiLattice b) => BoundedJoinSemiLattice (a, b) where
  {-# INLINEABLE bottom #-}
  bottom :: (a, b)
bottom = (a
forall a. BoundedJoinSemiLattice a => a
bottom, b
forall a. BoundedJoinSemiLattice a => a
bottom)

instance (MeetSemiLattice a, MeetSemiLattice b) => MeetSemiLattice (a, b) where
  {-# INLINEABLE (/\) #-}
  (a
a1, b
b1) /\ :: (a, b) -> (a, b) -> (a, b)
/\ (a
a2, b
b2) = (a
a1 a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
/\ a
a2, b
b1 b -> b -> b
forall a. MeetSemiLattice a => a -> a -> a
/\ b
b2)

instance (BoundedMeetSemiLattice a, BoundedMeetSemiLattice b) => BoundedMeetSemiLattice (a, b) where
  {-# INLINEABLE top #-}
  top :: (a, b)
top = (a
forall a. BoundedMeetSemiLattice a => a
top, b
forall a. BoundedMeetSemiLattice a => a
top)

instance (JoinSemiLattice b) => JoinSemiLattice (a -> b) where
  {-# INLINEABLE (\/) #-}
  (a -> b
f \/ :: (a -> b) -> (a -> b) -> a -> b
\/ a -> b
g) a
a = a -> b
f a
a b -> b -> b
forall a. JoinSemiLattice a => a -> a -> a
\/ a -> b
g a
a

instance (BoundedJoinSemiLattice b) => BoundedJoinSemiLattice (a -> b) where
  {-# INLINEABLE bottom #-}
  bottom :: a -> b
bottom a
_ = b
forall a. BoundedJoinSemiLattice a => a
bottom

instance (MeetSemiLattice b) => MeetSemiLattice (a -> b) where
  {-# INLINEABLE (/\) #-}
  (a -> b
f /\ :: (a -> b) -> (a -> b) -> a -> b
/\ a -> b
g) a
a = a -> b
f a
a b -> b -> b
forall a. MeetSemiLattice a => a -> a -> a
/\ a -> b
g a
a

instance (BoundedMeetSemiLattice b) => BoundedMeetSemiLattice (a -> b) where
  {-# INLINEABLE top #-}
  top :: a -> b
top a
_ = b
forall a. BoundedMeetSemiLattice a => a
top