Return an error from BN_mod_inverse if n is 1 (or -1)
authorMatt Caswell <matt@openssl.org>
Fri, 27 Apr 2018 16:36:11 +0000 (17:36 +0100)
committerMatt Caswell <matt@openssl.org>
Thu, 3 May 2018 09:35:49 +0000 (10:35 +0100)
Calculating BN_mod_inverse where n is 1 (or -1) doesn't make sense. We
should return an error in that case. Instead we were returning a valid
result with value 0.

Fixes #6004

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/6119)

(cherry picked from commit b1860d6c71733314417d053a72af66ae72e8268e)

crypto/bn/bn_gcd.c
crypto/bn/bn_mont.c

index 067642644ec1020ec7b5c1031a922c0421ae76a2..fd1c7a2947c4a873ffcc8f12c808d6cb449f3ee6 100644 (file)
@@ -140,7 +140,14 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in,
     BIGNUM *ret = NULL;
     int sign;
 
-    if (pnoinv)
+    /* This is invalid input so we don't worry about constant time here */
+    if (BN_abs_is_word(n, 1) || BN_is_zero(n)) {
+        if (pnoinv != NULL)
+            *pnoinv = 1;
+        return NULL;
+    }
+
+    if (pnoinv != NULL)
         *pnoinv = 0;
 
     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
index faef5815717d22d69ff5c11703ef156fd567b211..c0c174650cc5d3021a109cccb6d3720ad6ad6e35 100644 (file)
@@ -278,7 +278,9 @@ int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
         if ((buf[1] = mod->top > 1 ? mod->d[1] : 0))
             tmod.top = 2;
 
-        if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
+        if (BN_is_one(&tmod))
+            BN_zero(Ri);
+        else if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
             goto err;
         if (!BN_lshift(Ri, Ri, 2 * BN_BITS2))
             goto err;           /* R*Ri */
@@ -311,7 +313,9 @@ int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod, BN_CTX *ctx)
         buf[1] = 0;
         tmod.top = buf[0] != 0 ? 1 : 0;
         /* Ri = R^-1 mod N */
-        if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
+        if (BN_is_one(&tmod))
+            BN_zero(Ri);
+        else if ((BN_mod_inverse(Ri, R, &tmod, ctx)) == NULL)
             goto err;
         if (!BN_lshift(Ri, Ri, BN_BITS2))
             goto err;           /* R*Ri */