Convert more functions in ssl/statem/statem_dtls.c to use SSLfatal()
authorMatt Caswell <matt@openssl.org>
Thu, 23 Nov 2017 11:19:34 +0000 (11:19 +0000)
committerMatt Caswell <matt@openssl.org>
Mon, 4 Dec 2017 13:31:48 +0000 (13:31 +0000)
Reviewed-by: Richard Levitte <levitte@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/4778)

ssl/record/rec_layer_d1.c
ssl/statem/statem_dtls.c

index 3eabf71cf60e5fc7462a92ff722a33b9a3945e38..c5857a10d27c6bf8b4e03933c4a0756f9808e141 100644 (file)
@@ -397,8 +397,12 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
     }
 
     /* Check for timeout */
-    if (dtls1_handle_timeout(s) > 0)
+    if (dtls1_handle_timeout(s) > 0) {
         goto start;
+    } else if (ossl_statem_in_error(s)) {
+        /* dtls1_handle_timeout() has failed with a fatal error */
+        return -1;
+    }
 
     /* get new packet if necessary */
     if ((SSL3_RECORD_get_length(rr) == 0)
@@ -633,7 +637,12 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             if (dtls1_check_timeout_num(s) < 0)
                 return -1;
 
-            dtls1_retransmit_buffered_messages(s);
+            if (dtls1_retransmit_buffered_messages(s) <= 0) {
+                /* Fail if we encountered a fatal error */
+                if (ossl_statem_in_error(s))
+                    return -1;
+                
+            }
             SSL3_RECORD_set_length(rr, 0);
             goto start;
         }
index 651e8c2ed80ed573fd6b2c8c60e78233b730a658..d19fe0a0c4c82e4189191d2287cc24190e6a47c3 100644 (file)
@@ -415,8 +415,9 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
     /* sanity checking */
     if ((frag_off + frag_len) > msg_len
             || msg_len > dtls1_max_handshake_message_len(s)) {
-        SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-        return SSL_AD_ILLEGAL_PARAMETER;
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                 SSL_R_EXCESSIVE_MESSAGE_SIZE);
+        return 0;
     }
 
     if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */
@@ -425,8 +426,9 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
          * dtls_max_handshake_message_len(s) above
          */
         if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) {
-            SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, ERR_R_BUF_LIB);
-            return SSL_AD_INTERNAL_ERROR;
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                     ERR_R_BUF_LIB);
+            return 0;
         }
 
         s->s3->tmp.message_size = msg_len;
@@ -439,13 +441,18 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
          * They must be playing with us! BTW, failure to enforce upper limit
          * would open possibility for buffer overrun.
          */
-        SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-        return SSL_AD_ILLEGAL_PARAMETER;
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_F_DTLS1_PREPROCESS_FRAGMENT,
+                 SSL_R_EXCESSIVE_MESSAGE_SIZE);
+        return 0;
     }
 
-    return 0;                   /* no error */
+    return 1;
 }
 
+/*
+ * Returns 1 if there is a buffered fragment available, 0 if not, or -1 on a
+ * fatal error.
+ */
 static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
 {
     /*-
@@ -456,7 +463,7 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
      */
     pitem *item;
     hm_fragment *frag;
-    int al;
+    int ret;
 
     do {
         item = pqueue_peek(s->d1->buffered_messages);
@@ -483,9 +490,10 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
         size_t frag_len = frag->msg_header.frag_len;
         pqueue_pop(s->d1->buffered_messages);
 
-        al = dtls1_preprocess_fragment(s, &frag->msg_header);
+        /* Calls SSLfatal() as required */
+        ret = dtls1_preprocess_fragment(s, &frag->msg_header);
 
-        if (al == 0) {          /* no alert */
+        if (ret) {
             unsigned char *p =
                 (unsigned char *)s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
             memcpy(&p[frag->msg_header.frag_off], frag->fragment,
@@ -495,14 +503,14 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, size_t *len)
         dtls1_hm_fragment_free(frag);
         pitem_free(item);
 
-        if (al == 0) {
+        if (ret) {
             *len = frag_len;
             return 1;
         }
 
-        ssl3_send_alert(s, SSL3_AL_FATAL, al);
+        /* Fatal error */
         s->init_num = 0;
-        return 0;
+        return -1;
     } else {
         return 0;
     }
@@ -719,7 +727,7 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 {
     unsigned char wire[DTLS1_HM_HEADER_LENGTH];
     size_t mlen, frag_off, frag_len;
-    int i, al, recvd_type;
+    int i, ret, recvd_type;
     struct hm_header_st msg_hdr;
     size_t readbytes;
 
@@ -727,7 +735,12 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 
  redo:
     /* see if we have the required fragment already */
-    if (dtls1_retrieve_buffered_fragment(s, &frag_len)) {
+    ret = dtls1_retrieve_buffered_fragment(s, &frag_len);
+    if (ret < 0) {
+        /* SSLfatal() already called */
+        return 0;
+    }
+    if (ret > 0) {
         s->init_num = frag_len;
         *len = frag_len;
         return 1;
@@ -743,9 +756,9 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
     }
     if (recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
         if (wire[0] != SSL3_MT_CCS) {
-            al = SSL_AD_UNEXPECTED_MESSAGE;
-            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
-                   SSL_R_BAD_CHANGE_CIPHER_SPEC);
+            SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                     SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
+                     SSL_R_BAD_CHANGE_CIPHER_SPEC);
             goto f_err;
         }
 
@@ -760,8 +773,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
 
     /* Handshake fails if message header is incomplete */
     if (readbytes != DTLS1_HM_HEADER_LENGTH) {
-        al = SSL_AD_UNEXPECTED_MESSAGE;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
+        SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
         goto f_err;
     }
 
@@ -777,8 +790,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
      * Fragments must not span records.
      */
     if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
-        al = SSL_AD_ILLEGAL_PARAMETER;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
         goto f_err;
     }
 
@@ -817,15 +830,17 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
             goto redo;
         } else {                /* Incorrectly formatted Hello request */
 
-            al = SSL_AD_UNEXPECTED_MESSAGE;
-            SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
-                   SSL_R_UNEXPECTED_MESSAGE);
+            SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE,
+                     SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
+                     SSL_R_UNEXPECTED_MESSAGE);
             goto f_err;
         }
     }
 
-    if ((al = dtls1_preprocess_fragment(s, &msg_hdr)))
+    if (!dtls1_preprocess_fragment(s, &msg_hdr)) {
+        /* SSLfatal() already called */
         goto f_err;
+    }
 
     if (frag_len > 0) {
         unsigned char *p =
@@ -852,8 +867,8 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
      * to fail
      */
     if (readbytes != frag_len) {
-        al = SSL_AD_ILLEGAL_PARAMETER;
-        SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER,
+                 SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
         goto f_err;
     }
 
@@ -867,7 +882,6 @@ static int dtls_get_reassembled_message(SSL *s, int *errtype, size_t *len)
     return 1;
 
  f_err:
-    ssl3_send_alert(s, SSL3_AL_FATAL, al);
     s->init_num = 0;
     *len = 0;
     return 0;
@@ -888,8 +902,10 @@ int dtls_construct_change_cipher_spec(SSL *s, WPACKET *pkt)
         s->d1->next_handshake_write_seq++;
 
         if (!WPACKET_put_bytes_u16(pkt, s->d1->handshake_write_seq)) {
-            SSLerr(SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC, ERR_R_INTERNAL_ERROR);
-            ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR,
+                     SSL_F_DTLS_CONSTRUCT_CHANGE_CIPHER_SPEC,
+                     ERR_R_INTERNAL_ERROR);
+            return 0;
         }
     }
 
@@ -923,8 +939,9 @@ WORK_STATE dtls_wait_for_dry(SSL *s)
 int dtls1_read_failed(SSL *s, int code)
 {
     if (code > 0) {
-        SSLerr(SSL_F_DTLS1_READ_FAILED, ERR_R_INTERNAL_ERROR);
-        return 1;
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR,
+                 SSL_F_DTLS1_READ_FAILED, ERR_R_INTERNAL_ERROR);
+        return 0;
     }
 
     if (!dtls1_is_timer_expired(s)) {
@@ -1065,7 +1082,8 @@ int dtls1_retransmit_message(SSL *s, unsigned short seq, int *found)
 
     item = pqueue_find(s->d1->sent_messages, seq64be);
     if (item == NULL) {
-        SSLerr(SSL_F_DTLS1_RETRANSMIT_MESSAGE, ERR_R_INTERNAL_ERROR);
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_F_DTLS1_RETRANSMIT_MESSAGE,
+                 ERR_R_INTERNAL_ERROR);
         *found = 0;
         return 0;
     }