Mark DTLS records as read when we have finished with them
authorMatt Caswell <matt@openssl.org>
Wed, 2 May 2018 15:07:13 +0000 (16:07 +0100)
committerMatt Caswell <matt@openssl.org>
Tue, 15 May 2018 13:08:31 +0000 (14:08 +0100)
The TLS code marks records as read when its finished using a record. The DTLS code did
not do that. However SSL_has_pending() relies on it. So we should make DTLS consistent.

Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/6160)

ssl/record/rec_layer_d1.c
ssl/record/ssl3_record.c

index 0eeed4fc10073a24d4112c868d7c54beed46c451..6111a2e1913e50a445ebc761f8f484ac223db48c 100644 (file)
@@ -473,6 +473,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             return -1;
         }
         SSL3_RECORD_set_length(rr, 0);
+        SSL3_RECORD_set_read(rr);
         goto start;
     }
 
@@ -482,8 +483,9 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
      */
     if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
         SSL3_RECORD_set_length(rr, 0);
+        SSL3_RECORD_set_read(rr);
         s->rwstate = SSL_NOTHING;
-        return (0);
+        return 0;
     }
 
     if (type == SSL3_RECORD_get_type(rr)
@@ -508,8 +510,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         if (recvd_type != NULL)
             *recvd_type = SSL3_RECORD_get_type(rr);
 
-        if (len <= 0)
-            return (len);
+        if (len <= 0) {
+            /*
+             * Mark a zero length record as read. This ensures multiple calls to
+             * SSL_read() with a zero length buffer will eventually cause
+             * SSL_pending() to report data as being available.
+             */
+            if (SSL3_RECORD_get_length(rr) == 0)
+                SSL3_RECORD_set_read(rr);
+            return len;
+        }
 
         if ((unsigned int)len > SSL3_RECORD_get_length(rr))
             n = SSL3_RECORD_get_length(rr);
@@ -517,12 +527,16 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             n = (unsigned int)len;
 
         memcpy(buf, &(SSL3_RECORD_get_data(rr)[SSL3_RECORD_get_off(rr)]), n);
-        if (!peek) {
+        if (peek) {
+            if (SSL3_RECORD_get_length(rr) == 0)
+                SSL3_RECORD_set_read(rr);
+        } else {
             SSL3_RECORD_sub_length(rr, n);
             SSL3_RECORD_add_off(rr, n);
             if (SSL3_RECORD_get_length(rr) == 0) {
                 s->rlayer.rstate = SSL_ST_READ_HEADER;
                 SSL3_RECORD_set_off(rr, 0);
+                SSL3_RECORD_set_read(rr);
             }
         }
 #ifndef OPENSSL_NO_SCTP
@@ -573,6 +587,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             }
             /* Exit and notify application to read again */
             SSL3_RECORD_set_length(rr, 0);
+            SSL3_RECORD_set_read(rr);
             s->rwstate = SSL_READING;
             BIO_clear_retry_flags(SSL_get_rbio(s));
             BIO_set_retry_read(SSL_get_rbio(s));
@@ -617,6 +632,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
 #endif
                 s->rlayer.rstate = SSL_ST_READ_HEADER;
                 SSL3_RECORD_set_length(rr, 0);
+                SSL3_RECORD_set_read(rr);
                 goto start;
             }
 
@@ -626,6 +642,8 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                 SSL3_RECORD_add_off(rr, 1);
                 SSL3_RECORD_add_length(rr, -1);
             }
+            if (SSL3_RECORD_get_length(rr) == 0)
+                SSL3_RECORD_set_read(rr);
             *dest_len = dest_maxlen;
         }
     }
@@ -696,6 +714,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             }
         } else {
             SSL3_RECORD_set_length(rr, 0);
+            SSL3_RECORD_set_read(rr);
             ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_NO_RENEGOTIATION);
         }
         /*
@@ -720,6 +739,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                 || (s->options & SSL_OP_NO_RENEGOTIATION) != 0)) {
         s->rlayer.d->handshake_fragment_len = 0;
         SSL3_RECORD_set_length(rr, 0);
+        SSL3_RECORD_set_read(rr);
         ssl3_send_alert(s, SSL3_AL_WARNING, SSL_AD_NO_RENEGOTIATION);
         goto start;
     }
@@ -747,6 +767,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
 
         if (alert_level == SSL3_AL_WARNING) {
             s->s3->warn_alert = alert_descr;
+            SSL3_RECORD_set_read(rr);
 
             s->rlayer.alert_count++;
             if (s->rlayer.alert_count == MAX_WARN_ALERT_COUNT) {
@@ -811,6 +832,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
             BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
             ERR_add_error_data(2, "SSL alert number ", tmp);
             s->shutdown |= SSL_RECEIVED_SHUTDOWN;
+            SSL3_RECORD_set_read(rr);
             SSL_CTX_remove_session(s->session_ctx, s->session);
             return (0);
         } else {
@@ -826,7 +848,8 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
                                             * shutdown */
         s->rwstate = SSL_NOTHING;
         SSL3_RECORD_set_length(rr, 0);
-        return (0);
+        SSL3_RECORD_set_read(rr);
+        return 0;
     }
 
     if (SSL3_RECORD_get_type(rr) == SSL3_RT_CHANGE_CIPHER_SPEC) {
@@ -835,6 +858,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
          * are still missing, so just drop it.
          */
         SSL3_RECORD_set_length(rr, 0);
+        SSL3_RECORD_set_read(rr);
         goto start;
     }
 
@@ -849,6 +873,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         dtls1_get_message_header(rr->data, &msg_hdr);
         if (SSL3_RECORD_get_epoch(rr) != s->rlayer.d->r_epoch) {
             SSL3_RECORD_set_length(rr, 0);
+            SSL3_RECORD_set_read(rr);
             goto start;
         }
 
@@ -862,6 +887,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
 
             dtls1_retransmit_buffered_messages(s);
             SSL3_RECORD_set_length(rr, 0);
+            SSL3_RECORD_set_read(rr);
             if (!(s->mode & SSL_MODE_AUTO_RETRY)) {
                 if (SSL3_BUFFER_get_left(&s->rlayer.rbuf) == 0) {
                     /* no read-ahead left? */
@@ -916,6 +942,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf,
         /* TLS just ignores unknown message types */
         if (s->version == TLS1_VERSION) {
             SSL3_RECORD_set_length(rr, 0);
+            SSL3_RECORD_set_read(rr);
             goto start;
         }
         al = SSL_AD_UNEXPECTED_MESSAGE;
index c7a54feb126a69d52ec6ebb28efcd7312562fb6c..c80add37f931c29d232122959f40188c1e86ea74 100644 (file)
@@ -1531,6 +1531,7 @@ int dtls1_get_record(SSL *s)
         p += 6;
 
         n2s(p, rr->length);
+        rr->read = 0;
 
         /*
          * Lets check the version. We tolerate alerts that don't have the exact
@@ -1540,6 +1541,7 @@ int dtls1_get_record(SSL *s)
             if (version != s->version) {
                 /* unexpected version, silently discard */
                 rr->length = 0;
+                rr->read = 1;
                 RECORD_LAYER_reset_packet_length(&s->rlayer);
                 goto again;
             }
@@ -1548,6 +1550,7 @@ int dtls1_get_record(SSL *s)
         if ((version & 0xff00) != (s->version & 0xff00)) {
             /* wrong version, silently discard record */
             rr->length = 0;
+            rr->read = 1;
             RECORD_LAYER_reset_packet_length(&s->rlayer);
             goto again;
         }
@@ -1555,10 +1558,10 @@ int dtls1_get_record(SSL *s)
         if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
             /* record too long, silently discard it */
             rr->length = 0;
+            rr->read = 1;
             RECORD_LAYER_reset_packet_length(&s->rlayer);
             goto again;
         }
-
         /* now s->rlayer.rstate == SSL_ST_READ_BODY */
     }
 
@@ -1572,6 +1575,7 @@ int dtls1_get_record(SSL *s)
         /* this packet contained a partial record, dump it */
         if (n != i) {
             rr->length = 0;
+            rr->read = 1;
             RECORD_LAYER_reset_packet_length(&s->rlayer);
             goto again;
         }
@@ -1588,6 +1592,7 @@ int dtls1_get_record(SSL *s)
     bitmap = dtls1_get_bitmap(s, rr, &is_next_epoch);
     if (bitmap == NULL) {
         rr->length = 0;
+        rr->read = 1;
         RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
         goto again;             /* get another record */
     }
@@ -1602,6 +1607,7 @@ int dtls1_get_record(SSL *s)
          */
         if (!dtls1_record_replay_check(s, bitmap)) {
             rr->length = 0;
+            rr->read = 1;
             RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
             goto again;         /* get another record */
         }
@@ -1610,8 +1616,10 @@ int dtls1_get_record(SSL *s)
 #endif
 
     /* just read a 0 length packet */
-    if (rr->length == 0)
+    if (rr->length == 0) {
+        rr->read = 1;
         goto again;
+    }
 
     /*
      * If this record is from the next epoch (either HM or ALERT), and a
@@ -1626,12 +1634,14 @@ int dtls1_get_record(SSL *s)
                 return -1;
         }
         rr->length = 0;
+        rr->read = 1;
         RECORD_LAYER_reset_packet_length(&s->rlayer);
         goto again;
     }
 
     if (!dtls1_process_record(s, bitmap)) {
         rr->length = 0;
+        rr->read = 1;
         RECORD_LAYER_reset_packet_length(&s->rlayer); /* dump this record */
         goto again;             /* get another record */
     }