Change submitted files so that they compile (in particular,
[oweals/openssl.git] / crypto / bn / bn_mont2.c
1 /*
2  *
3  *      bn_mont2.c
4  *
5  *      Montgomery Modular Arithmetic Functions.
6  *
7  *      Copyright (C) Lenka Fibikova 2000
8  *
9  *
10  */
11
12
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <assert.h>
16
17 #include "bn.h"
18 #include "bn_modfs.h"
19 #include "bn_mont2.h"
20
21 #define BN_mask_word(x, m) ((x->d[0]) & (m))
22
23 BN_MONTGOMERY *BN_mont_new()
24         {
25         BN_MONTGOMERY *ret;
26
27         ret=(BN_MONTGOMERY *)malloc(sizeof(BN_MONTGOMERY));
28
29         if (ret == NULL) return NULL;
30
31         if ((ret->p = BN_new()) == NULL)
32                 {
33                 free(ret);
34                 return NULL;
35                 }
36
37         return ret;
38         }
39
40
41 void BN_mont_clear_free(BN_MONTGOMERY *mont)
42         {
43         if (mont == NULL) return;
44
45         if (mont->p != NULL) BN_clear_free(mont->p);
46
47         mont->p_num_bytes = 0;
48         mont->R_num_bits = 0;
49         mont->p_inv_b_neg = 0;
50         }
51
52
53 int BN_to_mont(BIGNUM *x, BN_MONTGOMERY *mont, BN_CTX *ctx)
54         {
55         assert(x != NULL);
56
57         assert(mont != NULL);
58         assert(mont->p != NULL);
59
60         assert(ctx != NULL);
61
62         if (!BN_lshift(x, x, mont->R_num_bits)) return 0;
63         if (!BN_mod(x, x, mont->p, ctx)) return 0;
64
65         return 1;
66         }
67
68
69 static BN_ULONG BN_mont_inv(BIGNUM *a, int e, BN_CTX *ctx)
70 /* y = a^{-1} (mod 2^e) for an odd number a */
71         {
72         BN_ULONG y, exp, mask;
73         BIGNUM *x, *xy, *x_sh;
74         int i;
75
76         assert(a != NULL && ctx != NULL);
77         assert(e <= BN_BITS2);
78         assert(BN_is_odd(a));
79         assert(!BN_is_zero(a) && !a->neg);
80
81
82         y = 1;
83         exp = 2;
84         mask = 3;
85         if((x = BN_dup(a)) == NULL) return 0;
86         if(!BN_mask_bits(x, e)) return 0;
87
88         BN_CTX_start(ctx);
89         xy = BN_CTX_get(ctx);
90         x_sh = BN_CTX_get(ctx);
91         if (x_sh == NULL) goto err;
92
93         if (BN_copy(xy, x) == NULL) goto err;
94         if (!BN_lshift1(x_sh, x)) goto err;
95
96
97         for (i = 2; i <= e; i++)
98                 {
99                 if (exp < BN_mask_word(xy, mask))
100                         {
101                         y = y + exp;
102                         if (!BN_add(xy, xy, x_sh)) goto err;
103                         }
104
105                 exp <<= 1;
106                 if (!BN_lshift1(x_sh, x_sh)) goto err;
107                 mask <<= 1;
108                 mask++;
109                 }
110
111
112 #ifdef TEST
113         if (xy->d[0] != 1) goto err;
114 #endif
115
116         if (x != NULL) BN_clear_free(x);
117         BN_CTX_end(ctx);
118         return y;
119
120
121 err:
122         if (x != NULL) BN_clear_free(x);
123         BN_CTX_end(ctx);
124         return 0;
125         }
126
127
128 int BN_mont_set(BIGNUM *p, BN_MONTGOMERY *mont, BN_CTX *ctx)
129         {
130         assert(p != NULL && ctx != NULL);
131         assert(mont != NULL);
132         assert(mont->p != NULL);
133         assert(!BN_is_zero(p) && !p->neg);
134
135
136         mont->p_num_bytes = p->top;
137         mont->R_num_bits = (mont->p_num_bytes) * BN_BITS2;
138
139         if (BN_copy(mont->p, p) == NULL);
140         
141         mont->p_inv_b_neg =  BN_mont_inv(p, BN_BITS2, ctx);
142         mont->p_inv_b_neg = 0 - mont->p_inv_b_neg;
143
144         return 1;
145         }
146
147
148 static int BN_cpy_mul_word(BIGNUM *ret, BIGNUM *a, BN_ULONG w)
149 /* ret = a * w */
150         {
151         if (BN_copy(ret, a) == NULL) return 0;
152
153         if (!BN_mul_word(ret, w)) return 0;
154
155         return 1;
156         }
157
158
159 int BN_mont_red(BIGNUM *y, BN_MONTGOMERY *mont, BN_CTX *ctx)
160 /* yR^{-1} (mod p) */
161         {
162         int i;
163         BIGNUM *up, *p;
164         BN_ULONG u;
165
166         assert(y != NULL && mont != NULL && ctx != NULL);
167         assert(mont->p != NULL);
168         assert(BN_cmp(y, mont->p) < 0);
169         assert(!y->neg);
170
171
172         if (BN_is_zero(y)) return 1;
173
174         p = mont->p;
175
176         BN_CTX_start(ctx);
177         up = BN_CTX_get(ctx);
178         if (up == NULL) goto err;
179
180         for (i = 0; i < mont->p_num_bytes; i++)
181                 {
182                 u = (y->d[0]) * mont->p_inv_b_neg;                      /* u = y_0 * p' */
183
184                 if (!BN_cpy_mul_word(up, p, u)) goto err;       /* up = u * p */
185
186                 if (!BN_add(y, y, up)) goto err;                        
187 #ifdef TEST
188                 if (y->d[0]) goto err;
189 #endif
190                 if (!BN_rshift(y, y, BN_BITS2)) goto err;       /* y = (y + up)/b */
191                 }
192
193
194         if (BN_cmp(y, mont->p) >= 0)
195                 {
196                 if (!BN_sub(y, y, mont->p)) goto err;
197                 }
198
199         BN_CTX_end(ctx);
200         return 1;
201
202 err:
203         BN_CTX_end(ctx);
204         return 0;
205         }
206
207
208 int BN_mont_mod_mul(BIGNUM *r, BIGNUM *x, BIGNUM *y, BN_MONTGOMERY *mont, BN_CTX *ctx)
209 /* r = x * y mod p */
210 /* r != x && r! = y !!! */
211         {
212         BIGNUM *xiy, *up;
213         BN_ULONG u;
214         int i;
215         
216
217         assert(r != x && r != y);
218         assert(r != NULL && x != NULL  && y != NULL && mont != NULL && ctx != NULL);
219         assert(mont->p != NULL);
220         assert(BN_cmp(x, mont->p) < 0);
221         assert(BN_cmp(y, mont->p) < 0);
222         assert(!x->neg);
223         assert(!y->neg);
224
225         if (BN_is_zero(x) || BN_is_zero(y))
226                 {
227                 if (!BN_zero(r)) return 0;
228                 return 1;
229                 }
230
231
232
233         BN_CTX_start(ctx);
234         xiy = BN_CTX_get(ctx);
235         up = BN_CTX_get(ctx);
236         if (up == NULL) goto err;
237
238         if (!BN_zero(r)) goto err;
239
240         for (i = 0; i < x->top; i++)
241                 {
242                 u = (r->d[0] + x->d[i] * y->d[0]) * mont->p_inv_b_neg;
243
244                 if (!BN_cpy_mul_word(xiy, y, x->d[i])) goto err;
245                 if (!BN_cpy_mul_word(up, mont->p, u)) goto err;
246
247                 if (!BN_add(r, r, xiy)) goto err;
248                 if (!BN_add(r, r, up)) goto err;
249
250 #ifdef TEST
251                 if (r->d[0]) goto err;
252 #endif
253                 if (!BN_rshift(r, r, BN_BITS2)) goto err;
254                 }
255
256         for (i = x->top; i < mont->p_num_bytes; i++)
257                 {
258                 u = (r->d[0]) * mont->p_inv_b_neg;
259
260                 if (!BN_cpy_mul_word(up, mont->p, u)) goto err;
261
262                 if (!BN_add(r, r, up)) goto err;
263
264 #ifdef TEST
265                 if (r->d[0]) goto err;
266 #endif
267                 if (!BN_rshift(r, r, BN_BITS2)) goto err;
268                 }
269
270
271         if (BN_cmp(r, mont->p) >= 0)
272                 {
273                 if (!BN_sub(r, r, mont->p)) goto err;
274                 }
275
276
277         BN_CTX_end(ctx);
278         return 1;
279
280 err:
281         BN_CTX_end(ctx);
282         return 0;
283         }