PR: 1828
authorDr. Stephen Henson <steve@openssl.org>
Wed, 15 Apr 2009 14:49:36 +0000 (14:49 +0000)
committerDr. Stephen Henson <steve@openssl.org>
Wed, 15 Apr 2009 14:49:36 +0000 (14:49 +0000)
Submitted by: Robin Seggelmann <seggelmann@fh-muenster.de>
Approved by: steve@openssl.org

Updated DTLS Rentransmission bug patch.

ssl/d1_both.c
ssl/d1_pkt.c
ssl/dtls1.h
ssl/ssl_locl.h

index 8e2058ed33b726d6a4a551215c1cf55c3634a480..913098361143506950154849ce0f8e58c49c8806 100644 (file)
@@ -136,7 +136,6 @@ static unsigned char *dtls1_write_message_header(SSL *s,
 static void dtls1_set_message_header_int(SSL *s, unsigned char mt,
        unsigned long len, unsigned short seq_num, unsigned long frag_off, 
        unsigned long frag_len);
-static int dtls1_retransmit_buffered_messages(SSL *s);
 static long dtls1_get_message_fragment(SSL *s, int st1, int stn, 
        long max, int *ok);
 
@@ -932,8 +931,21 @@ int dtls1_read_failed(SSL *s, int code)
        return dtls1_retransmit_buffered_messages(s) ;
        }
 
+int
+dtls1_get_queue_priority(unsigned short seq, int is_ccs)
+       {
+       /* The index of the retransmission queue actually is the message sequence number,
+        * since the queue only contains messages of a single handshake. However, the
+        * ChangeCipherSpec has no message sequence number and so using only the sequence
+        * will result in the CCS and Finished having the same index. To prevent this,
+        * the sequence number is multiplied by 2. In case of a CCS 1 is subtracted.
+        * This does not only differ CSS and Finished, it also maintains the order of the
+        * index (important for priority queues) and fits in the unsigned short variable.
+        */     
+       return seq * 2 - is_ccs;
+       }
 
-static int
+int
 dtls1_retransmit_buffered_messages(SSL *s)
        {
        pqueue sent = s->d1->sent_messages;
@@ -947,8 +959,9 @@ dtls1_retransmit_buffered_messages(SSL *s)
        for ( item = pqueue_next(&iter); item != NULL; item = pqueue_next(&iter))
                {
                frag = (hm_fragment *)item->data;
-               if ( dtls1_retransmit_message(s, frag->msg_header.seq, 0, &found) <= 0 &&
-                       found)
+                       if ( dtls1_retransmit_message(s,
+                               dtls1_get_queue_priority(frag->msg_header.seq, frag->msg_header.is_ccs),
+                               0, &found) <= 0 && found)
                        {
                        fprintf(stderr, "dtls1_retransmit_message() failed\n");
                        return -1;
@@ -964,7 +977,6 @@ dtls1_buffer_message(SSL *s, int is_ccs)
        pitem *item;
        hm_fragment *frag;
        unsigned char seq64be[8];
-       unsigned int epoch = s->d1->w_epoch;
 
        /* this function is called immediately after a message has 
         * been serialized */
@@ -978,7 +990,6 @@ dtls1_buffer_message(SSL *s, int is_ccs)
                {
                OPENSSL_assert(s->d1->w_msg_hdr.msg_len + 
                        DTLS1_CCS_HEADER_LENGTH == (unsigned int)s->init_num);
-               epoch++;
                }
        else
                {
@@ -993,11 +1004,18 @@ dtls1_buffer_message(SSL *s, int is_ccs)
        frag->msg_header.frag_len = s->d1->w_msg_hdr.msg_len;
        frag->msg_header.is_ccs = is_ccs;
 
+       /* save current state*/
+       frag->msg_header.saved_retransmit_state.enc_write_ctx = s->enc_write_ctx;
+       frag->msg_header.saved_retransmit_state.write_hash = s->write_hash;
+       frag->msg_header.saved_retransmit_state.compress = s->compress;
+       frag->msg_header.saved_retransmit_state.session = s->session;
+       frag->msg_header.saved_retransmit_state.epoch = s->d1->w_epoch;
+       
        memset(seq64be,0,sizeof(seq64be));
-       seq64be[0] = (unsigned char)(epoch>>8);
-       seq64be[1] = (unsigned char)(epoch);
-       seq64be[6] = (unsigned char)(frag->msg_header.seq>>8);
-       seq64be[7] = (unsigned char)(frag->msg_header.seq);
+       seq64be[6] = (unsigned char)(dtls1_get_queue_priority(frag->msg_header.seq,
+                                                                                                                 frag->msg_header.is_ccs)>>8);
+       seq64be[7] = (unsigned char)(dtls1_get_queue_priority(frag->msg_header.seq,
+                                                                                                                 frag->msg_header.is_ccs));
 
        item = pitem_new(seq64be, frag);
        if ( item == NULL)
@@ -1026,6 +1044,8 @@ dtls1_retransmit_message(SSL *s, unsigned short seq, unsigned long frag_off,
        hm_fragment *frag ;
        unsigned long header_length;
        unsigned char seq64be[8];
+       struct dtls1_retransmit_state saved_state;
+       unsigned char save_write_sequence[8];
 
        /*
          OPENSSL_assert(s->init_num == 0);
@@ -1061,9 +1081,45 @@ dtls1_retransmit_message(SSL *s, unsigned short seq, unsigned long frag_off,
                frag->msg_header.msg_len, frag->msg_header.seq, 0, 
                frag->msg_header.frag_len);
 
+       /* save current state */
+       saved_state.enc_write_ctx = s->enc_write_ctx;
+       saved_state.write_hash = s->write_hash;
+       saved_state.compress = s->compress;
+       saved_state.session = s->session;
+       saved_state.epoch = s->d1->w_epoch;
+       saved_state.epoch = s->d1->w_epoch;
+       
        s->d1->retransmitting = 1;
+       
+       /* restore state in which the message was originally sent */
+       s->enc_write_ctx = frag->msg_header.saved_retransmit_state.enc_write_ctx;
+       s->write_hash = frag->msg_header.saved_retransmit_state.write_hash;
+       s->compress = frag->msg_header.saved_retransmit_state.compress;
+       s->session = frag->msg_header.saved_retransmit_state.session;
+       s->d1->w_epoch = frag->msg_header.saved_retransmit_state.epoch;
+       
+       if (frag->msg_header.saved_retransmit_state.epoch == saved_state.epoch - 1)
+       {
+               memcpy(save_write_sequence, s->s3->write_sequence, sizeof(s->s3->write_sequence));
+               memcpy(s->s3->write_sequence, s->d1->last_write_sequence, sizeof(s->s3->write_sequence));
+       }
+       
        ret = dtls1_do_write(s, frag->msg_header.is_ccs ? 
-               SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE);
+                                                SSL3_RT_CHANGE_CIPHER_SPEC : SSL3_RT_HANDSHAKE);
+       
+       /* restore current state */
+       s->enc_write_ctx = saved_state.enc_write_ctx;
+       s->write_hash = saved_state.write_hash;
+       s->compress = saved_state.compress;
+       s->session = saved_state.session;
+       s->d1->w_epoch = saved_state.epoch;
+       
+       if (frag->msg_header.saved_retransmit_state.epoch == saved_state.epoch - 1)
+       {
+               memcpy(s->d1->last_write_sequence, s->s3->write_sequence, sizeof(s->s3->write_sequence));
+               memcpy(s->s3->write_sequence, save_write_sequence, sizeof(s->s3->write_sequence));
+       }
+
        s->d1->retransmitting = 0;
 
        (void)BIO_flush(SSL_get_wbio(s));
index c215d7096adb0b267fd6c19733266d1121864572..2e9d5452f7beb23679ea58275b534c52ff779804 100644 (file)
@@ -1020,7 +1020,9 @@ start:
                                n2s(p, seq);
                                n2l3(p, frag_off);
 
-                               dtls1_retransmit_message(s, seq, frag_off, &found);
+                               dtls1_retransmit_message(s,
+                                                                                dtls1_get_queue_priority(frag->msg_header.seq, 0),
+                                                                                frag_off, &found);
                                if ( ! found  && SSL_in_init(s))
                                        {
                                        /* fprintf( stderr,"in init = %d\n", SSL_in_init(s)); */
@@ -1109,6 +1111,16 @@ start:
                        goto start;
                        }
 
+               /* If we are server, we may have a repeated FINISHED of the
+                * client here, then retransmit our CCS and FINISHED.
+                */
+               if (msg_hdr.type == SSL3_MT_FINISHED)
+                       {
+                       dtls1_retransmit_buffered_messages(s);
+                       rr->length = 0;
+                       goto start;
+                       }
+
                if (((s->state&SSL_ST_MASK) == SSL_ST_OK) &&
                        !(s->s3->flags & SSL3_FLAGS_NO_RENEGOTIATE_CIPHERS))
                        {
@@ -1763,6 +1775,7 @@ dtls1_reset_seq_numbers(SSL *s, int rw)
        else
                {
                seq = s->s3->write_sequence;
+               memcpy(s->d1->last_write_sequence, seq, sizeof(s->s3->write_sequence));
                s->d1->w_epoch++;
                }
 
index 6ecbc49314381ec72d112057b0cacf8eb69def22..cb8bd7cdfe10268c8fcd25224ecb0665bce8603e 100644 (file)
@@ -102,6 +102,19 @@ typedef struct dtls1_bitmap_st
                                           encoding */
        } DTLS1_BITMAP;
 
+struct dtls1_retransmit_state
+       {
+       EVP_CIPHER_CTX *enc_write_ctx;  /* cryptographic state */
+       EVP_MD_CTX *write_hash;                 /* used for mac generation */
+#ifndef OPENSSL_NO_COMP
+       COMP_CTX *compress;                             /* compression */
+#else
+       char *compress; 
+#endif
+       SSL_SESSION *session;
+       unsigned short epoch;
+       };
+
 struct hm_header_st
        {
        unsigned char type;
@@ -110,6 +123,7 @@ struct hm_header_st
        unsigned long frag_off;
        unsigned long frag_len;
        unsigned int is_ccs;
+       struct dtls1_retransmit_state saved_retransmit_state;
        };
 
 struct ccs_header_st
@@ -169,6 +183,9 @@ typedef struct dtls1_state_st
 
        unsigned short handshake_read_seq;
 
+       /* save last sequence number for retransmissions */
+       unsigned char last_write_sequence[8];
+
        /* Received handshake records (processed and unprocessed) */
        record_pqueue unprocessed_rcds;
        record_pqueue processed_rcds;
index 9b6aadd9504dc7e84c984686fb5003eee77db564..dd9fa8780c7cb475c3ee8aa12ebefca1421308ef 100644 (file)
@@ -935,6 +935,8 @@ int dtls1_read_failed(SSL *s, int code);
 int dtls1_buffer_message(SSL *s, int ccs);
 int dtls1_retransmit_message(SSL *s, unsigned short seq, 
        unsigned long frag_off, int *found);
+int dtls1_get_queue_priority(unsigned short seq, int is_ccs);
+int dtls1_retransmit_buffered_messages(SSL *s);
 void dtls1_clear_record_buffer(SSL *s);
 void dtls1_get_message_header(unsigned char *data, struct hm_header_st *msg_hdr);
 void dtls1_get_ccs_header(unsigned char *data, struct ccs_header_st *ccs_hdr);