Integrate ec_err.[co].
[oweals/openssl.git] / crypto / bn / bn_gcd.c
index ea6816a43fc32a6deb82a68f3fb552e3bf19fb0c..d5caf5136f1248007dc6bb6f215665351cc055a9 100644 (file)
@@ -204,7 +204,7 @@ err:
 BIGNUM *BN_mod_inverse(BIGNUM *in,
        const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx)
        {
-       BIGNUM *A,*B,*X,*Y,*M,*D,*R=NULL;
+       BIGNUM *A,*B,*X,*Y,*M,*D,*T,*R=NULL;
        BIGNUM *ret=NULL;
        int sign;
 
@@ -218,7 +218,8 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
        D = BN_CTX_get(ctx);
        M = BN_CTX_get(ctx);
        Y = BN_CTX_get(ctx);
-       if (Y == NULL) goto err;
+       T = BN_CTX_get(ctx);
+       if (T == NULL) goto err;
 
        if (in == NULL)
                R=BN_new();
@@ -239,7 +240,7 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
        /* From  B = a mod |n|,  A = |n|  it follows that
         *
         *      0 <= B < A,
-        *           X*a  ==  B   (mod |n|),
+        *      sign*X*a  ==  B   (mod |n|),
         *     -sign*Y*a  ==  A   (mod |n|).
         */
 
@@ -249,11 +250,51 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
 
                /*
                 *      0 < B < A,
-                * (*)       X*a  ==  B   (mod |n|),
+                * (*)  sign*X*a  ==  B   (mod |n|),
                 *     -sign*Y*a  ==  A   (mod |n|)
                 */
 
-               if (!BN_div(D,M,A,B,ctx)) goto err;
+               /* (D, M) := (A/B, A%B) ... */
+               if (BN_num_bits(A) == BN_num_bits(B))
+                       {
+                       if (!BN_one(D)) goto err;
+                       if (!BN_sub(M,A,B)) goto err;
+                       }
+               else if (BN_num_bits(A) == BN_num_bits(B) + 1)
+                       {
+                       /* A/B is 1, 2, or 3 */
+                       if (!BN_lshift1(T,B)) goto err;
+                       if (BN_ucmp(A,T) < 0)
+                               {
+                               /* A < 2*B, so D=1 */
+                               if (!BN_one(D)) goto err;
+                               if (!BN_sub(M,A,B)) goto err;
+                               }
+                       else
+                               {
+                               /* A >= 2*B, so D=2 or D=3 */
+                               if (!BN_sub(M,A,T)) goto err;
+                               if (!BN_add(D,T,B)) goto err; /* use D (:= 3*B) as temp */
+                               if (BN_ucmp(A,D) < 0)
+                                       {
+                                       /* A < 3*B, so D=2 */
+                                       if (!BN_set_word(D,2)) goto err;
+                                       /* M (= A - 2*B) already has the correct value */
+                                       }
+                               else
+                                       {
+                                       /* only D=3 remains */
+                                       if (!BN_set_word(D,3)) goto err;
+                                       /* currently  M = A - 2*B,  but we need  M = A - 3*B */
+                                       if (!BN_sub(M,M,B)) goto err;
+                                       }
+                               }
+                       }
+               else
+                       {
+                       if (!BN_div(D,M,A,B,ctx)) goto err;
+                       }
+               
                /* Now
                 *      A = D*B + M;
                 * thus we have
@@ -273,21 +314,46 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
                 * i.e.
                 *      -sign*Y*a - D*A  ==  B    (mod |n|).
                 * Similarly, (*) translates into
-                *      X*a  ==  A          (mod |n|).
+                *       sign*X*a  ==  A          (mod |n|).
                 *
                 * Thus,
-                *      -sign*Y*a - D*X*a  ==  B  (mod |n|),
+                *  -sign*Y*a - D*sign*X*a  ==  B  (mod |n|),
                 * i.e.
-                *      -sign*(Y + D*X)*a  ==  B  (mod |n|).
+                *       -sign*(Y + D*X)*a  ==  B  (mod |n|).
                 *
                 * So if we set  (X, Y, sign) := (Y + D*X, X, -sign),  we arrive back at
-                *            X*a  ==  B   (mod |n|),
+                *       sign*X*a  ==  B   (mod |n|),
                 *      -sign*Y*a  ==  A   (mod |n|).
                 * Note that  X  and  Y  stay non-negative all the time.
                 */
 
-               if (!BN_mul(tmp,D,X,ctx)) goto err;
-               if (!BN_add(tmp,tmp,Y)) goto err;
+               /* most of the time D is very small, so we can optimize tmp := D*X+Y */
+               if (BN_is_one(D))
+                       {
+                       if (!BN_add(tmp,X,Y)) goto err;
+                       }
+               else
+                       {
+                       if (BN_is_word(D,2))
+                               {
+                               if (!BN_lshift1(tmp,X)) goto err;
+                               }
+                       else if (BN_is_word(D,4))
+                               {
+                               if (!BN_lshift(tmp,X,2)) goto err;
+                               }
+                       else if (D->top == 1)
+                               {
+                               if (!BN_copy(tmp,X)) goto err;
+                               if (!BN_mul_word(tmp,D->d[0])) goto err;
+                               }
+                       else
+                               {
+                               if (!BN_mul(tmp,D,X,ctx)) goto err;
+                               }
+                       if (!BN_add(tmp,tmp,Y)) goto err;
+                       }
+               
                M=Y; /* keep the BIGNUM object, the value does not matter */
                Y=X;
                X=tmp;
@@ -295,7 +361,7 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
                }
 
        /*
-        * The while loop ends when
+        * The while loop (Euclid's algorithm) ends when
         *      A == gcd(a,n);
         * we have
         *      -sign*Y*a  ==  A  (mod |n|),
@@ -312,7 +378,14 @@ BIGNUM *BN_mod_inverse(BIGNUM *in,
        if (BN_is_one(A))
                {
                /* Y*a == 1  (mod |n|) */
-               if (!BN_mod(R,Y,n,ctx)) goto err;
+               if (BN_ucmp(Y,n) < 0)
+                       {
+                       if (!BN_copy(R,Y)) goto err;
+                       }
+               else
+                       {
+                       if (!BN_nnmod(R,Y,n,ctx)) goto err;
+                       }
                }
        else
                {
@@ -325,4 +398,3 @@ err:
        BN_CTX_end(ctx);
        return(ret);
        }
-