Add datagram mode to the SPTPS protocol.
authorGuus Sliepen <guus@tinc-vpn.org>
Sun, 18 Mar 2012 15:42:02 +0000 (16:42 +0100)
committerGuus Sliepen <guus@tinc-vpn.org>
Sun, 18 Mar 2012 15:42:02 +0000 (16:42 +0100)
* Everything is identical except the headers of the records.
* Instead of sending explicit message length and having an implicit sequence
  number, datagram mode has an implicit message length and an explicit sequence
  number.
* The sequence number is used to set the most significant bytes of the counter.

src/protocol_auth.c
src/sptps.c
src/sptps.h
src/sptps_test.c

index d3eef21e0eb4c2cc8903446ba5bcd2a89666c14f..3bf18b21cb2dbf4747b038b4ef52106fa1eedd22 100644 (file)
@@ -144,7 +144,7 @@ bool id_h(connection_t *c, char *request) {
                else
                        snprintf(label, sizeof label, "tinc TCP key expansion %s %s", c->name, myself->name);
 
-               return sptps_start(&c->sptps, c, c->outgoing, myself->connection->ecdsa, c->ecdsa, label, sizeof label, send_meta_sptps, receive_meta_sptps);
+               return sptps_start(&c->sptps, c, c->outgoing, false, myself->connection->ecdsa, c->ecdsa, label, sizeof label, send_meta_sptps, receive_meta_sptps);
        } else {
                return send_metakey(c);
        }
index 395c92fc540f38897fe71a80869056e75fcb62c0..fa1594db842dc95555e9056d6df035010530c630 100644 (file)
@@ -54,8 +54,41 @@ static bool error(sptps_t *s, int s_errno, const char *msg) {
        return false;
 }
 
+// Send a record (datagram version, accepts all record types, handles encryption and authentication).
+static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
+       char buffer[len + 23UL];
+
+       // Create header with sequence number, length and record type
+       uint32_t seqno = htonl(s->outseqno++);
+       uint16_t netlen = htons(len);
+
+       memcpy(buffer, &netlen, 2);
+       memcpy(buffer + 2, &seqno, 4);
+       buffer[6] = type;
+
+       // Add plaintext (TODO: avoid unnecessary copy)
+       memcpy(buffer + 7, data, len);
+
+       if(s->outstate) {
+               // If first handshake has finished, encrypt and HMAC
+               cipher_set_counter(&s->outcipher, &seqno, sizeof seqno);
+               if(!cipher_counter_xor(&s->outcipher, buffer + 6, len + 1UL, buffer + 6))
+                       return false;
+
+               if(!digest_create(&s->outdigest, buffer, len + 7UL, buffer + 7UL + len))
+                       return false;
+
+               return s->send_data(s->handle, buffer + 2, len + 21UL);
+       } else {
+               // Otherwise send as plaintext
+               return s->send_data(s->handle, buffer + 2, len + 5UL);
+       }
+}
 // Send a record (private version, accepts all record types, handles encryption and authentication).
 static bool send_record_priv(sptps_t *s, uint8_t type, const char *data, uint16_t len) {
+       if(s->datagram)
+               return send_record_priv_datagram(s, type, data, len);
+
        char buffer[len + 23UL];
 
        // Create header with sequence number, length and record type
@@ -102,6 +135,8 @@ static bool send_kex(sptps_t *s) {
        size_t keylen = ECDH_SIZE;
 
        // Make room for our KEX message, which we will keep around since send_sig() needs it.
+       if(s->mykex)
+               abort();
        s->mykex = realloc(s->mykex, 1 + 32 + keylen);
        if(!s->mykex)
                return error(s, errno, strerror(errno));
@@ -219,6 +254,8 @@ static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
        // Ignore version number for now.
 
        // Make a copy of the KEX message, send_sig() and receive_sig() need it
+       if(s->hiskex)
+               abort();
        s->hiskex = realloc(s->hiskex, len);
        if(!s->hiskex)
                return error(s, errno, strerror(errno));
@@ -315,7 +352,6 @@ static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
                        // If we already sent our secondary public ECDH key, we expect the peer to send his.
                        if(!receive_sig(s, data, len))
                                return false;
-                       // s->state = SPTPS_ACK;
                        s->state = SPTPS_ACK;
                        return true;
                case SPTPS_ACK:
@@ -331,8 +367,79 @@ static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
        }
 }
 
+// Receive incoming data, datagram version.
+static bool sptps_receive_data_datagram(sptps_t *s, const char *data, size_t len) {
+       if(len < (s->instate ? 21 : 5))
+               return error(s, EIO, "Received short packet");
+
+       uint32_t seqno;
+       memcpy(&seqno, data, 4);
+       seqno = ntohl(seqno);
+
+       if(!s->instate) {
+               if(seqno != s->inseqno) {
+                       fprintf(stderr, "Received invalid packet seqno: %d != %d\n", seqno, s->inseqno);
+                       return error(s, EIO, "Invalid packet seqno");
+               }
+
+               s->inseqno = seqno + 1;
+
+               uint8_t type = data[4];
+
+               if(type != SPTPS_HANDSHAKE)
+                       return error(s, EIO, "Application record received before handshake finished");
+
+               return receive_handshake(s, data + 5, len - 5);
+       }
+
+       if(seqno < s->inseqno) {
+               fprintf(stderr, "Received late or replayed packet: %d < %d\n", seqno, s->inseqno);
+               return true;
+       }
+
+       if(seqno > s->inseqno)
+               fprintf(stderr, "Missed %d packets\n", seqno - s->inseqno);
+
+       s->inseqno = seqno + 1;
+
+       uint16_t netlen = htons(len - 21);
+
+       char buffer[len + 23];
+
+       memcpy(buffer, &netlen, 2);
+       memcpy(buffer + 2, data, len);
+
+       memcpy(&seqno, buffer + 2, 4);
+
+       // Check HMAC and decrypt.
+       if(!digest_verify(&s->indigest, buffer, len - 14, buffer + len - 14))
+               return error(s, EIO, "Invalid HMAC");
+
+       cipher_set_counter(&s->incipher, &seqno, sizeof seqno);
+       if(!cipher_counter_xor(&s->incipher, buffer + 6, len - 4, buffer + 6))
+               return false;
+
+       // Append a NULL byte for safety.
+       buffer[len - 14] = 0;
+
+       uint8_t type = buffer[6];
+
+       if(type < SPTPS_HANDSHAKE) {
+               if(!s->instate)
+                       return error(s, EIO, "Application record received before handshake finished");
+               if(!s->receive_record(s->handle, type, buffer + 7, len - 21))
+                       return false;
+       } else {
+               return error(s, EIO, "Invalid record type");
+       }
+
+       return true;
+}
 // Receive incoming data. Check if it contains a complete record, if so, handle it.
 bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
+       if(s->datagram)
+               return sptps_receive_data_datagram(s, data, len);
+
        while(len) {
                // First read the 2 length bytes.
                if(s->buflen < 6) {
@@ -422,12 +529,13 @@ bool sptps_receive_data(sptps_t *s, const char *data, size_t len) {
 }
 
 // Start a SPTPS session.
-bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
+bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
        // Initialise struct sptps
        memset(s, 0, sizeof *s);
 
        s->handle = handle;
        s->initiator = initiator;
+       s->datagram = datagram;
        s->mykey = mykey;
        s->hiskey = hiskey;
 
@@ -435,11 +543,13 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_
        if(!s->label)
                return error(s, errno, strerror(errno));
 
-       s->inbuf = malloc(7);
-       if(!s->inbuf)
-               return error(s, errno, strerror(errno));
-       s->buflen = 4;
-       memset(s->inbuf, 0, 4);
+       if(!datagram) {
+               s->inbuf = malloc(7);
+               if(!s->inbuf)
+                       return error(s, errno, strerror(errno));
+               s->buflen = 4;
+               memset(s->inbuf, 0, 4);
+       }
 
        memcpy(s->label, label, labellen);
        s->labellen = labellen;
index 065c6a099d24d146a2120e4cf489b4d2005e88d0..3854ec24a6eedaf878e9e56eaac27aaaa063bb24 100644 (file)
@@ -45,6 +45,7 @@ typedef bool (*receive_record_t)(void *handle, uint8_t type, const char *data, u
 
 typedef struct sptps {
        bool initiator;
+       bool datagram;
        int state;
 
        char *inbuf;
@@ -76,7 +77,7 @@ typedef struct sptps {
        receive_record_t receive_record;
 } sptps_t;
 
-extern bool sptps_start(sptps_t *s, void *handle, bool initiator, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
+extern bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t mykey, ecdsa_t hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record);
 extern bool sptps_stop(sptps_t *s);
 extern bool sptps_send_record(sptps_t *s, uint8_t type, const char *data, uint16_t len);
 extern bool sptps_receive_data(sptps_t *s, const char *data, size_t len);
index 56dcc88634d4061ea01bd8ad5178736c817b0fde..3ee7ab69b63eedaa862bc73275637ddbd8de3da8 100644 (file)
@@ -51,9 +51,16 @@ static bool receive_record(void *handle, uint8_t type, const char *data, uint16_
 
 int main(int argc, char *argv[]) {
        bool initiator = false;
+       bool datagram = false;
 
-       if(argc < 3) {
-               fprintf(stderr, "Usage: %s my_ecdsa_key_file his_ecdsa_key_file [host] port\n", argv[0]);
+       if(argc > 1 && !strcmp(argv[1], "-d")) {
+               datagram = true;
+               argc--;
+               argv++;
+       }
+
+       if(argc < 4) {
+               fprintf(stderr, "Usage: %s [-d] my_ecdsa_key_file his_ecdsa_key_file [host] port\n", argv[0]);
                return 1;
        }
 
@@ -123,7 +130,7 @@ int main(int argc, char *argv[]) {
        fprintf(stderr, "Keys loaded\n");
 
        sptps_t s;
-       if(!sptps_start(&s, &sock, initiator, mykey, hiskey, "sptps_test", 10, send_data, receive_record))
+       if(!sptps_start(&s, &sock, initiator, datagram, mykey, hiskey, "sptps_test", 10, send_data, receive_record))
                return 1;
 
        while(true) {