{-# LANGUAGE MagicHash   #-}
{-# LANGUAGE UnboxedSums #-}
module PlutusCore.Crypto.ExpMod
    ( expMod
    ) where

import PlutusCore.Builtin

import GHC.Natural
import GHC.Num.Integer

-- | Modular exponentiation.  This uses GHC.Num.integerPowMod#, which gives the
-- wrong answer in some cases.  TODO: we'll be able to remove some of the guards
-- when/if integerPowMod# gets fixed.
expMod :: Integer -> Integer -> Natural -> BuiltinResult Natural
expMod :: Integer -> Integer -> Natural -> BuiltinResult Natural
expMod Integer
b Integer
e Natural
m
  | Natural
m Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
0 = String -> BuiltinResult Natural
forall a. String -> BuiltinResult a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"expMod: invalid modulus"
  -- ^ We can't have m<0 when m is a Natural, but we may as well be paranoid.
  | Natural
m Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
1 = Natural -> BuiltinResult Natural
forall a. a -> BuiltinResult a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Natural
0
  -- ^ Just in case: GHC.Num.Integer.integerRecip# gets this wrong.  Note that 0
  -- is invertible modulo 1, with inverse 0.
  | Integer
b Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
&& Integer
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Integer -> Natural -> BuiltinResult Natural
failNonInvertible Integer
0 Natural
m
  -- ^ integerPowMod# incorrectly returns 0 in this case.
  | Bool
otherwise =
      case Integer -> Integer -> Natural -> (# Natural | () #)
integerPowMod# Integer
b Integer
e Natural
m of
        (# Natural
n | #)  -> Natural -> BuiltinResult Natural
forall a. a -> BuiltinResult a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Natural
n
        (# | () #) -> Integer -> Natural -> BuiltinResult Natural
failNonInvertible Integer
b Natural
m
  where failNonInvertible :: Integer -> Natural -> BuiltinResult Natural
        failNonInvertible :: Integer -> Natural -> BuiltinResult Natural
failNonInvertible Integer
b1 Natural
m1 =
          String -> BuiltinResult Natural
forall a. String -> BuiltinResult a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"expMod: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Integer -> String
forall a. Show a => a -> String
show Integer
b1) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not invertible modulo " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Natural -> String
forall a. Show a => a -> String
show Natural
m1))
{-# INLINE expMod #-}