Teach the RSA implementation about TLS RSA Key Transport
[oweals/openssl.git] / crypto / rsa / rsa_lib.c
1 /*
2  * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9
10 #include <stdio.h>
11 #include <openssl/crypto.h>
12 #include <openssl/core_names.h>
13 #include <openssl/engine.h>
14 #include <openssl/evp.h>
15 #include "internal/cryptlib.h"
16 #include "internal/refcount.h"
17 #include "crypto/bn.h"
18 #include "crypto/evp.h"
19 #include "crypto/rsa.h"
20 #include "rsa_local.h"
21
22 RSA *RSA_new(void)
23 {
24     return RSA_new_method(NULL);
25 }
26
27 const RSA_METHOD *RSA_get_method(const RSA *rsa)
28 {
29     return rsa->meth;
30 }
31
32 int RSA_set_method(RSA *rsa, const RSA_METHOD *meth)
33 {
34     /*
35      * NB: The caller is specifically setting a method, so it's not up to us
36      * to deal with which ENGINE it comes from.
37      */
38     const RSA_METHOD *mtmp;
39     mtmp = rsa->meth;
40     if (mtmp->finish)
41         mtmp->finish(rsa);
42 #ifndef OPENSSL_NO_ENGINE
43     ENGINE_finish(rsa->engine);
44     rsa->engine = NULL;
45 #endif
46     rsa->meth = meth;
47     if (meth->init)
48         meth->init(rsa);
49     return 1;
50 }
51
52 RSA *RSA_new_method(ENGINE *engine)
53 {
54     RSA *ret = OPENSSL_zalloc(sizeof(*ret));
55
56     if (ret == NULL) {
57         RSAerr(RSA_F_RSA_NEW_METHOD, ERR_R_MALLOC_FAILURE);
58         return NULL;
59     }
60
61     ret->references = 1;
62     ret->lock = CRYPTO_THREAD_lock_new();
63     if (ret->lock == NULL) {
64         RSAerr(RSA_F_RSA_NEW_METHOD, ERR_R_MALLOC_FAILURE);
65         OPENSSL_free(ret);
66         return NULL;
67     }
68
69     ret->meth = RSA_get_default_method();
70 #ifndef OPENSSL_NO_ENGINE
71     ret->flags = ret->meth->flags & ~RSA_FLAG_NON_FIPS_ALLOW;
72     if (engine) {
73         if (!ENGINE_init(engine)) {
74             RSAerr(RSA_F_RSA_NEW_METHOD, ERR_R_ENGINE_LIB);
75             goto err;
76         }
77         ret->engine = engine;
78     } else {
79         ret->engine = ENGINE_get_default_RSA();
80     }
81     if (ret->engine) {
82         ret->meth = ENGINE_get_RSA(ret->engine);
83         if (ret->meth == NULL) {
84             RSAerr(RSA_F_RSA_NEW_METHOD, ERR_R_ENGINE_LIB);
85             goto err;
86         }
87     }
88 #endif
89
90     ret->flags = ret->meth->flags & ~RSA_FLAG_NON_FIPS_ALLOW;
91     if (!CRYPTO_new_ex_data(CRYPTO_EX_INDEX_RSA, ret, &ret->ex_data)) {
92         goto err;
93     }
94
95     if ((ret->meth->init != NULL) && !ret->meth->init(ret)) {
96         RSAerr(RSA_F_RSA_NEW_METHOD, ERR_R_INIT_FAIL);
97         goto err;
98     }
99
100     return ret;
101
102  err:
103     RSA_free(ret);
104     return NULL;
105 }
106
107 void RSA_free(RSA *r)
108 {
109     int i;
110
111     if (r == NULL)
112         return;
113
114     CRYPTO_DOWN_REF(&r->references, &i, r->lock);
115     REF_PRINT_COUNT("RSA", r);
116     if (i > 0)
117         return;
118     REF_ASSERT_ISNT(i < 0);
119
120     if (r->meth != NULL && r->meth->finish != NULL)
121         r->meth->finish(r);
122 #ifndef OPENSSL_NO_ENGINE
123     ENGINE_finish(r->engine);
124 #endif
125
126     CRYPTO_free_ex_data(CRYPTO_EX_INDEX_RSA, r, &r->ex_data);
127
128     CRYPTO_THREAD_lock_free(r->lock);
129
130     BN_free(r->n);
131     BN_free(r->e);
132     BN_clear_free(r->d);
133     BN_clear_free(r->p);
134     BN_clear_free(r->q);
135     BN_clear_free(r->dmp1);
136     BN_clear_free(r->dmq1);
137     BN_clear_free(r->iqmp);
138     RSA_PSS_PARAMS_free(r->pss);
139     sk_RSA_PRIME_INFO_pop_free(r->prime_infos, rsa_multip_info_free);
140     BN_BLINDING_free(r->blinding);
141     BN_BLINDING_free(r->mt_blinding);
142     OPENSSL_free(r->bignum_data);
143     OPENSSL_free(r);
144 }
145
146 int RSA_up_ref(RSA *r)
147 {
148     int i;
149
150     if (CRYPTO_UP_REF(&r->references, &i, r->lock) <= 0)
151         return 0;
152
153     REF_PRINT_COUNT("RSA", r);
154     REF_ASSERT_ISNT(i < 2);
155     return i > 1 ? 1 : 0;
156 }
157
158 int RSA_set_ex_data(RSA *r, int idx, void *arg)
159 {
160     return CRYPTO_set_ex_data(&r->ex_data, idx, arg);
161 }
162
163 void *RSA_get_ex_data(const RSA *r, int idx)
164 {
165     return CRYPTO_get_ex_data(&r->ex_data, idx);
166 }
167
168 /*
169  * Define a scaling constant for our fixed point arithmetic.
170  * This value must be a power of two because the base two logarithm code
171  * makes this assumption.  The exponent must also be a multiple of three so
172  * that the scale factor has an exact cube root.  Finally, the scale factor
173  * should not be so large that a multiplication of two scaled numbers
174  * overflows a 64 bit unsigned integer.
175  */
176 static const unsigned int scale = 1 << 18;
177 static const unsigned int cbrt_scale = 1 << (2 * 18 / 3);
178
179 /* Define some constants, none exceed 32 bits */
180 static const unsigned int log_2  = 0x02c5c8;    /* scale * log(2) */
181 static const unsigned int log_e  = 0x05c551;    /* scale * log2(M_E) */
182 static const unsigned int c1_923 = 0x07b126;    /* scale * 1.923 */
183 static const unsigned int c4_690 = 0x12c28f;    /* scale * 4.690 */
184
185 /*
186  * Multiply two scaled integers together and rescale the result.
187  */
188 static ossl_inline uint64_t mul2(uint64_t a, uint64_t b)
189 {
190     return a * b / scale;
191 }
192
193 /*
194  * Calculate the cube root of a 64 bit scaled integer.
195  * Although the cube root of a 64 bit number does fit into a 32 bit unsigned
196  * integer, this is not guaranteed after scaling, so this function has a
197  * 64 bit return.  This uses the shifting nth root algorithm with some
198  * algebraic simplifications.
199  */
200 static uint64_t icbrt64(uint64_t x)
201 {
202     uint64_t r = 0;
203     uint64_t b;
204     int s;
205
206     for (s = 63; s >= 0; s -= 3) {
207         r <<= 1;
208         b = 3 * r * (r + 1) + 1;
209         if ((x >> s) >= b) {
210             x -= b << s;
211             r++;
212         }
213     }
214     return r * cbrt_scale;
215 }
216
217 /*
218  * Calculate the natural logarithm of a 64 bit scaled integer.
219  * This is done by calculating a base two logarithm and scaling.
220  * The maximum logarithm (base 2) is 64 and this reduces base e, so
221  * a 32 bit result should not overflow.  The argument passed must be
222  * greater than unity so we don't need to handle negative results.
223  */
224 static uint32_t ilog_e(uint64_t v)
225 {
226     uint32_t i, r = 0;
227
228     /*
229      * Scale down the value into the range 1 .. 2.
230      *
231      * If fractional numbers need to be processed, another loop needs
232      * to go here that checks v < scale and if so multiplies it by 2 and
233      * reduces r by scale.  This also means making r signed.
234      */
235     while (v >= 2 * scale) {
236         v >>= 1;
237         r += scale;
238     }
239     for (i = scale / 2; i != 0; i /= 2) {
240         v = mul2(v, v);
241         if (v >= 2 * scale) {
242             v >>= 1;
243             r += i;
244         }
245     }
246     r = (r * (uint64_t)scale) / log_e;
247     return r;
248 }
249
250 /*
251  * NIST SP 800-56B rev 2 Appendix D: Maximum Security Strength Estimates for IFC
252  * Modulus Lengths.
253  *
254  * E = \frac{1.923 \sqrt[3]{nBits \cdot log_e(2)}
255  *           \cdot(log_e(nBits \cdot log_e(2))^{2/3} - 4.69}{log_e(2)}
256  * The two cube roots are merged together here.
257  */
258 uint16_t rsa_compute_security_bits(int n)
259 {
260     uint64_t x;
261     uint32_t lx;
262     uint16_t y;
263
264     /* Look for common values as listed in SP 800-56B rev 2 Appendix D */
265     switch (n) {
266     case 2048:
267         return 112;
268     case 3072:
269         return 128;
270     case 4096:
271         return 152;
272     case 6144:
273         return 176;
274     case 8192:
275         return 200;
276     }
277     /*
278      * The first incorrect result (i.e. not accurate or off by one low) occurs
279      * for n = 699668.  The true value here is 1200.  Instead of using this n
280      * as the check threshold, the smallest n such that the correct result is
281      * 1200 is used instead.
282      */
283     if (n >= 687737)
284         return 1200;
285     if (n < 8)
286         return 0;
287
288     x = n * (uint64_t)log_2;
289     lx = ilog_e(x);
290     y = (uint16_t)((mul2(c1_923, icbrt64(mul2(mul2(x, lx), lx))) - c4_690)
291                    / log_2);
292     return (y + 4) & ~7;
293 }
294
295 int RSA_security_bits(const RSA *rsa)
296 {
297     int bits = BN_num_bits(rsa->n);
298
299     if (rsa->version == RSA_ASN1_VERSION_MULTI) {
300         /* This ought to mean that we have private key at hand. */
301         int ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos);
302
303         if (ex_primes <= 0 || (ex_primes + 2) > rsa_multip_cap(bits))
304             return 0;
305     }
306     return rsa_compute_security_bits(bits);
307 }
308
309 int RSA_set0_key(RSA *r, BIGNUM *n, BIGNUM *e, BIGNUM *d)
310 {
311     /* If the fields n and e in r are NULL, the corresponding input
312      * parameters MUST be non-NULL for n and e.  d may be
313      * left NULL (in case only the public key is used).
314      */
315     if ((r->n == NULL && n == NULL)
316         || (r->e == NULL && e == NULL))
317         return 0;
318
319     if (n != NULL) {
320         BN_free(r->n);
321         r->n = n;
322     }
323     if (e != NULL) {
324         BN_free(r->e);
325         r->e = e;
326     }
327     if (d != NULL) {
328         BN_clear_free(r->d);
329         r->d = d;
330         BN_set_flags(r->d, BN_FLG_CONSTTIME);
331     }
332     r->dirty_cnt++;
333
334     return 1;
335 }
336
337 int RSA_set0_factors(RSA *r, BIGNUM *p, BIGNUM *q)
338 {
339     /* If the fields p and q in r are NULL, the corresponding input
340      * parameters MUST be non-NULL.
341      */
342     if ((r->p == NULL && p == NULL)
343         || (r->q == NULL && q == NULL))
344         return 0;
345
346     if (p != NULL) {
347         BN_clear_free(r->p);
348         r->p = p;
349         BN_set_flags(r->p, BN_FLG_CONSTTIME);
350     }
351     if (q != NULL) {
352         BN_clear_free(r->q);
353         r->q = q;
354         BN_set_flags(r->q, BN_FLG_CONSTTIME);
355     }
356     r->dirty_cnt++;
357
358     return 1;
359 }
360
361 int RSA_set0_crt_params(RSA *r, BIGNUM *dmp1, BIGNUM *dmq1, BIGNUM *iqmp)
362 {
363     /* If the fields dmp1, dmq1 and iqmp in r are NULL, the corresponding input
364      * parameters MUST be non-NULL.
365      */
366     if ((r->dmp1 == NULL && dmp1 == NULL)
367         || (r->dmq1 == NULL && dmq1 == NULL)
368         || (r->iqmp == NULL && iqmp == NULL))
369         return 0;
370
371     if (dmp1 != NULL) {
372         BN_clear_free(r->dmp1);
373         r->dmp1 = dmp1;
374         BN_set_flags(r->dmp1, BN_FLG_CONSTTIME);
375     }
376     if (dmq1 != NULL) {
377         BN_clear_free(r->dmq1);
378         r->dmq1 = dmq1;
379         BN_set_flags(r->dmq1, BN_FLG_CONSTTIME);
380     }
381     if (iqmp != NULL) {
382         BN_clear_free(r->iqmp);
383         r->iqmp = iqmp;
384         BN_set_flags(r->iqmp, BN_FLG_CONSTTIME);
385     }
386     r->dirty_cnt++;
387
388     return 1;
389 }
390
391 /*
392  * Is it better to export RSA_PRIME_INFO structure
393  * and related functions to let user pass a triplet?
394  */
395 int RSA_set0_multi_prime_params(RSA *r, BIGNUM *primes[], BIGNUM *exps[],
396                                 BIGNUM *coeffs[], int pnum)
397 {
398     STACK_OF(RSA_PRIME_INFO) *prime_infos, *old = NULL;
399     RSA_PRIME_INFO *pinfo;
400     int i;
401
402     if (primes == NULL || exps == NULL || coeffs == NULL || pnum == 0)
403         return 0;
404
405     prime_infos = sk_RSA_PRIME_INFO_new_reserve(NULL, pnum);
406     if (prime_infos == NULL)
407         return 0;
408
409     if (r->prime_infos != NULL)
410         old = r->prime_infos;
411
412     for (i = 0; i < pnum; i++) {
413         pinfo = rsa_multip_info_new();
414         if (pinfo == NULL)
415             goto err;
416         if (primes[i] != NULL && exps[i] != NULL && coeffs[i] != NULL) {
417             BN_clear_free(pinfo->r);
418             BN_clear_free(pinfo->d);
419             BN_clear_free(pinfo->t);
420             pinfo->r = primes[i];
421             pinfo->d = exps[i];
422             pinfo->t = coeffs[i];
423             BN_set_flags(pinfo->r, BN_FLG_CONSTTIME);
424             BN_set_flags(pinfo->d, BN_FLG_CONSTTIME);
425             BN_set_flags(pinfo->t, BN_FLG_CONSTTIME);
426         } else {
427             rsa_multip_info_free(pinfo);
428             goto err;
429         }
430         (void)sk_RSA_PRIME_INFO_push(prime_infos, pinfo);
431     }
432
433     r->prime_infos = prime_infos;
434
435     if (!rsa_multip_calc_product(r)) {
436         r->prime_infos = old;
437         goto err;
438     }
439
440     if (old != NULL) {
441         /*
442          * This is hard to deal with, since the old infos could
443          * also be set by this function and r, d, t should not
444          * be freed in that case. So currently, stay consistent
445          * with other *set0* functions: just free it...
446          */
447         sk_RSA_PRIME_INFO_pop_free(old, rsa_multip_info_free);
448     }
449
450     r->version = RSA_ASN1_VERSION_MULTI;
451     r->dirty_cnt++;
452
453     return 1;
454  err:
455     /* r, d, t should not be freed */
456     sk_RSA_PRIME_INFO_pop_free(prime_infos, rsa_multip_info_free_ex);
457     return 0;
458 }
459
460 void RSA_get0_key(const RSA *r,
461                   const BIGNUM **n, const BIGNUM **e, const BIGNUM **d)
462 {
463     if (n != NULL)
464         *n = r->n;
465     if (e != NULL)
466         *e = r->e;
467     if (d != NULL)
468         *d = r->d;
469 }
470
471 void RSA_get0_factors(const RSA *r, const BIGNUM **p, const BIGNUM **q)
472 {
473     if (p != NULL)
474         *p = r->p;
475     if (q != NULL)
476         *q = r->q;
477 }
478
479 int RSA_get_multi_prime_extra_count(const RSA *r)
480 {
481     int pnum;
482
483     pnum = sk_RSA_PRIME_INFO_num(r->prime_infos);
484     if (pnum <= 0)
485         pnum = 0;
486     return pnum;
487 }
488
489 int RSA_get0_multi_prime_factors(const RSA *r, const BIGNUM *primes[])
490 {
491     int pnum, i;
492     RSA_PRIME_INFO *pinfo;
493
494     if ((pnum = RSA_get_multi_prime_extra_count(r)) == 0)
495         return 0;
496
497     /*
498      * return other primes
499      * it's caller's responsibility to allocate oth_primes[pnum]
500      */
501     for (i = 0; i < pnum; i++) {
502         pinfo = sk_RSA_PRIME_INFO_value(r->prime_infos, i);
503         primes[i] = pinfo->r;
504     }
505
506     return 1;
507 }
508
509 void RSA_get0_crt_params(const RSA *r,
510                          const BIGNUM **dmp1, const BIGNUM **dmq1,
511                          const BIGNUM **iqmp)
512 {
513     if (dmp1 != NULL)
514         *dmp1 = r->dmp1;
515     if (dmq1 != NULL)
516         *dmq1 = r->dmq1;
517     if (iqmp != NULL)
518         *iqmp = r->iqmp;
519 }
520
521 int RSA_get0_multi_prime_crt_params(const RSA *r, const BIGNUM *exps[],
522                                     const BIGNUM *coeffs[])
523 {
524     int pnum;
525
526     if ((pnum = RSA_get_multi_prime_extra_count(r)) == 0)
527         return 0;
528
529     /* return other primes */
530     if (exps != NULL || coeffs != NULL) {
531         RSA_PRIME_INFO *pinfo;
532         int i;
533
534         /* it's the user's job to guarantee the buffer length */
535         for (i = 0; i < pnum; i++) {
536             pinfo = sk_RSA_PRIME_INFO_value(r->prime_infos, i);
537             if (exps != NULL)
538                 exps[i] = pinfo->d;
539             if (coeffs != NULL)
540                 coeffs[i] = pinfo->t;
541         }
542     }
543
544     return 1;
545 }
546
547 const BIGNUM *RSA_get0_n(const RSA *r)
548 {
549     return r->n;
550 }
551
552 const BIGNUM *RSA_get0_e(const RSA *r)
553 {
554     return r->e;
555 }
556
557 const BIGNUM *RSA_get0_d(const RSA *r)
558 {
559     return r->d;
560 }
561
562 const BIGNUM *RSA_get0_p(const RSA *r)
563 {
564     return r->p;
565 }
566
567 const BIGNUM *RSA_get0_q(const RSA *r)
568 {
569     return r->q;
570 }
571
572 const BIGNUM *RSA_get0_dmp1(const RSA *r)
573 {
574     return r->dmp1;
575 }
576
577 const BIGNUM *RSA_get0_dmq1(const RSA *r)
578 {
579     return r->dmq1;
580 }
581
582 const BIGNUM *RSA_get0_iqmp(const RSA *r)
583 {
584     return r->iqmp;
585 }
586
587 const RSA_PSS_PARAMS *RSA_get0_pss_params(const RSA *r)
588 {
589     return r->pss;
590 }
591
592 void RSA_clear_flags(RSA *r, int flags)
593 {
594     r->flags &= ~flags;
595 }
596
597 int RSA_test_flags(const RSA *r, int flags)
598 {
599     return r->flags & flags;
600 }
601
602 void RSA_set_flags(RSA *r, int flags)
603 {
604     r->flags |= flags;
605 }
606
607 int RSA_get_version(RSA *r)
608 {
609     /* { two-prime(0), multi(1) } */
610     return r->version;
611 }
612
613 ENGINE *RSA_get0_engine(const RSA *r)
614 {
615     return r->engine;
616 }
617
618 int RSA_pkey_ctx_ctrl(EVP_PKEY_CTX *ctx, int optype, int cmd, int p1, void *p2)
619 {
620     /* If key type not RSA or RSA-PSS return error */
621     if (ctx != NULL && ctx->pmeth != NULL
622         && ctx->pmeth->pkey_id != EVP_PKEY_RSA
623         && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
624         return -1;
625      return EVP_PKEY_CTX_ctrl(ctx, -1, optype, cmd, p1, p2);
626 }
627
628 DEFINE_STACK_OF(BIGNUM)
629
630 int rsa_set0_all_params(RSA *r, const STACK_OF(BIGNUM) *primes,
631                         const STACK_OF(BIGNUM) *exps,
632                         const STACK_OF(BIGNUM) *coeffs)
633 {
634     STACK_OF(RSA_PRIME_INFO) *prime_infos, *old_infos = NULL;
635     int pnum;
636
637     if (primes == NULL || exps == NULL || coeffs == NULL)
638         return 0;
639
640     pnum = sk_BIGNUM_num(primes);
641     if (pnum < 2
642         || pnum != sk_BIGNUM_num(exps)
643         || pnum != sk_BIGNUM_num(coeffs) + 1)
644         return 0;
645
646     if (!RSA_set0_factors(r, sk_BIGNUM_value(primes, 0),
647                           sk_BIGNUM_value(primes, 1))
648         || !RSA_set0_crt_params(r, sk_BIGNUM_value(exps, 0),
649                                 sk_BIGNUM_value(exps, 1),
650                                 sk_BIGNUM_value(coeffs, 0)))
651         return 0;
652
653     old_infos = r->prime_infos;
654
655     if (pnum > 2) {
656         int i;
657
658         prime_infos = sk_RSA_PRIME_INFO_new_reserve(NULL, pnum);
659         if (prime_infos == NULL)
660             return 0;
661
662         for (i = 2; i < pnum; i++) {
663             BIGNUM *prime = sk_BIGNUM_value(primes, i);
664             BIGNUM *exp = sk_BIGNUM_value(exps, i);
665             BIGNUM *coeff = sk_BIGNUM_value(coeffs, i - 1);
666             RSA_PRIME_INFO *pinfo = NULL;
667
668             if (!ossl_assert(prime != NULL && exp != NULL && coeff != NULL))
669                 goto err;
670
671             /* Using rsa_multip_info_new() is wasteful, so allocate directly */
672             if ((pinfo = OPENSSL_zalloc(sizeof(*pinfo))) == NULL) {
673                 ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
674                 goto err;
675             }
676
677             pinfo->r = prime;
678             pinfo->d = exp;
679             pinfo->t = coeff;
680             BN_set_flags(pinfo->r, BN_FLG_CONSTTIME);
681             BN_set_flags(pinfo->d, BN_FLG_CONSTTIME);
682             BN_set_flags(pinfo->t, BN_FLG_CONSTTIME);
683             (void)sk_RSA_PRIME_INFO_push(prime_infos, pinfo);
684         }
685
686         r->prime_infos = prime_infos;
687
688         if (!rsa_multip_calc_product(r)) {
689             r->prime_infos = old_infos;
690             goto err;
691         }
692     }
693
694     if (old_infos != NULL) {
695         /*
696          * This is hard to deal with, since the old infos could
697          * also be set by this function and r, d, t should not
698          * be freed in that case. So currently, stay consistent
699          * with other *set0* functions: just free it...
700          */
701         sk_RSA_PRIME_INFO_pop_free(old_infos, rsa_multip_info_free);
702     }
703
704     r->version = pnum > 2 ? RSA_ASN1_VERSION_MULTI : RSA_ASN1_VERSION_DEFAULT;
705     r->dirty_cnt++;
706
707     return 1;
708  err:
709     /* r, d, t should not be freed */
710     sk_RSA_PRIME_INFO_pop_free(prime_infos, rsa_multip_info_free_ex);
711     return 0;
712 }
713
714 DEFINE_SPECIAL_STACK_OF_CONST(BIGNUM_const, BIGNUM)
715
716 int rsa_get0_all_params(RSA *r, STACK_OF(BIGNUM_const) *primes,
717                         STACK_OF(BIGNUM_const) *exps,
718                         STACK_OF(BIGNUM_const) *coeffs)
719 {
720     RSA_PRIME_INFO *pinfo;
721     int i, pnum;
722
723     if (r == NULL)
724         return 0;
725
726     pnum = RSA_get_multi_prime_extra_count(r);
727
728     sk_BIGNUM_const_push(primes, RSA_get0_p(r));
729     sk_BIGNUM_const_push(primes, RSA_get0_q(r));
730     sk_BIGNUM_const_push(exps, RSA_get0_dmp1(r));
731     sk_BIGNUM_const_push(exps, RSA_get0_dmq1(r));
732     sk_BIGNUM_const_push(coeffs, RSA_get0_iqmp(r));
733     for (i = 0; i < pnum; i++) {
734         pinfo = sk_RSA_PRIME_INFO_value(r->prime_infos, i);
735         sk_BIGNUM_const_push(primes, pinfo->r);
736         sk_BIGNUM_const_push(exps, pinfo->d);
737         sk_BIGNUM_const_push(coeffs, pinfo->t);
738     }
739
740     return 1;
741 }
742
743 int EVP_PKEY_CTX_set_rsa_padding(EVP_PKEY_CTX *ctx, int pad_mode)
744 {
745     OSSL_PARAM pad_params[2], *p = pad_params;
746
747     if (ctx == NULL) {
748         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
749         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
750         return -2;
751     }
752
753     /* If key type not RSA or RSA-PSS return error */
754     if (ctx->pmeth != NULL
755             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
756             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
757         return -1;
758
759     /* TODO(3.0): Remove this eventually when no more legacy */
760     if (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
761             || ctx->op.ciph.ciphprovctx == NULL)
762         return EVP_PKEY_CTX_ctrl(ctx, -1, -1, EVP_PKEY_CTRL_RSA_PADDING,
763                                  pad_mode, NULL);
764
765     *p++ = OSSL_PARAM_construct_int(OSSL_ASYM_CIPHER_PARAM_PAD_MODE, &pad_mode);
766     *p++ = OSSL_PARAM_construct_end();
767
768     return EVP_PKEY_CTX_set_params(ctx, pad_params);
769 }
770
771 int EVP_PKEY_CTX_get_rsa_padding(EVP_PKEY_CTX *ctx, int *pad_mode)
772 {
773     OSSL_PARAM pad_params[2], *p = pad_params;
774
775     if (ctx == NULL || pad_mode == NULL) {
776         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
777         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
778         return -2;
779     }
780
781     /* If key type not RSA or RSA-PSS return error */
782     if (ctx->pmeth != NULL
783             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
784             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
785         return -1;
786
787     /* TODO(3.0): Remove this eventually when no more legacy */
788     if (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
789             || ctx->op.ciph.ciphprovctx == NULL)
790         return EVP_PKEY_CTX_ctrl(ctx, -1, -1, EVP_PKEY_CTRL_GET_RSA_PADDING, 0,
791                                  pad_mode);
792
793     *p++ = OSSL_PARAM_construct_int(OSSL_ASYM_CIPHER_PARAM_PAD_MODE, pad_mode);
794     *p++ = OSSL_PARAM_construct_end();
795
796     if (!EVP_PKEY_CTX_get_params(ctx, pad_params))
797         return 0;
798
799     return 1;
800
801 }
802
803 int EVP_PKEY_CTX_set_rsa_oaep_md(EVP_PKEY_CTX *ctx, const EVP_MD *md)
804 {
805     const char *name;
806
807     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
808         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
809         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
810         return -2;
811     }
812
813     /* If key type not RSA return error */
814     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
815         return -1;
816
817     /* TODO(3.0): Remove this eventually when no more legacy */
818     if (ctx->op.ciph.ciphprovctx == NULL)
819         return EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA, EVP_PKEY_OP_TYPE_CRYPT,
820                                  EVP_PKEY_CTRL_RSA_OAEP_MD, 0, (void *)md);
821
822     name = (md == NULL) ? "" : EVP_MD_name(md);
823
824     return EVP_PKEY_CTX_set_rsa_oaep_md_name(ctx, name, NULL);
825 }
826
827 int EVP_PKEY_CTX_set_rsa_oaep_md_name(EVP_PKEY_CTX *ctx, const char *mdname,
828                                       const char *mdprops)
829 {
830     OSSL_PARAM rsa_params[3], *p = rsa_params;
831
832     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
833         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
834         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
835         return -2;
836     }
837
838     /* If key type not RSA return error */
839     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
840         return -1;
841
842
843     *p++ = OSSL_PARAM_construct_utf8_string(OSSL_ASYM_CIPHER_PARAM_OAEP_DIGEST,
844                                             /*
845                                              * Cast away the const. This is read
846                                              * only so should be safe
847                                              */
848                                             (char *)mdname,
849                                             strlen(mdname) + 1);
850     if (mdprops != NULL) {
851         *p++ = OSSL_PARAM_construct_utf8_string(
852                     OSSL_ASYM_CIPHER_PARAM_OAEP_DIGEST_PROPS,
853                     /*
854                      * Cast away the const. This is read
855                      * only so should be safe
856                      */
857                     (char *)mdprops,
858                     strlen(mdprops) + 1);
859     }
860     *p++ = OSSL_PARAM_construct_end();
861
862     return EVP_PKEY_CTX_set_params(ctx, rsa_params);
863 }
864
865 int EVP_PKEY_CTX_get_rsa_oaep_md_name(EVP_PKEY_CTX *ctx, char *name,
866                                       size_t namelen)
867 {
868     OSSL_PARAM rsa_params[2], *p = rsa_params;
869
870     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
871         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
872         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
873         return -2;
874     }
875
876     /* If key type not RSA return error */
877     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
878         return -1;
879
880     *p++ = OSSL_PARAM_construct_utf8_string(OSSL_ASYM_CIPHER_PARAM_OAEP_DIGEST,
881                                             name, namelen);
882     *p++ = OSSL_PARAM_construct_end();
883
884     if (!EVP_PKEY_CTX_get_params(ctx, rsa_params))
885         return -1;
886
887     return 1;
888 }
889
890 int EVP_PKEY_CTX_get_rsa_oaep_md(EVP_PKEY_CTX *ctx, const EVP_MD **md)
891 {
892     /* 80 should be big enough */
893     char name[80] = "";
894
895     if (ctx == NULL || md == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
896         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
897         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
898         return -2;
899     }
900
901     /* If key type not RSA return error */
902     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
903         return -1;
904
905     /* TODO(3.0): Remove this eventually when no more legacy */
906     if (ctx->op.ciph.ciphprovctx == NULL)
907         return EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA, EVP_PKEY_OP_TYPE_CRYPT,
908                                  EVP_PKEY_CTRL_GET_RSA_OAEP_MD, 0, (void *)md);
909
910     if (EVP_PKEY_CTX_get_rsa_oaep_md_name(ctx, name, sizeof(name)) <= 0)
911         return -1;
912
913     /* May be NULL meaning "unknown" */
914     *md = EVP_get_digestbyname(name);
915
916     return 1;
917 }
918
919 int EVP_PKEY_CTX_set_rsa_mgf1_md(EVP_PKEY_CTX *ctx, const EVP_MD *md)
920 {
921     const char *name;
922
923     if (ctx == NULL
924             || (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
925                 && !EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx))) {
926         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
927         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
928         return -2;
929     }
930
931     /* If key type not RSA return error */
932     if (ctx->pmeth != NULL
933             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
934             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
935         return -1;
936
937     /* TODO(3.0): Remove this eventually when no more legacy */
938     if ((EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
939                 && ctx->op.ciph.ciphprovctx == NULL)
940             || (EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx)
941                 && ctx->op.sig.sigprovctx == NULL))
942         return EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA,
943                                  EVP_PKEY_OP_TYPE_SIG | EVP_PKEY_OP_TYPE_CRYPT,
944                                  EVP_PKEY_CTRL_RSA_MGF1_MD, 0, (void *)md);
945
946     name = (md == NULL) ? "" : EVP_MD_name(md);
947
948     return EVP_PKEY_CTX_set_rsa_mgf1_md_name(ctx, name, NULL);
949 }
950
951 int EVP_PKEY_CTX_set_rsa_mgf1_md_name(EVP_PKEY_CTX *ctx, const char *mdname,
952                                       const char *mdprops)
953 {
954     OSSL_PARAM rsa_params[3], *p = rsa_params;
955
956     if (ctx == NULL
957             || mdname == NULL
958             || (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
959                 && !EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx))) {
960         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
961         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
962         return -2;
963     }
964
965     /* If key type not RSA return error */
966     if (ctx->pmeth != NULL
967             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
968             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
969         return -1;
970
971     *p++ = OSSL_PARAM_construct_utf8_string(OSSL_ASYM_CIPHER_PARAM_MGF1_DIGEST,
972                                             /*
973                                              * Cast away the const. This is read
974                                              * only so should be safe
975                                              */
976                                             (char *)mdname,
977                                             strlen(mdname) + 1);
978     if (mdprops != NULL) {
979         *p++ = OSSL_PARAM_construct_utf8_string(
980                     OSSL_ASYM_CIPHER_PARAM_MGF1_DIGEST_PROPS,
981                     /*
982                      * Cast away the const. This is read
983                      * only so should be safe
984                      */
985                     (char *)mdprops,
986                     strlen(mdprops) + 1);
987     }
988     *p++ = OSSL_PARAM_construct_end();
989
990     return EVP_PKEY_CTX_set_params(ctx, rsa_params);
991 }
992
993 int EVP_PKEY_CTX_get_rsa_mgf1_md_name(EVP_PKEY_CTX *ctx, char *name,
994                                       size_t namelen)
995 {
996     OSSL_PARAM rsa_params[2], *p = rsa_params;
997
998     if (ctx == NULL
999             || (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
1000                 && !EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx))) {
1001         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
1002         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
1003         return -2;
1004     }
1005
1006     /* If key type not RSA or RSA-PSS return error */
1007     if (ctx->pmeth != NULL
1008             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
1009             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
1010         return -1;
1011
1012     *p++ = OSSL_PARAM_construct_utf8_string(OSSL_ASYM_CIPHER_PARAM_MGF1_DIGEST,
1013                                             name, namelen);
1014     *p++ = OSSL_PARAM_construct_end();
1015
1016     if (!EVP_PKEY_CTX_get_params(ctx, rsa_params))
1017         return -1;
1018
1019     return 1;
1020 }
1021
1022 int EVP_PKEY_CTX_get_rsa_mgf1_md(EVP_PKEY_CTX *ctx, const EVP_MD **md)
1023 {
1024     /* 80 should be big enough */
1025     char name[80] = "";
1026
1027     if (ctx == NULL
1028             || (!EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
1029                 && !EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx))) {
1030         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
1031         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
1032         return -2;
1033     }
1034
1035     /* If key type not RSA or RSA-PSS return error */
1036     if (ctx->pmeth != NULL
1037             && ctx->pmeth->pkey_id != EVP_PKEY_RSA
1038             && ctx->pmeth->pkey_id != EVP_PKEY_RSA_PSS)
1039         return -1;
1040
1041     /* TODO(3.0): Remove this eventually when no more legacy */
1042     if ((EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)
1043                 && ctx->op.ciph.ciphprovctx == NULL)
1044             || (EVP_PKEY_CTX_IS_SIGNATURE_OP(ctx)
1045                 && ctx->op.sig.sigprovctx == NULL))
1046         return EVP_PKEY_CTX_ctrl(ctx, -1,
1047                                  EVP_PKEY_OP_TYPE_SIG | EVP_PKEY_OP_TYPE_CRYPT,
1048                                  EVP_PKEY_CTRL_GET_RSA_MGF1_MD, 0, (void *)md);
1049
1050     if (EVP_PKEY_CTX_get_rsa_mgf1_md_name(ctx, name, sizeof(name)) <= 0)
1051         return -1;
1052
1053     /* May be NULL meaning "unknown" */
1054     *md = EVP_get_digestbyname(name);
1055
1056     return 1;
1057 }
1058
1059 int EVP_PKEY_CTX_set0_rsa_oaep_label(EVP_PKEY_CTX *ctx, void *label, int llen)
1060 {
1061     OSSL_PARAM rsa_params[2], *p = rsa_params;
1062
1063     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
1064         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
1065         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
1066         return -2;
1067     }
1068
1069     /* If key type not RSA return error */
1070     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
1071         return -1;
1072
1073     /* TODO(3.0): Remove this eventually when no more legacy */
1074     if (ctx->op.ciph.ciphprovctx == NULL)
1075         return EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA, EVP_PKEY_OP_TYPE_CRYPT,
1076                                  EVP_PKEY_CTRL_RSA_OAEP_LABEL, llen,
1077                                  (void *)label);
1078
1079     *p++ = OSSL_PARAM_construct_octet_string(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL,
1080                                             /*
1081                                              * Cast away the const. This is read
1082                                              * only so should be safe
1083                                              */
1084                                             (void *)label,
1085                                             (size_t)llen);
1086     *p++ = OSSL_PARAM_construct_end();
1087
1088     if (!EVP_PKEY_CTX_set_params(ctx, rsa_params))
1089         return 0;
1090
1091     OPENSSL_free(label);
1092     return 1;
1093 }
1094
1095 int EVP_PKEY_CTX_get0_rsa_oaep_label(EVP_PKEY_CTX *ctx, unsigned char **label)
1096 {
1097     OSSL_PARAM rsa_params[3], *p = rsa_params;
1098     size_t labellen;
1099
1100     if (ctx == NULL || !EVP_PKEY_CTX_IS_ASYM_CIPHER_OP(ctx)) {
1101         ERR_raise(ERR_LIB_EVP, EVP_R_COMMAND_NOT_SUPPORTED);
1102         /* Uses the same return values as EVP_PKEY_CTX_ctrl */
1103         return -2;
1104     }
1105
1106     /* If key type not RSA return error */
1107     if (ctx->pmeth != NULL && ctx->pmeth->pkey_id != EVP_PKEY_RSA)
1108         return -1;
1109
1110     /* TODO(3.0): Remove this eventually when no more legacy */
1111     if (ctx->op.ciph.ciphprovctx == NULL)
1112         return EVP_PKEY_CTX_ctrl(ctx, EVP_PKEY_RSA, EVP_PKEY_OP_TYPE_CRYPT,
1113                                  EVP_PKEY_CTRL_GET_RSA_OAEP_LABEL, 0,
1114                                  (void *)label);
1115
1116     *p++ = OSSL_PARAM_construct_octet_ptr(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL,
1117                                           (void **)label, 0);
1118     *p++ = OSSL_PARAM_construct_size_t(OSSL_ASYM_CIPHER_PARAM_OAEP_LABEL_LEN,
1119                                        &labellen);
1120     *p++ = OSSL_PARAM_construct_end();
1121
1122     if (!EVP_PKEY_CTX_get_params(ctx, rsa_params))
1123         return -1;
1124
1125     if (labellen > INT_MAX)
1126         return -1;
1127
1128     return (int)labellen;
1129 }