From 35bb0e44c6168facbb3acedbc7d4f2dcbdd65224 Mon Sep 17 00:00:00 2001 From: Nicola Tuveri Date: Sat, 13 Jun 2020 17:29:09 +0300 Subject: [PATCH] Refactor BN_R_NO_INVERSE logic in internal functions Closes #12129 As described in https://github.com/openssl/openssl/issues/12129 the readability of the internal functions providing the two alternative implementations for `BN_mod_inverse()` is a bit lacking. Both these functions are now completely internal, so we have the flexibility needed to slightly improve readability and remove unnecessary NULL checks. The main changes here are: - rename `BN_mod_inverse_no_branch()` as `bn_mod_inverse_no_branch()`: this function is `static` so it is not even visible within the rest of libcrypto. By convention upcase prefixes are reserved for public functions. - remove `if (pnoinv == NULL)` checks in `int_bn_mod_inverse()`: this function is internal to the BN module and we can guarantee that all callers pass non-NULL arguments. - `bn_mod_inverse_no_branch()` takes an extra `int *pnoinv` argument, so that it can signal if no inverse exists for the given inputs: in this way the caller is in charge of raising `BN_R_NO_INVERSE` as it is the case for the non-consttime implementation of `int_bn_mod_inverse()`. - `BN_mod_inverse()` is a public function and must guarantee that the internal functions providing the actual implementation receive valid arguments. If the caller passes a NULL `BN_CTX` we create a temporary one for internal use. - reorder function definitions in `crypto/bn/bn_gcd.c` to avoid forward declaration of `static` functions (in preparation for inlining) - inline `bn_mod_inverse_no_branch()`. (Backport to 1.1.1 from https://github.com/openssl/openssl/pull/12142) (cherry picked from commit 5d8b3a3ef2941b8822523742a0408ca6896aa65d) Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/12169) --- crypto/bn/bn_gcd.c | 384 ++++++++++++++++++++++--------------------- include/openssl/bn.h | 2 +- 2 files changed, 201 insertions(+), 185 deletions(-) diff --git a/crypto/bn/bn_gcd.c b/crypto/bn/bn_gcd.c index ef81acb77b..795dc0b698 100644 --- a/crypto/bn/bn_gcd.c +++ b/crypto/bn/bn_gcd.c @@ -10,22 +10,189 @@ #include "internal/cryptlib.h" #include "bn_local.h" -/* solves ax == 1 (mod n) */ -static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in, - const BIGNUM *a, const BIGNUM *n, - BN_CTX *ctx); - -BIGNUM *BN_mod_inverse(BIGNUM *in, - const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx) +/* + * bn_mod_inverse_no_branch is a special version of BN_mod_inverse. It does + * not contain branches that may leak sensitive information. + * + * This is a static function, we ensure all callers in this file pass valid + * arguments: all passed pointers here are non-NULL. + */ +static ossl_inline +BIGNUM *bn_mod_inverse_no_branch(BIGNUM *in, + const BIGNUM *a, const BIGNUM *n, + BN_CTX *ctx, int *pnoinv) { - BIGNUM *rv; - int noinv; - rv = int_bn_mod_inverse(in, a, n, ctx, &noinv); - if (noinv) - BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE); - return rv; + BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL; + BIGNUM *ret = NULL; + int sign; + + bn_check_top(a); + bn_check_top(n); + + BN_CTX_start(ctx); + A = BN_CTX_get(ctx); + B = BN_CTX_get(ctx); + X = BN_CTX_get(ctx); + D = BN_CTX_get(ctx); + M = BN_CTX_get(ctx); + Y = BN_CTX_get(ctx); + T = BN_CTX_get(ctx); + if (T == NULL) + goto err; + + if (in == NULL) + R = BN_new(); + else + R = in; + if (R == NULL) + goto err; + + BN_one(X); + BN_zero(Y); + if (BN_copy(B, a) == NULL) + goto err; + if (BN_copy(A, n) == NULL) + goto err; + A->neg = 0; + + if (B->neg || (BN_ucmp(B, A) >= 0)) { + /* + * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked, + * BN_div_no_branch will be called eventually. + */ + { + BIGNUM local_B; + bn_init(&local_B); + BN_with_flags(&local_B, B, BN_FLG_CONSTTIME); + if (!BN_nnmod(B, &local_B, A, ctx)) + goto err; + /* Ensure local_B goes out of scope before any further use of B */ + } + } + sign = -1; + /*- + * From B = a mod |n|, A = |n| it follows that + * + * 0 <= B < A, + * -sign*X*a == B (mod |n|), + * sign*Y*a == A (mod |n|). + */ + + while (!BN_is_zero(B)) { + BIGNUM *tmp; + + /*- + * 0 < B < A, + * (*) -sign*X*a == B (mod |n|), + * sign*Y*a == A (mod |n|) + */ + + /* + * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked, + * BN_div_no_branch will be called eventually. + */ + { + BIGNUM local_A; + bn_init(&local_A); + BN_with_flags(&local_A, A, BN_FLG_CONSTTIME); + + /* (D, M) := (A/B, A%B) ... */ + if (!BN_div(D, M, &local_A, B, ctx)) + goto err; + /* Ensure local_A goes out of scope before any further use of A */ + } + + /*- + * Now + * A = D*B + M; + * thus we have + * (**) sign*Y*a == D*B + M (mod |n|). + */ + + tmp = A; /* keep the BIGNUM object, the value does not + * matter */ + + /* (A, B) := (B, A mod B) ... */ + A = B; + B = M; + /* ... so we have 0 <= B < A again */ + + /*- + * Since the former M is now B and the former B is now A, + * (**) translates into + * sign*Y*a == D*A + B (mod |n|), + * i.e. + * sign*Y*a - D*A == B (mod |n|). + * Similarly, (*) translates into + * -sign*X*a == A (mod |n|). + * + * Thus, + * sign*Y*a + D*sign*X*a == B (mod |n|), + * i.e. + * sign*(Y + D*X)*a == B (mod |n|). + * + * So if we set (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at + * -sign*X*a == B (mod |n|), + * sign*Y*a == A (mod |n|). + * Note that X and Y stay non-negative all the time. + */ + + if (!BN_mul(tmp, D, X, ctx)) + goto err; + if (!BN_add(tmp, tmp, Y)) + goto err; + + M = Y; /* keep the BIGNUM object, the value does not + * matter */ + Y = X; + X = tmp; + sign = -sign; + } + + /*- + * The while loop (Euclid's algorithm) ends when + * A == gcd(a,n); + * we have + * sign*Y*a == A (mod |n|), + * where Y is non-negative. + */ + + if (sign < 0) { + if (!BN_sub(Y, n, Y)) + goto err; + } + /* Now Y*a == A (mod |n|). */ + + if (BN_is_one(A)) { + /* Y*a == 1 (mod |n|) */ + if (!Y->neg && BN_ucmp(Y, n) < 0) { + if (!BN_copy(R, Y)) + goto err; + } else { + if (!BN_nnmod(R, Y, n, ctx)) + goto err; + } + } else { + *pnoinv = 1; + /* caller sets the BN_R_NO_INVERSE error */ + goto err; + } + + ret = R; + *pnoinv = 0; + + err: + if ((ret == NULL) && (in == NULL)) + BN_free(R); + BN_CTX_end(ctx); + bn_check_top(ret); + return ret; } +/* + * This is an internal function, we assume all callers pass valid arguments: + * all pointers passed here are assumed non-NULL. + */ BIGNUM *int_bn_mod_inverse(BIGNUM *in, const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx, int *pnoinv) @@ -36,17 +203,15 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in, /* This is invalid input so we don't worry about constant time here */ if (BN_abs_is_word(n, 1) || BN_is_zero(n)) { - if (pnoinv != NULL) - *pnoinv = 1; + *pnoinv = 1; return NULL; } - if (pnoinv != NULL) - *pnoinv = 0; + *pnoinv = 0; if ((BN_get_flags(a, BN_FLG_CONSTTIME) != 0) || (BN_get_flags(n, BN_FLG_CONSTTIME) != 0)) { - return BN_mod_inverse_no_branch(in, a, n, ctx); + return bn_mod_inverse_no_branch(in, a, n, ctx, pnoinv); } bn_check_top(a); @@ -332,8 +497,7 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in, goto err; } } else { - if (pnoinv) - *pnoinv = 1; + *pnoinv = 1; goto err; } ret = R; @@ -345,175 +509,27 @@ BIGNUM *int_bn_mod_inverse(BIGNUM *in, return ret; } -/* - * BN_mod_inverse_no_branch is a special version of BN_mod_inverse. It does - * not contain branches that may leak sensitive information. - */ -static BIGNUM *BN_mod_inverse_no_branch(BIGNUM *in, - const BIGNUM *a, const BIGNUM *n, - BN_CTX *ctx) +/* solves ax == 1 (mod n) */ +BIGNUM *BN_mod_inverse(BIGNUM *in, + const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx) { - BIGNUM *A, *B, *X, *Y, *M, *D, *T, *R = NULL; - BIGNUM *ret = NULL; - int sign; - - bn_check_top(a); - bn_check_top(n); - - BN_CTX_start(ctx); - A = BN_CTX_get(ctx); - B = BN_CTX_get(ctx); - X = BN_CTX_get(ctx); - D = BN_CTX_get(ctx); - M = BN_CTX_get(ctx); - Y = BN_CTX_get(ctx); - T = BN_CTX_get(ctx); - if (T == NULL) - goto err; - - if (in == NULL) - R = BN_new(); - else - R = in; - if (R == NULL) - goto err; - - BN_one(X); - BN_zero(Y); - if (BN_copy(B, a) == NULL) - goto err; - if (BN_copy(A, n) == NULL) - goto err; - A->neg = 0; - - if (B->neg || (BN_ucmp(B, A) >= 0)) { - /* - * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked, - * BN_div_no_branch will be called eventually. - */ - { - BIGNUM local_B; - bn_init(&local_B); - BN_with_flags(&local_B, B, BN_FLG_CONSTTIME); - if (!BN_nnmod(B, &local_B, A, ctx)) - goto err; - /* Ensure local_B goes out of scope before any further use of B */ - } - } - sign = -1; - /*- - * From B = a mod |n|, A = |n| it follows that - * - * 0 <= B < A, - * -sign*X*a == B (mod |n|), - * sign*Y*a == A (mod |n|). - */ - - while (!BN_is_zero(B)) { - BIGNUM *tmp; - - /*- - * 0 < B < A, - * (*) -sign*X*a == B (mod |n|), - * sign*Y*a == A (mod |n|) - */ - - /* - * Turn BN_FLG_CONSTTIME flag on, so that when BN_div is invoked, - * BN_div_no_branch will be called eventually. - */ - { - BIGNUM local_A; - bn_init(&local_A); - BN_with_flags(&local_A, A, BN_FLG_CONSTTIME); + BN_CTX *new_ctx = NULL; + BIGNUM *rv; + int noinv = 0; - /* (D, M) := (A/B, A%B) ... */ - if (!BN_div(D, M, &local_A, B, ctx)) - goto err; - /* Ensure local_A goes out of scope before any further use of A */ + if (ctx == NULL) { + ctx = new_ctx = BN_CTX_new(); + if (ctx == NULL) { + BNerr(BN_F_BN_MOD_INVERSE, ERR_R_MALLOC_FAILURE); + return NULL; } - - /*- - * Now - * A = D*B + M; - * thus we have - * (**) sign*Y*a == D*B + M (mod |n|). - */ - - tmp = A; /* keep the BIGNUM object, the value does not - * matter */ - - /* (A, B) := (B, A mod B) ... */ - A = B; - B = M; - /* ... so we have 0 <= B < A again */ - - /*- - * Since the former M is now B and the former B is now A, - * (**) translates into - * sign*Y*a == D*A + B (mod |n|), - * i.e. - * sign*Y*a - D*A == B (mod |n|). - * Similarly, (*) translates into - * -sign*X*a == A (mod |n|). - * - * Thus, - * sign*Y*a + D*sign*X*a == B (mod |n|), - * i.e. - * sign*(Y + D*X)*a == B (mod |n|). - * - * So if we set (X, Y, sign) := (Y + D*X, X, -sign), we arrive back at - * -sign*X*a == B (mod |n|), - * sign*Y*a == A (mod |n|). - * Note that X and Y stay non-negative all the time. - */ - - if (!BN_mul(tmp, D, X, ctx)) - goto err; - if (!BN_add(tmp, tmp, Y)) - goto err; - - M = Y; /* keep the BIGNUM object, the value does not - * matter */ - Y = X; - X = tmp; - sign = -sign; - } - - /*- - * The while loop (Euclid's algorithm) ends when - * A == gcd(a,n); - * we have - * sign*Y*a == A (mod |n|), - * where Y is non-negative. - */ - - if (sign < 0) { - if (!BN_sub(Y, n, Y)) - goto err; } - /* Now Y*a == A (mod |n|). */ - if (BN_is_one(A)) { - /* Y*a == 1 (mod |n|) */ - if (!Y->neg && BN_ucmp(Y, n) < 0) { - if (!BN_copy(R, Y)) - goto err; - } else { - if (!BN_nnmod(R, Y, n, ctx)) - goto err; - } - } else { - BNerr(BN_F_BN_MOD_INVERSE_NO_BRANCH, BN_R_NO_INVERSE); - goto err; - } - ret = R; - err: - if ((ret == NULL) && (in == NULL)) - BN_free(R); - BN_CTX_end(ctx); - bn_check_top(ret); - return ret; + rv = int_bn_mod_inverse(in, a, n, ctx, &noinv); + if (noinv) + BNerr(BN_F_BN_MOD_INVERSE, BN_R_NO_INVERSE); + BN_CTX_free(new_ctx); + return rv; } /*- diff --git a/include/openssl/bn.h b/include/openssl/bn.h index 8af05d00e5..88091df693 100644 --- a/include/openssl/bn.h +++ b/include/openssl/bn.h @@ -56,7 +56,7 @@ extern "C" { * avoid leaking exponent information through timing, * BN_mod_exp_mont() will call BN_mod_exp_mont_consttime, * BN_div() will call BN_div_no_branch, - * BN_mod_inverse() will call BN_mod_inverse_no_branch. + * BN_mod_inverse() will call bn_mod_inverse_no_branch. */ # define BN_FLG_CONSTTIME 0x04 # define BN_FLG_SECURE 0x08 -- 2.25.1