Fix RSA-PSS in FIPS mode by switching digest implementations.
authorDr. Stephen Henson <steve@openssl.org>
Sat, 22 Jul 2017 14:54:48 +0000 (15:54 +0100)
committerDr. Stephen Henson <steve@openssl.org>
Sat, 22 Jul 2017 23:17:33 +0000 (00:17 +0100)
Fixes #2718

Reviewed-by: Tim Hudson <tjh@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3996)

crypto/rsa/rsa_pmeth.c

index 8896e2e977149b787c1a6c91bdfdc35487e0780a..00e730ffa95887ca818a299a72e2be23bfc126a7 100644 (file)
@@ -180,27 +180,25 @@ static void pkey_rsa_cleanup(EVP_PKEY_CTX *ctx)
  * FIPS mode.
  */
 
-static int pkey_fips_check_ctx(EVP_PKEY_CTX *ctx)
+static int pkey_fips_check_rsa(const RSA *rsa, const EVP_MD **pmd,
+                               const EVP_MD **pmgf1md)
 {
-    RSA_PKEY_CTX *rctx = ctx->data;
-    RSA *rsa = ctx->pkey->pkey.rsa;
     int rv = -1;
+
     if (!FIPS_mode())
         return 0;
     if (rsa->flags & RSA_FLAG_NON_FIPS_ALLOW)
         rv = 0;
     if (!(rsa->meth->flags & RSA_FLAG_FIPS_METHOD) && rv)
         return -1;
-    if (rctx->md) {
-        const EVP_MD *fmd;
-        fmd = FIPS_get_digestbynid(EVP_MD_type(rctx->md));
-        if (!fmd || !(fmd->flags & EVP_MD_FLAG_FIPS))
+    if (*pmd != NULL) {
+        *pmd = FIPS_get_digestbynid(EVP_MD_type(*pmd));
+        if (*pmd == NULL || !((*pmd)->flags & EVP_MD_FLAG_FIPS))
             return rv;
     }
-    if (rctx->mgf1md && !(rctx->mgf1md->flags & EVP_MD_FLAG_FIPS)) {
-        const EVP_MD *fmd;
-        fmd = FIPS_get_digestbynid(EVP_MD_type(rctx->mgf1md));
-        if (!fmd || !(fmd->flags & EVP_MD_FLAG_FIPS))
+    if (*pmgf1md != NULL) {
+        *pmgf1md = FIPS_get_digestbynid(EVP_MD_type(*pmgf1md));
+        if (*pmgf1md == NULL || !((*pmgf1md)->flags & EVP_MD_FLAG_FIPS))
             return rv;
     }
     return 1;
@@ -214,27 +212,27 @@ static int pkey_rsa_sign(EVP_PKEY_CTX *ctx, unsigned char *sig,
     int ret;
     RSA_PKEY_CTX *rctx = ctx->data;
     RSA *rsa = ctx->pkey->pkey.rsa;
+    const EVP_MD *md = rctx->md;
+    const EVP_MD *mgf1md = rctx->mgf1md;
 
 #ifdef OPENSSL_FIPS
-    ret = pkey_fips_check_ctx(ctx);
+    ret = pkey_fips_check_rsa(rsa, &md, &mgf1md);
     if (ret < 0) {
         RSAerr(RSA_F_PKEY_RSA_SIGN, RSA_R_OPERATION_NOT_ALLOWED_IN_FIPS_MODE);
         return -1;
     }
 #endif
 
-    if (rctx->md) {
-        if (tbslen != (size_t)EVP_MD_size(rctx->md)) {
+    if (md != NULL) {
+        if (tbslen != (size_t)EVP_MD_size(md)) {
             RSAerr(RSA_F_PKEY_RSA_SIGN, RSA_R_INVALID_DIGEST_LENGTH);
             return -1;
         }
 #ifdef OPENSSL_FIPS
         if (ret > 0) {
             unsigned int slen;
-            ret = FIPS_rsa_sign_digest(rsa, tbs, tbslen, rctx->md,
-                                       rctx->pad_mode,
-                                       rctx->saltlen,
-                                       rctx->mgf1md, sig, &slen);
+            ret = FIPS_rsa_sign_digest(rsa, tbs, tbslen, md, rctx->pad_mode,
+                                       rctx->saltlen, mgf1md, sig, &slen);
             if (ret > 0)
                 *siglen = slen;
             else
@@ -243,12 +241,12 @@ static int pkey_rsa_sign(EVP_PKEY_CTX *ctx, unsigned char *sig,
         }
 #endif
 
-        if (EVP_MD_type(rctx->md) == NID_mdc2) {
+        if (EVP_MD_type(md) == NID_mdc2) {
             unsigned int sltmp;
             if (rctx->pad_mode != RSA_PKCS1_PADDING)
                 return -1;
-            ret = RSA_sign_ASN1_OCTET_STRING(NID_mdc2,
-                                             tbs, tbslen, sig, &sltmp, rsa);
+            ret = RSA_sign_ASN1_OCTET_STRING(NID_mdc2, tbs, tbslen, sig, &sltmp,
+                                             rsa);
 
             if (ret <= 0)
                 return ret;
@@ -263,23 +261,20 @@ static int pkey_rsa_sign(EVP_PKEY_CTX *ctx, unsigned char *sig,
                 return -1;
             }
             memcpy(rctx->tbuf, tbs, tbslen);
-            rctx->tbuf[tbslen] = RSA_X931_hash_id(EVP_MD_type(rctx->md));
+            rctx->tbuf[tbslen] = RSA_X931_hash_id(EVP_MD_type(md));
             ret = RSA_private_encrypt(tbslen + 1, rctx->tbuf,
                                       sig, rsa, RSA_X931_PADDING);
         } else if (rctx->pad_mode == RSA_PKCS1_PADDING) {
             unsigned int sltmp;
-            ret = RSA_sign(EVP_MD_type(rctx->md),
-                           tbs, tbslen, sig, &sltmp, rsa);
+            ret = RSA_sign(EVP_MD_type(md), tbs, tbslen, sig, &sltmp, rsa);
             if (ret <= 0)
                 return ret;
             ret = sltmp;
         } else if (rctx->pad_mode == RSA_PKCS1_PSS_PADDING) {
             if (!setup_tbuf(rctx, ctx))
                 return -1;
-            if (!RSA_padding_add_PKCS1_PSS_mgf1(rsa,
-                                                rctx->tbuf, tbs,
-                                                rctx->md, rctx->mgf1md,
-                                                rctx->saltlen))
+            if (!RSA_padding_add_PKCS1_PSS_mgf1(rsa, rctx->tbuf, tbs,
+                                                md, mgf1md, rctx->saltlen))
                 return -1;
             ret = RSA_private_encrypt(RSA_size(rsa), rctx->tbuf,
                                       sig, rsa, RSA_NO_PADDING);
@@ -348,32 +343,31 @@ static int pkey_rsa_verify(EVP_PKEY_CTX *ctx,
 {
     RSA_PKEY_CTX *rctx = ctx->data;
     RSA *rsa = ctx->pkey->pkey.rsa;
+    const EVP_MD *md = rctx->md;
+    const EVP_MD *mgf1md = rctx->mgf1md;
     size_t rslen;
+
 #ifdef OPENSSL_FIPS
-    int rv;
-    rv = pkey_fips_check_ctx(ctx);
+    int rv = pkey_fips_check_rsa(rsa, &md, &mgf1md);
+
     if (rv < 0) {
         RSAerr(RSA_F_PKEY_RSA_VERIFY,
                RSA_R_OPERATION_NOT_ALLOWED_IN_FIPS_MODE);
         return -1;
     }
 #endif
-    if (rctx->md) {
+    if (md != NULL) {
 #ifdef OPENSSL_FIPS
         if (rv > 0) {
-            return FIPS_rsa_verify_digest(rsa,
-                                          tbs, tbslen,
-                                          rctx->md,
-                                          rctx->pad_mode,
-                                          rctx->saltlen,
-                                          rctx->mgf1md, sig, siglen);
+            return FIPS_rsa_verify_digest(rsa, tbs, tbslen, md, rctx->pad_mode,
+                                          rctx->saltlen, mgf1md, sig, siglen);
 
         }
 #endif
         if (rctx->pad_mode == RSA_PKCS1_PADDING)
-            return RSA_verify(EVP_MD_type(rctx->md), tbs, tbslen,
+            return RSA_verify(EVP_MD_type(md), tbs, tbslen,
                               sig, siglen, rsa);
-        if (tbslen != (size_t)EVP_MD_size(rctx->md)) {
+        if (tbslen != (size_t)EVP_MD_size(md)) {
             RSAerr(RSA_F_PKEY_RSA_VERIFY, RSA_R_INVALID_DIGEST_LENGTH);
             return -1;
         }
@@ -388,8 +382,7 @@ static int pkey_rsa_verify(EVP_PKEY_CTX *ctx,
                                      rsa, RSA_NO_PADDING);
             if (ret <= 0)
                 return 0;
-            ret = RSA_verify_PKCS1_PSS_mgf1(rsa, tbs,
-                                            rctx->md, rctx->mgf1md,
+            ret = RSA_verify_PKCS1_PSS_mgf1(rsa, tbs, md, mgf1md,
                                             rctx->tbuf, rctx->saltlen);
             if (ret <= 0)
                 return 0;