Various clean-ups
[oweals/openssl.git] / ssl / t1_lib.c
index e21bb8b7d19b6771b163298716f52c72446b26a1..28b25e122f511485826b80a62fe595b9673708c4 100644 (file)
@@ -186,12 +186,12 @@ static const uint16_t suiteb_curves[] = {
     TLSEXT_curve_P_384
 };
 
-const TLS_GROUP_INFO *tls1_group_id_lookup(uint16_t curve_id)
+const TLS_GROUP_INFO *tls1_group_id_lookup(uint16_t group_id)
 {
     /* ECC curves from RFC 4492 and RFC 7027 */
-    if (curve_id < 1 || curve_id > OSSL_NELEM(nid_list))
+    if (group_id < 1 || group_id > OSSL_NELEM(nid_list))
         return NULL;
-    return &nid_list[curve_id - 1];
+    return &nid_list[group_id - 1];
 }
 
 static uint16_t tls1_nid2group_id(int nid)
@@ -205,63 +205,50 @@ static uint16_t tls1_nid2group_id(int nid)
 }
 
 /*
- * Get curves list, if "sess" is set return client curves otherwise
- * preferred list.
- * Sets |num_curves| to the number of curves in the list, i.e.,
- * the length of |pcurves| is num_curves.
- * Returns 1 on success and 0 if the client curves list has invalid format.
- * The latter indicates an internal error: we should not be accepting such
- * lists in the first place.
+ * Set *pgroups to the supported groups list and *pgroupslen to
+ * the number of groups supported.
  */
-int tls1_get_curvelist(SSL *s, int sess, const uint16_t **pcurves,
-                       size_t *num_curves)
+void tls1_get_supported_groups(SSL *s, const uint16_t **pgroups,
+                               size_t *pgroupslen)
 {
-    size_t pcurveslen = 0;
 
-    if (sess) {
-        *pcurves = s->session->ext.supportedgroups;
-        pcurveslen = s->session->ext.supportedgroups_len;
-    } else {
-        /* For Suite B mode only include P-256, P-384 */
-        switch (tls1_suiteb(s)) {
-        case SSL_CERT_FLAG_SUITEB_128_LOS:
-            *pcurves = suiteb_curves;
-            pcurveslen = OSSL_NELEM(suiteb_curves);
-            break;
+    /* For Suite B mode only include P-256, P-384 */
+    switch (tls1_suiteb(s)) {
+    case SSL_CERT_FLAG_SUITEB_128_LOS:
+        *pgroups = suiteb_curves;
+        *pgroupslen = OSSL_NELEM(suiteb_curves);
+        break;
 
-        case SSL_CERT_FLAG_SUITEB_128_LOS_ONLY:
-            *pcurves = suiteb_curves;
-            pcurveslen = 1;
-            break;
+    case SSL_CERT_FLAG_SUITEB_128_LOS_ONLY:
+        *pgroups = suiteb_curves;
+        *pgroupslen = 1;
+        break;
 
-        case SSL_CERT_FLAG_SUITEB_192_LOS:
-            *pcurves = suiteb_curves + 1;
-            pcurveslen = 1;
-            break;
-        default:
-            *pcurves = s->ext.supportedgroups;
-            pcurveslen = s->ext.supportedgroups_len;
-        }
-        if (!*pcurves) {
-            *pcurves = eccurves_default;
-            pcurveslen = OSSL_NELEM(eccurves_default);
+    case SSL_CERT_FLAG_SUITEB_192_LOS:
+        *pgroups = suiteb_curves + 1;
+        *pgroupslen = 1;
+        break;
+
+    default:
+        if (s->ext.supportedgroups == NULL) {
+            *pgroups = eccurves_default;
+            *pgroupslen = OSSL_NELEM(eccurves_default);
+        } else {
+            *pgroups = s->ext.supportedgroups;
+            *pgroupslen = s->ext.supportedgroups_len;
         }
+        break;
     }
-
-    *num_curves = pcurveslen;
-    return 1;
 }
 
 /* See if curve is allowed by security callback */
 int tls_curve_allowed(SSL *s, uint16_t curve, int op)
 {
-    const TLS_GROUP_INFO *cinfo;
+    const TLS_GROUP_INFO *cinfo = tls1_group_id_lookup(curve);
     unsigned char ctmp[2];
-    if (curve > 0xff)
-        return 1;
-    if (curve < 1 || curve > OSSL_NELEM(nid_list))
+
+    if (cinfo == NULL)
         return 0;
-    cinfo = &nid_list[curve - 1];
 # ifdef OPENSSL_NO_EC2M
     if (cinfo->flags & TLS_CURVE_CHAR2)
         return 0;
@@ -271,34 +258,13 @@ int tls_curve_allowed(SSL *s, uint16_t curve, int op)
     return ssl_security(s, op, cinfo->secbits, cinfo->nid, (void *)ctmp);
 }
 
-/* Check a curve is one of our preferences */
-int tls1_check_curve(SSL *s, const unsigned char *p, size_t len)
+/* Return 1 if "id" is in "list" */
+static int tls1_in_list(uint16_t id, const uint16_t *list, size_t listlen)
 {
-    const uint16_t *curves;
-    uint16_t curve_id;
-    size_t num_curves, i;
-    unsigned int suiteb_flags = tls1_suiteb(s);
-    if (len != 3 || p[0] != NAMED_CURVE_TYPE)
-        return 0;
-    curve_id = (p[1] << 8) | p[2];
-    /* Check curve matches Suite B preferences */
-    if (suiteb_flags) {
-        unsigned long cid = s->s3->tmp.new_cipher->id;
-        if (cid == TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) {
-            if (curve_id != TLSEXT_curve_P_256)
-                return 0;
-        } else if (cid == TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) {
-            if (curve_id != TLSEXT_curve_P_384)
-                return 0;
-        } else                  /* Should never happen */
-            return 0;
-    }
-    if (!tls1_get_curvelist(s, 0, &curves, &num_curves))
-        return 0;
-    for (i = 0; i < num_curves; i++) {
-        if (curve_id == curves[i])
-            return tls_curve_allowed(s, curve_id, SSL_SECOP_CURVE_CHECK);
-    }
+    size_t i;
+    for (i = 0; i < listlen; i++)
+        if (list[i] == id)
+            return 1;
     return 0;
 }
 
@@ -307,12 +273,12 @@ int tls1_check_curve(SSL *s, const unsigned char *p, size_t len)
  * if there is no match.
  * For nmatch == -1, return number of matches
  * For nmatch == -2, return the id of the group to use for
- * an tmp key, or 0 if there is no match.
+ * a tmp key, or 0 if there is no match.
  */
 uint16_t tls1_shared_group(SSL *s, int nmatch)
 {
     const uint16_t *pref, *supp;
-    size_t num_pref, num_supp, i, j;
+    size_t num_pref, num_supp, i;
     int k;
 
     /* Can't do anything on client side */
@@ -337,30 +303,26 @@ uint16_t tls1_shared_group(SSL *s, int nmatch)
         nmatch = 0;
     }
     /*
-     * Avoid truncation. tls1_get_curvelist takes an int
-     * but s->options is a long...
+     * If server preference set, our groups are the preference order
+     * otherwise peer decides.
      */
-    if (!tls1_get_curvelist(s,
-            (s->options & SSL_OP_CIPHER_SERVER_PREFERENCE) != 0,
-            &supp, &num_supp))
-        return 0;
-    if (!tls1_get_curvelist(s,
-            (s->options & SSL_OP_CIPHER_SERVER_PREFERENCE) == 0,
-            &pref, &num_pref))
-        return 0;
+    if (s->options & SSL_OP_CIPHER_SERVER_PREFERENCE) {
+        tls1_get_supported_groups(s, &pref, &num_pref);
+        tls1_get_peer_groups(s, &supp, &num_supp);
+    } else {
+        tls1_get_peer_groups(s, &pref, &num_pref);
+        tls1_get_supported_groups(s, &supp, &num_supp);
+    }
 
     for (k = 0, i = 0; i < num_pref; i++) {
         uint16_t id = pref[i];
 
-        for (j = 0; j < num_supp; j++) {
-            if (id == supp[j]) {
-                if (!tls_curve_allowed(s, id, SSL_SECOP_CURVE_SHARED))
+        if (!tls1_in_list(id, supp, num_supp)
+            || !tls_curve_allowed(s, id, SSL_SECOP_CURVE_SHARED))
                     continue;
-                if (nmatch == k)
-                    return id;
-                k++;
-            }
-        }
+        if (nmatch == k)
+            return id;
+         k++;
     }
     if (nmatch == -1)
         return k;
@@ -483,7 +445,7 @@ static int tls1_check_pkey_comp(SSL *s, EVP_PKEY *pkey)
 
         if (field_type == NID_X9_62_prime_field)
             comp_id = TLSEXT_ECPOINTFORMAT_ansiX962_compressed_prime;
-        else if (field_type == NID_X9_62_prime_field)
+        else if (field_type == NID_X9_62_characteristic_two_field)
             comp_id = TLSEXT_ECPOINTFORMAT_ansiX962_compressed_char2;
         else
             return 0;
@@ -501,23 +463,38 @@ static int tls1_check_pkey_comp(SSL *s, EVP_PKEY *pkey)
     }
     return 0;
 }
+
 /* Check a group id matches preferences */
-static int tls1_check_group_id(SSL *s, uint16_t group_id)
+int tls1_check_group_id(SSL *s, uint16_t group_id)
     {
     const uint16_t *groups;
-    size_t i, groups_len;
+    size_t groups_len;
 
     if (group_id == 0)
         return 0;
 
+    /* Check for Suite B compliance */
+    if (tls1_suiteb(s) && s->s3->tmp.new_cipher != NULL) {
+        unsigned long cid = s->s3->tmp.new_cipher->id;
+
+        if (cid == TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) {
+            if (group_id != TLSEXT_curve_P_256)
+                return 0;
+        } else if (cid == TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) {
+            if (group_id != TLSEXT_curve_P_384)
+                return 0;
+        } else {
+            /* Should never happen */
+            return 0;
+        }
+    }
+
     /* Check group is one of our preferences */
-    if (!tls1_get_curvelist(s, 0, &groups, &groups_len))
+    tls1_get_supported_groups(s, &groups, &groups_len);
+    if (!tls1_in_list(group_id, groups, groups_len))
         return 0;
-    for (i = 0; i < groups_len; i++) {
-        if (groups[i] == group_id)
-            break;
-    }
-    if (i == groups_len)
+
+    if (!tls_curve_allowed(s, group_id, SSL_SECOP_CURVE_CHECK))
         return 0;
 
     /* For clients, nothing more to check */
@@ -525,8 +502,7 @@ static int tls1_check_group_id(SSL *s, uint16_t group_id)
         return 1;
 
     /* Check group is one of peers preferences */
-    if (!tls1_get_curvelist(s, 1, &groups, &groups_len))
-        return 0;
+    tls1_get_peer_groups(s, &groups, &groups_len);
 
     /*
      * RFC 4492 does not require the supported elliptic curves extension
@@ -536,12 +512,7 @@ static int tls1_check_group_id(SSL *s, uint16_t group_id)
      */
     if (groups_len == 0)
             return 1;
-
-    for (i = 0; i < groups_len; i++) {
-        if (groups[i] == group_id)
-            return 1;
-    }
-    return 0;
+    return tls1_in_list(group_id, groups, groups_len);
 }
 
 void tls1_get_formatlist(SSL *s, const unsigned char **pformats,
@@ -1432,7 +1403,7 @@ void ssl_set_sig_mask(uint32_t *pmask_a, SSL *s, int op)
      * in disabled_mask.
      */
     sigalgslen = tls12_get_psigalgs(s, 1, &sigalgs);
-    for (i = 0; i < sigalgslen; i ++, sigalgs++) {
+    for (i = 0; i < sigalgslen; i++, sigalgs++) {
         const SIGALG_LOOKUP *lu = tls1_lookup_sigalg(*sigalgs);
         const SSL_CERT_LOOKUP *clu;
 
@@ -1440,6 +1411,8 @@ void ssl_set_sig_mask(uint32_t *pmask_a, SSL *s, int op)
             continue;
 
         clu = ssl_cert_lookup_by_idx(lu->sig_idx);
+       if (clu == NULL)
+               continue;
 
         /* If algorithm is disabled see if we can enable it */
         if ((clu->amask & disabled_mask) != 0