colibri_imx6: fix video stdout in default environment
[oweals/u-boot.git] / lib / rsa / rsa-sign.c
index 8c6637e3287e5d3fd069c3523bdd7b196c11544c..40ca1e1f57320aaf52d3c1aa2a8d1a76d7817fb5 100644 (file)
@@ -1,14 +1,15 @@
+// SPDX-License-Identifier: GPL-2.0+
 /*
  * Copyright (c) 2013, Google Inc.
- *
- * SPDX-License-Identifier:    GPL-2.0+
  */
 
 #include "mkimage.h"
+#include <stdlib.h>
 #include <stdio.h>
 #include <string.h>
 #include <image.h>
 #include <time.h>
+#include <openssl/bn.h>
 #include <openssl/rsa.h>
 #include <openssl/pem.h>
 #include <openssl/err.h>
 #define HAVE_ERR_REMOVE_THREAD_STATE
 #endif
 
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+       (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
+static void RSA_get0_key(const RSA *r,
+                 const BIGNUM **n, const BIGNUM **e, const BIGNUM **d)
+{
+   if (n != NULL)
+       *n = r->n;
+   if (e != NULL)
+       *e = r->e;
+   if (d != NULL)
+       *d = r->d;
+}
+#endif
+
 static int rsa_err(const char *msg)
 {
        unsigned long sslErr = ERR_get_error();
@@ -119,13 +134,27 @@ static int rsa_engine_get_pub_key(const char *keydir, const char *name,
        engine_id = ENGINE_get_id(engine);
 
        if (engine_id && !strcmp(engine_id, "pkcs11")) {
+               if (keydir)
+                       if (strstr(keydir, "object="))
+                               snprintf(key_id, sizeof(key_id),
+                                        "pkcs11:%s;type=public",
+                                        keydir);
+                       else
+                               snprintf(key_id, sizeof(key_id),
+                                        "pkcs11:%s;object=%s;type=public",
+                                        keydir, name);
+               else
+                       snprintf(key_id, sizeof(key_id),
+                                "pkcs11:object=%s;type=public",
+                                name);
+       } else if (engine_id) {
                if (keydir)
                        snprintf(key_id, sizeof(key_id),
-                                "pkcs11:%s;object=%s;type=public",
+                                "%s%s",
                                 keydir, name);
                else
                        snprintf(key_id, sizeof(key_id),
-                                "pkcs11:object=%s;type=public",
+                                "%s",
                                 name);
        } else {
                fprintf(stderr, "Engine not supported\n");
@@ -231,12 +260,26 @@ static int rsa_engine_get_priv_key(const char *keydir, const char *name,
 
        if (engine_id && !strcmp(engine_id, "pkcs11")) {
                if (keydir)
+                       if (strstr(keydir, "object="))
+                               snprintf(key_id, sizeof(key_id),
+                                        "pkcs11:%s;type=private",
+                                        keydir);
+                       else
+                               snprintf(key_id, sizeof(key_id),
+                                        "pkcs11:%s;object=%s;type=private",
+                                        keydir, name);
+               else
                        snprintf(key_id, sizeof(key_id),
-                                "pkcs11:%s;object=%s;type=private",
+                                "pkcs11:object=%s;type=private",
+                                name);
+       } else if (engine_id) {
+               if (keydir)
+                       snprintf(key_id, sizeof(key_id),
+                                "%s%s",
                                 keydir, name);
                else
                        snprintf(key_id, sizeof(key_id),
-                                "pkcs11:object=%s;type=private",
+                                "%s",
                                 name);
        } else {
                fprintf(stderr, "Engine not supported\n");
@@ -286,16 +329,24 @@ static int rsa_init(void)
 {
        int ret;
 
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+       (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
        ret = SSL_library_init();
+#else
+       ret = OPENSSL_init_ssl(0, NULL);
+#endif
        if (!ret) {
                fprintf(stderr, "Failure to init SSL library\n");
                return -1;
        }
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+       (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
        SSL_load_error_strings();
 
        OpenSSL_add_all_algorithms();
        OpenSSL_add_all_digests();
        OpenSSL_add_all_ciphers();
+#endif
 
        return 0;
 }
@@ -335,12 +386,17 @@ err_set_rsa:
 err_engine_init:
        ENGINE_free(e);
 err_engine_by_id:
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+       (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
        ENGINE_cleanup();
+#endif
        return ret;
 }
 
 static void rsa_remove(void)
 {
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+       (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
        CRYPTO_cleanup_all_ex_data();
        ERR_free_strings();
 #ifdef HAVE_ERR_REMOVE_THREAD_STATE
@@ -349,6 +405,7 @@ static void rsa_remove(void)
        ERR_remove_state(0);
 #endif
        EVP_cleanup();
+#endif
 }
 
 static void rsa_engine_remove(ENGINE *e)
@@ -359,13 +416,16 @@ static void rsa_engine_remove(ENGINE *e)
        }
 }
 
-static int rsa_sign_with_key(RSA *rsa, struct checksum_algo *checksum_algo,
+static int rsa_sign_with_key(RSA *rsa, struct padding_algo *padding_algo,
+                            struct checksum_algo *checksum_algo,
                const struct image_region region[], int region_count,
                uint8_t **sigp, uint *sig_size)
 {
        EVP_PKEY *key;
+       EVP_PKEY_CTX *ckey;
        EVP_MD_CTX *context;
-       int size, ret = 0;
+       int ret = 0;
+       size_t size;
        uint8_t *sig;
        int i;
 
@@ -381,7 +441,7 @@ static int rsa_sign_with_key(RSA *rsa, struct checksum_algo *checksum_algo,
        size = EVP_PKEY_size(key);
        sig = malloc(size);
        if (!sig) {
-               fprintf(stderr, "Out of memory for signature (%d bytes)\n",
+               fprintf(stderr, "Out of memory for signature (%zu bytes)\n",
                        size);
                ret = -ENOMEM;
                goto err_alloc;
@@ -393,27 +453,53 @@ static int rsa_sign_with_key(RSA *rsa, struct checksum_algo *checksum_algo,
                goto err_create;
        }
        EVP_MD_CTX_init(context);
-       if (!EVP_SignInit(context, checksum_algo->calculate_sign())) {
+
+       ckey = EVP_PKEY_CTX_new(key, NULL);
+       if (!ckey) {
+               ret = rsa_err("EVP key context creation failed");
+               goto err_create;
+       }
+
+       if (EVP_DigestSignInit(context, &ckey,
+                              checksum_algo->calculate_sign(),
+                              NULL, key) <= 0) {
                ret = rsa_err("Signer setup failed");
                goto err_sign;
        }
 
+#ifdef CONFIG_FIT_ENABLE_RSASSA_PSS_SUPPORT
+       if (padding_algo && !strcmp(padding_algo->name, "pss")) {
+               if (EVP_PKEY_CTX_set_rsa_padding(ckey,
+                                                RSA_PKCS1_PSS_PADDING) <= 0) {
+                       ret = rsa_err("Signer padding setup failed");
+                       goto err_sign;
+               }
+       }
+#endif /* CONFIG_FIT_ENABLE_RSASSA_PSS_SUPPORT */
+
        for (i = 0; i < region_count; i++) {
-               if (!EVP_SignUpdate(context, region[i].data, region[i].size)) {
+               if (!EVP_DigestSignUpdate(context, region[i].data,
+                                         region[i].size)) {
                        ret = rsa_err("Signing data failed");
                        goto err_sign;
                }
        }
 
-       if (!EVP_SignFinal(context, sig, sig_size, key)) {
+       if (!EVP_DigestSignFinal(context, sig, &size)) {
                ret = rsa_err("Could not obtain signature");
                goto err_sign;
        }
-       EVP_MD_CTX_cleanup(context);
+
+       #if OPENSSL_VERSION_NUMBER < 0x10100000L || \
+               (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x02070000fL)
+               EVP_MD_CTX_cleanup(context);
+       #else
+               EVP_MD_CTX_reset(context);
+       #endif
        EVP_MD_CTX_destroy(context);
        EVP_PKEY_free(key);
 
-       debug("Got signature: %d bytes, expected %d\n", *sig_size, size);
+       debug("Got signature: %d bytes, expected %zu\n", *sig_size, size);
        *sigp = sig;
        *sig_size = size;
 
@@ -450,7 +536,7 @@ int rsa_sign(struct image_sign_info *info,
        ret = rsa_get_priv_key(info->keydir, info->keyname, e, &rsa);
        if (ret)
                goto err_priv;
-       ret = rsa_sign_with_key(rsa, info->checksum, region,
+       ret = rsa_sign_with_key(rsa, info->padding, info->checksum, region,
                                region_count, sigp, sig_len);
        if (ret)
                goto err_sign;
@@ -479,6 +565,7 @@ static int rsa_get_exponent(RSA *key, uint64_t *e)
 {
        int ret;
        BIGNUM *bn_te;
+       const BIGNUM *key_e;
        uint64_t te;
 
        ret = -EINVAL;
@@ -487,17 +574,18 @@ static int rsa_get_exponent(RSA *key, uint64_t *e)
        if (!e)
                goto cleanup;
 
-       if (BN_num_bits(key->e) > 64)
+       RSA_get0_key(key, NULL, &key_e, NULL);
+       if (BN_num_bits(key_e) > 64)
                goto cleanup;
 
-       *e = BN_get_word(key->e);
+       *e = BN_get_word(key_e);
 
-       if (BN_num_bits(key->e) < 33) {
+       if (BN_num_bits(key_e) < 33) {
                ret = 0;
                goto cleanup;
        }
 
-       bn_te = BN_dup(key->e);
+       bn_te = BN_dup(key_e);
        if (!bn_te)
                goto cleanup;
 
@@ -527,6 +615,7 @@ int rsa_get_params(RSA *key, uint64_t *exponent, uint32_t *n0_invp,
 {
        BIGNUM *big1, *big2, *big32, *big2_32;
        BIGNUM *n, *r, *r_squared, *tmp;
+       const BIGNUM *key_n;
        BN_CTX *bn_ctx = BN_CTX_new();
        int ret = 0;
 
@@ -548,7 +637,8 @@ int rsa_get_params(RSA *key, uint64_t *exponent, uint32_t *n0_invp,
        if (0 != rsa_get_exponent(key, exponent))
                ret = -1;
 
-       if (!BN_copy(n, key->n) || !BN_set_word(big1, 1L) ||
+       RSA_get0_key(key, &key_n, NULL, NULL);
+       if (!BN_copy(n, key_n) || !BN_set_word(big1, 1L) ||
            !BN_set_word(big2, 2L) || !BN_set_word(big32, 32L))
                ret = -1;
 
@@ -604,6 +694,15 @@ static int fdt_add_bignum(void *blob, int noffset, const char *prop_name,
        big2 = BN_new();
        big32 = BN_new();
        big2_32 = BN_new();
+
+       /*
+        * Note: This code assumes that all of the above succeed, or all fail.
+        * In practice memory allocations generally do not fail (unless the
+        * process is killed), so it does not seem worth handling each of these
+        * as a separate case. Technicaly this could leak memory on failure,
+        * but a) it won't happen in practice, and b) it doesn't matter as we
+        * will immediately exit with a failure code.
+        */
        if (!tmp || !big2 || !big32 || !big2_32) {
                fprintf(stderr, "Out of memory (bignum)\n");
                return -ENOMEM;
@@ -636,15 +735,13 @@ static int fdt_add_bignum(void *blob, int noffset, const char *prop_name,
         * might fail several times
         */
        ret = fdt_setprop(blob, noffset, prop_name, buf, size);
-       if (ret)
-               return -FDT_ERR_NOSPACE;
        free(buf);
        BN_free(tmp);
        BN_free(big2);
        BN_free(big32);
        BN_free(big2_32);
 
-       return ret;
+       return ret ? -FDT_ERR_NOSPACE : 0;
 }
 
 int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
@@ -705,8 +802,8 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
        }
 
        if (!ret) {
-               ret = fdt_setprop_string(keydest, node, "key-name-hint",
-                                info->keyname);
+               ret = fdt_setprop_string(keydest, node, FIT_KEY_HINT,
+                                        info->keyname);
        }
        if (!ret)
                ret = fdt_setprop_u32(keydest, node, "rsa,num-bits", bits);
@@ -728,7 +825,7 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
                                         info->name);
        }
        if (!ret && info->require_keys) {
-               ret = fdt_setprop_string(keydest, node, "required",
+               ret = fdt_setprop_string(keydest, node, FIT_KEY_REQUIRED,
                                         info->require_keys);
        }
 done: