80e817314b5867d692d5b656b45d9b9a5e90ff75
[oweals/u-boot.git] / lib / rsa / rsa-verify.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Copyright (c) 2013, Google Inc.
4  */
5
6 #ifndef USE_HOSTCC
7 #include <common.h>
8 #include <fdtdec.h>
9 #include <malloc.h>
10 #include <asm/types.h>
11 #include <asm/byteorder.h>
12 #include <linux/errno.h>
13 #include <asm/types.h>
14 #include <asm/unaligned.h>
15 #include <dm.h>
16 #else
17 #include "fdt_host.h"
18 #include "mkimage.h"
19 #include <fdt_support.h>
20 #endif
21 #include <linux/kconfig.h>
22 #include <u-boot/rsa-mod-exp.h>
23 #include <u-boot/rsa.h>
24
25 #ifndef __UBOOT__
26 /*
27  * NOTE:
28  * Since host tools, like mkimage, make use of openssl library for
29  * RSA encryption, rsa_verify_with_pkey()/rsa_gen_key_prop() are
30  * of no use and should not be compiled in.
31  * So just turn off CONFIG_RSA_VERIFY_WITH_PKEY.
32  */
33
34 #undef CONFIG_RSA_VERIFY_WITH_PKEY
35 #endif
36
37 /* Default public exponent for backward compatibility */
38 #define RSA_DEFAULT_PUBEXP      65537
39
40 /**
41  * rsa_verify_padding() - Verify RSA message padding is valid
42  *
43  * Verify a RSA message's padding is consistent with PKCS1.5
44  * padding as described in the RSA PKCS#1 v2.1 standard.
45  *
46  * @msg:        Padded message
47  * @pad_len:    Number of expected padding bytes
48  * @algo:       Checksum algo structure having information on DER encoding etc.
49  * @return 0 on success, != 0 on failure
50  */
51 static int rsa_verify_padding(const uint8_t *msg, const int pad_len,
52                               struct checksum_algo *algo)
53 {
54         int ff_len;
55         int ret;
56
57         /* first byte must be 0x00 */
58         ret = *msg++;
59         /* second byte must be 0x01 */
60         ret |= *msg++ ^ 0x01;
61         /* next ff_len bytes must be 0xff */
62         ff_len = pad_len - algo->der_len - 3;
63         ret |= *msg ^ 0xff;
64         ret |= memcmp(msg, msg+1, ff_len-1);
65         msg += ff_len;
66         /* next byte must be 0x00 */
67         ret |= *msg++;
68         /* next der_len bytes must match der_prefix */
69         ret |= memcmp(msg, algo->der_prefix, algo->der_len);
70
71         return ret;
72 }
73
74 int padding_pkcs_15_verify(struct image_sign_info *info,
75                            uint8_t *msg, int msg_len,
76                            const uint8_t *hash, int hash_len)
77 {
78         struct checksum_algo *checksum = info->checksum;
79         int ret, pad_len = msg_len - checksum->checksum_len;
80
81         /* Check pkcs1.5 padding bytes. */
82         ret = rsa_verify_padding(msg, pad_len, checksum);
83         if (ret) {
84                 debug("In RSAVerify(): Padding check failed!\n");
85                 return -EINVAL;
86         }
87
88         /* Check hash. */
89         if (memcmp((uint8_t *)msg + pad_len, hash, msg_len - pad_len)) {
90                 debug("In RSAVerify(): Hash check failed!\n");
91                 return -EACCES;
92         }
93
94         return 0;
95 }
96
97 #ifdef CONFIG_FIT_ENABLE_RSASSA_PSS_SUPPORT
98 static void u32_i2osp(uint32_t val, uint8_t *buf)
99 {
100         buf[0] = (uint8_t)((val >> 24) & 0xff);
101         buf[1] = (uint8_t)((val >> 16) & 0xff);
102         buf[2] = (uint8_t)((val >>  8) & 0xff);
103         buf[3] = (uint8_t)((val >>  0) & 0xff);
104 }
105
106 /**
107  * mask_generation_function1() - generate an octet string
108  *
109  * Generate an octet string used to check rsa signature.
110  * It use an input octet string and a hash function.
111  *
112  * @checksum:   A Hash function
113  * @seed:       Specifies an input variable octet string
114  * @seed_len:   Size of the input octet string
115  * @output:     Specifies the output octet string
116  * @output_len: Size of the output octet string
117  * @return 0 if the octet string was correctly generated, others on error
118  */
119 static int mask_generation_function1(struct checksum_algo *checksum,
120                                      uint8_t *seed, int seed_len,
121                                      uint8_t *output, int output_len)
122 {
123         struct image_region region[2];
124         int ret = 0, i, i_output = 0, region_count = 2;
125         uint32_t counter = 0;
126         uint8_t buf_counter[4], *tmp;
127         int hash_len = checksum->checksum_len;
128
129         memset(output, 0, output_len);
130
131         region[0].data = seed;
132         region[0].size = seed_len;
133         region[1].data = &buf_counter[0];
134         region[1].size = 4;
135
136         tmp = malloc(hash_len);
137         if (!tmp) {
138                 debug("%s: can't allocate array tmp\n", __func__);
139                 ret = -ENOMEM;
140                 goto out;
141         }
142
143         while (i_output < output_len) {
144                 u32_i2osp(counter, &buf_counter[0]);
145
146                 ret = checksum->calculate(checksum->name,
147                                           region, region_count,
148                                           tmp);
149                 if (ret < 0) {
150                         debug("%s: Error in checksum calculation\n", __func__);
151                         goto out;
152                 }
153
154                 i = 0;
155                 while ((i_output < output_len) && (i < hash_len)) {
156                         output[i_output] = tmp[i];
157                         i_output++;
158                         i++;
159                 }
160
161                 counter++;
162         }
163
164 out:
165         free(tmp);
166
167         return ret;
168 }
169
170 static int compute_hash_prime(struct checksum_algo *checksum,
171                               uint8_t *pad, int pad_len,
172                               uint8_t *hash, int hash_len,
173                               uint8_t *salt, int salt_len,
174                               uint8_t *hprime)
175 {
176         struct image_region region[3];
177         int ret, region_count = 3;
178
179         region[0].data = pad;
180         region[0].size = pad_len;
181         region[1].data = hash;
182         region[1].size = hash_len;
183         region[2].data = salt;
184         region[2].size = salt_len;
185
186         ret = checksum->calculate(checksum->name, region, region_count, hprime);
187         if (ret < 0) {
188                 debug("%s: Error in checksum calculation\n", __func__);
189                 goto out;
190         }
191
192 out:
193         return ret;
194 }
195
196 int padding_pss_verify(struct image_sign_info *info,
197                        uint8_t *msg, int msg_len,
198                        const uint8_t *hash, int hash_len)
199 {
200         uint8_t *masked_db = NULL;
201         int masked_db_len = msg_len - hash_len - 1;
202         uint8_t *h = NULL, *hprime = NULL;
203         int h_len = hash_len;
204         uint8_t *db_mask = NULL;
205         int db_mask_len = masked_db_len;
206         uint8_t *db = NULL, *salt = NULL;
207         int db_len = masked_db_len, salt_len = msg_len - hash_len - 2;
208         uint8_t pad_zero[8] = { 0 };
209         int ret, i, leftmost_bits = 1;
210         uint8_t leftmost_mask;
211         struct checksum_algo *checksum = info->checksum;
212
213         /* first, allocate everything */
214         masked_db = malloc(masked_db_len);
215         h = malloc(h_len);
216         db_mask = malloc(db_mask_len);
217         db = malloc(db_len);
218         salt = malloc(salt_len);
219         hprime = malloc(hash_len);
220         if (!masked_db || !h || !db_mask || !db || !salt || !hprime) {
221                 printf("%s: can't allocate some buffer\n", __func__);
222                 ret = -ENOMEM;
223                 goto out;
224         }
225
226         /* step 4: check if the last byte is 0xbc */
227         if (msg[msg_len - 1] != 0xbc) {
228                 printf("%s: invalid pss padding (0xbc is missing)\n", __func__);
229                 ret = -EINVAL;
230                 goto out;
231         }
232
233         /* step 5 */
234         memcpy(masked_db, msg, masked_db_len);
235         memcpy(h, msg + masked_db_len, h_len);
236
237         /* step 6 */
238         leftmost_mask = (0xff >> (8 - leftmost_bits)) << (8 - leftmost_bits);
239         if (masked_db[0] & leftmost_mask) {
240                 printf("%s: invalid pss padding ", __func__);
241                 printf("(leftmost bit of maskedDB not zero)\n");
242                 ret = -EINVAL;
243                 goto out;
244         }
245
246         /* step 7 */
247         mask_generation_function1(checksum, h, h_len, db_mask, db_mask_len);
248
249         /* step 8 */
250         for (i = 0; i < db_len; i++)
251                 db[i] = masked_db[i] ^ db_mask[i];
252
253         /* step 9 */
254         db[0] &= 0xff >> leftmost_bits;
255
256         /* step 10 */
257         if (db[0] != 0x01) {
258                 printf("%s: invalid pss padding ", __func__);
259                 printf("(leftmost byte of db isn't 0x01)\n");
260                 ret = EINVAL;
261                 goto out;
262         }
263
264         /* step 11 */
265         memcpy(salt, &db[1], salt_len);
266
267         /* step 12 & 13 */
268         compute_hash_prime(checksum, pad_zero, 8,
269                            (uint8_t *)hash, hash_len,
270                            salt, salt_len, hprime);
271
272         /* step 14 */
273         ret = memcmp(h, hprime, hash_len);
274
275 out:
276         free(hprime);
277         free(salt);
278         free(db);
279         free(db_mask);
280         free(h);
281         free(masked_db);
282
283         return ret;
284 }
285 #endif
286
287 #if CONFIG_IS_ENABLED(FIT_SIGNATURE) || IS_ENABLED(CONFIG_RSA_VERIFY_WITH_PKEY)
288 /**
289  * rsa_verify_key() - Verify a signature against some data using RSA Key
290  *
291  * Verify a RSA PKCS1.5 signature against an expected hash using
292  * the RSA Key properties in prop structure.
293  *
294  * @info:       Specifies key and FIT information
295  * @prop:       Specifies key
296  * @sig:        Signature
297  * @sig_len:    Number of bytes in signature
298  * @hash:       Pointer to the expected hash
299  * @key_len:    Number of bytes in rsa key
300  * @return 0 if verified, -ve on error
301  */
302 static int rsa_verify_key(struct image_sign_info *info,
303                           struct key_prop *prop, const uint8_t *sig,
304                           const uint32_t sig_len, const uint8_t *hash,
305                           const uint32_t key_len)
306 {
307         int ret;
308 #if !defined(USE_HOSTCC)
309         struct udevice *mod_exp_dev;
310 #endif
311         struct checksum_algo *checksum = info->checksum;
312         struct padding_algo *padding = info->padding;
313         int hash_len;
314
315         if (!prop || !sig || !hash || !checksum)
316                 return -EIO;
317
318         if (sig_len != (prop->num_bits / 8)) {
319                 debug("Signature is of incorrect length %d\n", sig_len);
320                 return -EINVAL;
321         }
322
323         debug("Checksum algorithm: %s", checksum->name);
324
325         /* Sanity check for stack size */
326         if (sig_len > RSA_MAX_SIG_BITS / 8) {
327                 debug("Signature length %u exceeds maximum %d\n", sig_len,
328                       RSA_MAX_SIG_BITS / 8);
329                 return -EINVAL;
330         }
331
332         uint8_t buf[sig_len];
333         hash_len = checksum->checksum_len;
334
335 #if !defined(USE_HOSTCC)
336         ret = uclass_get_device(UCLASS_MOD_EXP, 0, &mod_exp_dev);
337         if (ret) {
338                 printf("RSA: Can't find Modular Exp implementation\n");
339                 return -EINVAL;
340         }
341
342         ret = rsa_mod_exp(mod_exp_dev, sig, sig_len, prop, buf);
343 #else
344         ret = rsa_mod_exp_sw(sig, sig_len, prop, buf);
345 #endif
346         if (ret) {
347                 debug("Error in Modular exponentation\n");
348                 return ret;
349         }
350
351         ret = padding->verify(info, buf, key_len, hash, hash_len);
352         if (ret) {
353                 debug("In RSAVerify(): padding check failed!\n");
354                 return ret;
355         }
356
357         return 0;
358 }
359 #endif
360
361 #ifdef CONFIG_RSA_VERIFY_WITH_PKEY
362 /**
363  * rsa_verify_with_pkey() - Verify a signature against some data using
364  * only modulus and exponent as RSA key properties.
365  * @info:       Specifies key information
366  * @hash:       Pointer to the expected hash
367  * @sig:        Signature
368  * @sig_len:    Number of bytes in signature
369  *
370  * Parse a RSA public key blob in DER format pointed to in @info and fill
371  * a key_prop structure with properties of the key. Then verify a RSA PKCS1.5
372  * signature against an expected hash using the calculated properties.
373  *
374  * Return       0 if verified, -ve on error
375  */
376 static int rsa_verify_with_pkey(struct image_sign_info *info,
377                                 const void *hash, uint8_t *sig, uint sig_len)
378 {
379         struct key_prop *prop;
380         int ret;
381
382         /* Public key is self-described to fill key_prop */
383         ret = rsa_gen_key_prop(info->key, info->keylen, &prop);
384         if (ret) {
385                 debug("Generating necessary parameter for decoding failed\n");
386                 return ret;
387         }
388
389         ret = rsa_verify_key(info, prop, sig, sig_len, hash,
390                              info->crypto->key_len);
391
392         rsa_free_key_prop(prop);
393
394         return ret;
395 }
396 #else
397 static int rsa_verify_with_pkey(struct image_sign_info *info,
398                                 const void *hash, uint8_t *sig, uint sig_len)
399 {
400         return -EACCES;
401 }
402 #endif
403
404 #if CONFIG_IS_ENABLED(FIT_SIGNATURE)
405 /**
406  * rsa_verify_with_keynode() - Verify a signature against some data using
407  * information in node with prperties of RSA Key like modulus, exponent etc.
408  *
409  * Parse sign-node and fill a key_prop structure with properties of the
410  * key.  Verify a RSA PKCS1.5 signature against an expected hash using
411  * the properties parsed
412  *
413  * @info:       Specifies key and FIT information
414  * @hash:       Pointer to the expected hash
415  * @sig:        Signature
416  * @sig_len:    Number of bytes in signature
417  * @node:       Node having the RSA Key properties
418  * @return 0 if verified, -ve on error
419  */
420 static int rsa_verify_with_keynode(struct image_sign_info *info,
421                                    const void *hash, uint8_t *sig,
422                                    uint sig_len, int node)
423 {
424         const void *blob = info->fdt_blob;
425         struct key_prop prop;
426         int length;
427         int ret = 0;
428
429         if (node < 0) {
430                 debug("%s: Skipping invalid node", __func__);
431                 return -EBADF;
432         }
433
434         prop.num_bits = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
435
436         prop.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
437
438         prop.public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length);
439         if (!prop.public_exponent || length < sizeof(uint64_t))
440                 prop.public_exponent = NULL;
441
442         prop.exp_len = sizeof(uint64_t);
443
444         prop.modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
445
446         prop.rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
447
448         if (!prop.num_bits || !prop.modulus) {
449                 debug("%s: Missing RSA key info", __func__);
450                 return -EFAULT;
451         }
452
453         ret = rsa_verify_key(info, &prop, sig, sig_len, hash,
454                              info->crypto->key_len);
455
456         return ret;
457 }
458 #else
459 static int rsa_verify_with_keynode(struct image_sign_info *info,
460                                    const void *hash, uint8_t *sig,
461                                    uint sig_len, int node)
462 {
463         return -EACCES;
464 }
465 #endif
466
467 int rsa_verify(struct image_sign_info *info,
468                const struct image_region region[], int region_count,
469                uint8_t *sig, uint sig_len)
470 {
471         /* Reserve memory for maximum checksum-length */
472         uint8_t hash[info->crypto->key_len];
473         int ret = -EACCES;
474
475         /*
476          * Verify that the checksum-length does not exceed the
477          * rsa-signature-length
478          */
479         if (info->checksum->checksum_len >
480             info->crypto->key_len) {
481                 debug("%s: invlaid checksum-algorithm %s for %s\n",
482                       __func__, info->checksum->name, info->crypto->name);
483                 return -EINVAL;
484         }
485
486         /* Calculate checksum with checksum-algorithm */
487         ret = info->checksum->calculate(info->checksum->name,
488                                         region, region_count, hash);
489         if (ret < 0) {
490                 debug("%s: Error in checksum calculation\n", __func__);
491                 return -EINVAL;
492         }
493
494         if (IS_ENABLED(CONFIG_RSA_VERIFY_WITH_PKEY) && !info->fdt_blob) {
495                 /* don't rely on fdt properties */
496                 ret = rsa_verify_with_pkey(info, hash, sig, sig_len);
497
498                 return ret;
499         }
500
501         if (CONFIG_IS_ENABLED(FIT_SIGNATURE)) {
502                 const void *blob = info->fdt_blob;
503                 int ndepth, noffset;
504                 int sig_node, node;
505                 char name[100];
506
507                 sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME);
508                 if (sig_node < 0) {
509                         debug("%s: No signature node found\n", __func__);
510                         return -ENOENT;
511                 }
512
513                 /* See if we must use a particular key */
514                 if (info->required_keynode != -1) {
515                         ret = rsa_verify_with_keynode(info, hash, sig, sig_len,
516                                                       info->required_keynode);
517                         return ret;
518                 }
519
520                 /* Look for a key that matches our hint */
521                 snprintf(name, sizeof(name), "key-%s", info->keyname);
522                 node = fdt_subnode_offset(blob, sig_node, name);
523                 ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node);
524                 if (!ret)
525                         return ret;
526
527                 /* No luck, so try each of the keys in turn */
528                 for (ndepth = 0, noffset = fdt_next_node(info->fit, sig_node,
529                                                          &ndepth);
530                      (noffset >= 0) && (ndepth > 0);
531                      noffset = fdt_next_node(info->fit, noffset, &ndepth)) {
532                         if (ndepth == 1 && noffset != node) {
533                                 ret = rsa_verify_with_keynode(info, hash,
534                                                               sig, sig_len,
535                                                               noffset);
536                                 if (!ret)
537                                         break;
538                         }
539                 }
540         }
541
542         return ret;
543 }