Refactor BN_R_NO_INVERSE logic in internal functions
authorNicola Tuveri <nic.tuv@gmail.com>
Sat, 13 Jun 2020 14:29:09 +0000 (17:29 +0300)
committerNicola Tuveri <nic.tuv@gmail.com>
Sun, 21 Jun 2020 12:46:54 +0000 (15:46 +0300)
Closes #12129

As described in https://github.com/openssl/openssl/issues/12129 the
readability of the internal functions providing the two alternative
implementations for `BN_mod_inverse()` is a bit lacking.

Both these functions are now completely internal, so we have the
flexibility needed to slightly improve readability and remove
unnecessary NULL checks.

The main changes here are:
- rename `BN_mod_inverse_no_branch()` as `bn_mod_inverse_no_branch()`:
  this function is `static` so it is not even visible within the rest of
  libcrypto. By convention upcase prefixes are reserved for public
  functions.
- remove `if (pnoinv == NULL)` checks in `int_bn_mod_inverse()`: this
  function is internal to the BN module and we can guarantee that all
  callers pass non-NULL arguments.
- `bn_mod_inverse_no_branch()` takes an extra `int *pnoinv` argument, so
  that it can signal if no inverse exists for the given inputs: in this
  way the caller is in charge of raising `BN_R_NO_INVERSE` as it is the
  case for the non-consttime implementation of `int_bn_mod_inverse()`.
- `BN_mod_inverse()` is a public function and must guarantee that the
  internal functions providing the actual implementation receive valid
  arguments. If the caller passes a NULL `BN_CTX` we create a temporary
  one for internal use.
- reorder function definitions in `crypto/bn/bn_gcd.c` to avoid forward
  declaration of `static` functions (in preparation for inlining)
- inline `bn_mod_inverse_no_branch()`.

(Backport to 1.1.1 from https://github.com/openssl/openssl/pull/12142)
(cherry picked from commit 5d8b3a3ef2941b8822523742a0408ca6896aa65d)

Reviewed-by: Tomas Mraz <tmraz@fedoraproject.org>
(Merged from https://github.com/openssl/openssl/pull/12169)

crypto/bn/bn_gcd.c
include/openssl/bn.h

index ef81acb77ba6b6475597ec26d72063284d41344c..795dc0b698dfd5fcaf7c5b4f8af2073abeb4c099 100644 (file)
 #include "internal/cryptlib.h"
 #include "bn_local.h"
 
-/* solves ax == 1 (mod n) */
-static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
-                                        const BIGNUM *a, const BIGNUM *n,
-                                        BN_CTX *ctx);
-
-BIGNUM *BN_mod_inverse(BIGNUM *in,
-                       const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
+/*
+ * bn_mod_inverse_no_branch is a special version of BN_mod_inverse. It does
+ * not contain branches that may leak sensitive information.
+ *
+ * This is a static function, we ensure all callers in this file pass valid
+ * arguments: all passed pointers here are non-NULL.
+ */
+static ossl_inline
+BIGNUM *bn_mod_inverse_no_branch(BIGNUM *in,
+                                 const BIGNUM *a, const BIGNUM *n,
+                                 BN_CTX *ctx, int *pnoinv)
 {
-    BIGNUM *rv;
-    int noinv;
-    rv = int_bn_mod_inverse(in, a, n, ctx, &noinv);
-    if (noinv)
-        BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE);
-    return rv;
+    BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
+    BIGNUM *ret = NULL;
+    int sign;
+
+    bn_check_top(a);
+    bn_check_top(n);
+
+    BN_CTX_start(ctx);
+    A = BN_CTX_get(ctx);
+    B = BN_CTX_get(ctx);
+    X = BN_CTX_get(ctx);
+    D = BN_CTX_get(ctx);
+    M = BN_CTX_get(ctx);
+    Y = BN_CTX_get(ctx);
+    T = BN_CTX_get(ctx);
+    if (T == NULL)
+        goto err;
+
+    if (in == NULL)
+        R = BN_new();
+    else
+        R = in;
+    if (R == NULL)
+        goto err;
+
+    BN_one(X);
+    BN_zero(Y);
+    if (BN_copy(B, a) == NULL)
+        goto err;
+    if (BN_copy(A, n) == NULL)
+        goto err;
+    A->neg = 0;
+
+    if (B->neg || (BN_ucmp(B, A) >= 0)) {
+        /*
+         * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
+         * BN_div_no_branch will be called eventually.
+         */
+         {
+            BIGNUM local_B;
+            bn_init(&local_B);
+            BN_with_flags(&local_B, B, BN_FLG_CONSTTIME);
+            if (!BN_nnmod(B, &local_B, A, ctx))
+                goto err;
+            /* Ensure local_B goes out of scope before any further use of B */
+        }
+    }
+    sign = -1;
+    /*-
+     * From  B = a mod |n|,  A = |n|  it follows that
+     *
+     *      0 <= B < A,
+     *     -sign*X*a  ==  B   (mod |n|),
+     *      sign*Y*a  ==  A   (mod |n|).
+     */
+
+    while (!BN_is_zero(B)) {
+        BIGNUM *tmp;
+
+        /*-
+         *      0 < B < A,
+         * (*) -sign*X*a  ==  B   (mod |n|),
+         *      sign*Y*a  ==  A   (mod |n|)
+         */
+
+        /*
+         * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
+         * BN_div_no_branch will be called eventually.
+         */
+        {
+            BIGNUM local_A;
+            bn_init(&local_A);
+            BN_with_flags(&local_A, A, BN_FLG_CONSTTIME);
+
+            /* (D, M) := (A/B, A%B) ... */
+            if (!BN_div(D, M, &local_A, B, ctx))
+                goto err;
+            /* Ensure local_A goes out of scope before any further use of A */
+        }
+
+        /*-
+         * Now
+         *      A = D*B + M;
+         * thus we have
+         * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
+         */
+
+        tmp = A;                /* keep the BIGNUM object, the value does not
+                                 * matter */
+
+        /* (A, B) := (B, A mod B) ... */
+        A = B;
+        B = M;
+        /* ... so we have  0 <= B < A  again */
+
+        /*-
+         * Since the former  M  is now  B  and the former  B  is now  A,
+         * (**) translates into
+         *       sign*Y*a  ==  D*A + B    (mod |n|),
+         * i.e.
+         *       sign*Y*a - D*A  ==  B    (mod |n|).
+         * Similarly, (*) translates into
+         *      -sign*X*a  ==  A          (mod |n|).
+         *
+         * Thus,
+         *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
+         * i.e.
+         *        sign*(Y + D*X)*a  ==  B  (mod |n|).
+         *
+         * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
+         *      -sign*X*a  ==  B   (mod |n|),
+         *       sign*Y*a  ==  A   (mod |n|).
+         * Note that  X  and  Y  stay non-negative all the time.
+         */
+
+        if (!BN_mul(tmp, D, X, ctx))
+            goto err;
+        if (!BN_add(tmp, tmp, Y))
+            goto err;
+
+        M = Y;                  /* keep the BIGNUM object, the value does not
+                                 * matter */
+        Y = X;
+        X = tmp;
+        sign = -sign;
+    }
+
+    /*-
+     * The while loop (Euclid's algorithm) ends when
+     *      A == gcd(a,n);
+     * we have
+     *       sign*Y*a  ==  A  (mod |n|),
+     * where  Y  is non-negative.
+     */
+
+    if (sign < 0) {
+        if (!BN_sub(Y, n, Y))
+            goto err;
+    }
+    /* Now  Y*a  ==  A  (mod |n|).  */
+
+    if (BN_is_one(A)) {
+        /* Y*a == 1  (mod |n|) */
+        if (!Y->neg && BN_ucmp(Y, n) < 0) {
+            if (!BN_copy(R, Y))
+                goto err;
+        } else {
+            if (!BN_nnmod(R, Y, n, ctx))
+                goto err;
+        }
+    } else {
+        *pnoinv = 1;
+        /* caller sets the BN_R_NO_INVERSE error */
+        goto err;
+    }
+
+    ret = R;
+    *pnoinv = 0;
+
+ err:
+    if ((ret == NULL) && (in == NULL))
+        BN_free(R);
+    BN_CTX_end(ctx);
+    bn_check_top(ret);
+    return ret;
 }
 
+/*
+ * This is an internal function, we assume all callers pass valid arguments:
+ * all pointers passed here are assumed non-NULL.
+ */
 BIGNUM *int_bn_mod_inverse(BIGNUM *in,
                            const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
                            int *pnoinv)
@@ -36,17 +203,15 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in,
 
     /* This is invalid input so we don't worry about constant time here */
     if (BN_abs_is_word(n, 1) || BN_is_zero(n)) {
-        if (pnoinv != NULL)
-            *pnoinv = 1;
+        *pnoinv = 1;
         return NULL;
     }
 
-    if (pnoinv != NULL)
-        *pnoinv = 0;
+    *pnoinv = 0;
 
     if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0)
         || (BN_get_flags(n, BN_FLG_CONSTTIME) != 0)) {
-        return BN_mod_inverse_no_branch(in, a, n, ctx);
+        return bn_mod_inverse_no_branch(in, a, n, ctx, pnoinv);
     }
 
     bn_check_top(a);
@@ -332,8 +497,7 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in,
                 goto err;
         }
     } else {
-        if (pnoinv)
-            *pnoinv = 1;
+        *pnoinv = 1;
         goto err;
     }
     ret = R;
@@ -345,175 +509,27 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in,
     return ret;
 }
 
-/*
- * BN_mod_inverse_no_branch is a special version of BN_mod_inverse. It does
- * not contain branches that may leak sensitive information.
- */
-static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in,
-                                        const BIGNUM *a, const BIGNUM *n,
-                                        BN_CTX *ctx)
+/* solves ax == 1 (mod n) */
+BIGNUM *BN_mod_inverse(BIGNUM *in,
+                       const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
 {
-    BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL;
-    BIGNUM *ret = NULL;
-    int sign;
-
-    bn_check_top(a);
-    bn_check_top(n);
-
-    BN_CTX_start(ctx);
-    A = BN_CTX_get(ctx);
-    B = BN_CTX_get(ctx);
-    X = BN_CTX_get(ctx);
-    D = BN_CTX_get(ctx);
-    M = BN_CTX_get(ctx);
-    Y = BN_CTX_get(ctx);
-    T = BN_CTX_get(ctx);
-    if (T == NULL)
-        goto err;
-
-    if (in == NULL)
-        R = BN_new();
-    else
-        R = in;
-    if (R == NULL)
-        goto err;
-
-    BN_one(X);
-    BN_zero(Y);
-    if (BN_copy(B, a) == NULL)
-        goto err;
-    if (BN_copy(A, n) == NULL)
-        goto err;
-    A->neg = 0;
-
-    if (B->neg || (BN_ucmp(B, A) >= 0)) {
-        /*
-         * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
-         * BN_div_no_branch will be called eventually.
-         */
-         {
-            BIGNUM local_B;
-            bn_init(&local_B);
-            BN_with_flags(&local_B, B, BN_FLG_CONSTTIME);
-            if (!BN_nnmod(B, &local_B, A, ctx))
-                goto err;
-            /* Ensure local_B goes out of scope before any further use of B */
-        }
-    }
-    sign = -1;
-    /*-
-     * From  B = a mod |n|,  A = |n|  it follows that
-     *
-     *      0 <= B < A,
-     *     -sign*X*a  ==  B   (mod |n|),
-     *      sign*Y*a  ==  A   (mod |n|).
-     */
-
-    while (!BN_is_zero(B)) {
-        BIGNUM *tmp;
-
-        /*-
-         *      0 < B < A,
-         * (*) -sign*X*a  ==  B   (mod |n|),
-         *      sign*Y*a  ==  A   (mod |n|)
-         */
-
-        /*
-         * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked,
-         * BN_div_no_branch will be called eventually.
-         */
-        {
-            BIGNUM local_A;
-            bn_init(&local_A);
-            BN_with_flags(&local_A, A, BN_FLG_CONSTTIME);
+    BN_CTX *new_ctx = NULL;
+    BIGNUM *rv;
+    int noinv = 0;
 
-            /* (D, M) := (A/B, A%B) ... */
-            if (!BN_div(D, M, &local_A, B, ctx))
-                goto err;
-            /* Ensure local_A goes out of scope before any further use of A */
+    if (ctx == NULL) {
+        ctx = new_ctx = BN_CTX_new();
+        if (ctx == NULL) {
+            BNerr(BN_F_BN_MOD_INVERSE, ERR_R_MALLOC_FAILURE);
+            return NULL;
         }
-
-        /*-
-         * Now
-         *      A = D*B + M;
-         * thus we have
-         * (**)  sign*Y*a  ==  D*B + M   (mod |n|).
-         */
-
-        tmp = A;                /* keep the BIGNUM object, the value does not
-                                 * matter */
-
-        /* (A, B) := (B, A mod B) ... */
-        A = B;
-        B = M;
-        /* ... so we have  0 <= B < A  again */
-
-        /*-
-         * Since the former  M  is now  B  and the former  B  is now  A,
-         * (**) translates into
-         *       sign*Y*a  ==  D*A + B    (mod |n|),
-         * i.e.
-         *       sign*Y*a - D*A  ==  B    (mod |n|).
-         * Similarly, (*) translates into
-         *      -sign*X*a  ==  A          (mod |n|).
-         *
-         * Thus,
-         *   sign*Y*a + D*sign*X*a  ==  B  (mod |n|),
-         * i.e.
-         *        sign*(Y + D*X)*a  ==  B  (mod |n|).
-         *
-         * So if we set  (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at
-         *      -sign*X*a  ==  B   (mod |n|),
-         *       sign*Y*a  ==  A   (mod |n|).
-         * Note that  X  and  Y  stay non-negative all the time.
-         */
-
-        if (!BN_mul(tmp, D, X, ctx))
-            goto err;
-        if (!BN_add(tmp, tmp, Y))
-            goto err;
-
-        M = Y;                  /* keep the BIGNUM object, the value does not
-                                 * matter */
-        Y = X;
-        X = tmp;
-        sign = -sign;
-    }
-
-    /*-
-     * The while loop (Euclid's algorithm) ends when
-     *      A == gcd(a,n);
-     * we have
-     *       sign*Y*a  ==  A  (mod |n|),
-     * where  Y  is non-negative.
-     */
-
-    if (sign < 0) {
-        if (!BN_sub(Y, n, Y))
-            goto err;
     }
-    /* Now  Y*a  ==  A  (mod |n|).  */
 
-    if (BN_is_one(A)) {
-        /* Y*a == 1  (mod |n|) */
-        if (!Y->neg && BN_ucmp(Y, n) < 0) {
-            if (!BN_copy(R, Y))
-                goto err;
-        } else {
-            if (!BN_nnmod(R, Y, n, ctx))
-                goto err;
-        }
-    } else {
-        BNerr(BN_F_BN_MOD_INVERSE_NO_BRANCH, BN_R_NO_INVERSE);
-        goto err;
-    }
-    ret = R;
- err:
-    if ((ret == NULL) && (in == NULL))
-        BN_free(R);
-    BN_CTX_end(ctx);
-    bn_check_top(ret);
-    return ret;
+    rv = int_bn_mod_inverse(in, a, n, ctx, &noinv);
+    if (noinv)
+        BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE);
+    BN_CTX_free(new_ctx);
+    return rv;
 }
 
 /*-
index 8af05d00e59a90cb7c764eeda0b8a269504dd5a1..88091df6938cc5abdf24c776a0e1ab4d9db26d94 100644 (file)
@@ -56,7 +56,7 @@ extern "C" {
  * avoid leaking exponent information through timing,
  * BN_mod_exp_mont() will call BN_mod_exp_mont_consttime,
  * BN_div() will call BN_div_no_branch,
- * BN_mod_inverse() will call BN_mod_inverse_no_branch.
+ * BN_mod_inverse() will call bn_mod_inverse_no_branch.
  */
 # define BN_FLG_CONSTTIME        0x04
 # define BN_FLG_SECURE           0x08