Add .gitignore
[oweals/nmrpflash.git] / tftp.c
diff --git a/tftp.c b/tftp.c
index 57ccc12c9034b02e87a021f044d2d5a45de3f7a6..ec15c4b1d6f422c59b50541f55df14c249ed38f6 100644 (file)
--- a/tftp.c
+++ b/tftp.c
@@ -85,7 +85,7 @@ static void pkt_mkwrq(char *pkt, const char *filename)
 
        filename = leafname(filename);
        if (!is_netascii(filename) || strlen(filename) > 500) {
-               fprintf(stderr, "Overlong/illegal filename; using 'firmware.bin'.");
+               fprintf(stderr, "Overlong/illegal filename; using 'firmware.bin'.\n");
                filename = "firmware.bin";
        }
 
@@ -111,10 +111,16 @@ static inline void pkt_print(char *pkt, FILE *fp)
        }
 }
 
-static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src,
+static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port,
                unsigned timeout)
 {
        ssize_t len;
+       struct sockaddr_in src;
+#ifndef NMRPFLASH_WINDOWS
+       socklen_t alen;
+#else
+       int alen;
+#endif
 
        len = select_fd(sock, timeout);
        if (len < 0) {
@@ -123,12 +129,15 @@ static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src,
                return 0;
        }
 
-       len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, NULL, NULL);
+       alen = sizeof(src);
+       len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, (struct sockaddr*)&src, &alen);
        if (len < 0) {
                sock_perror("recvfrom");
                return -1;
        }
 
+       *port = ntohs(src.sin_port);
+
        uint16_t opcode = pkt_num(pkt);
 
        if (opcode == ERR) {
@@ -140,12 +149,18 @@ static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src,
                 * at offset 0. The limit of 32 chars is arbitrary.
                 */
                fprintf(stderr, "Error: %.32s\n", pkt);
-               return -3;
-       } else {
+               return -2;
+       } else if (!opcode || opcode > ERR) {
                fprintf(stderr, "Received invalid packet: ");
                pkt_print(pkt, stderr);
                fprintf(stderr, ".\n");
-               return -2;
+               return -1;
+       }
+
+       if (verbosity > 2) {
+               printf(">> ");
+               pkt_print(pkt, stdout);
+               printf("\n");
        }
 
        return len;
@@ -178,6 +193,12 @@ static ssize_t tftp_sendto(int sock, char *pkt, size_t len,
                        return -1;
        }
 
+       if (verbosity > 2) {
+               printf("<< ");
+               pkt_print(pkt, stdout);
+               printf("\n");
+       }
+
        sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
        if (sent < 0) {
                sock_perror("sendto");
@@ -201,9 +222,9 @@ inline void sock_perror(const char *msg)
 int tftp_put(struct nmrpd_args *args)
 {
        struct sockaddr_in addr;
-       uint16_t block;
-       ssize_t len;
-       int fd, sock, ret, timeout, last_len;
+       uint16_t block, port;
+       ssize_t len, last_len;
+       int fd, sock, ret, timeout, errors, ackblock;
        char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE];
 
        sock = -1;
@@ -234,13 +255,16 @@ int tftp_put(struct nmrpd_args *args)
        block = 0;
        last_len = -1;
        len = 0;
+       errors = 0;
        /* Not really, but this way the loop sends our WRQ before receiving */
        timeout = 1;
 
        pkt_mkwrq(tx, args->filename);
 
        do {
-               if (timeout || (pkt_num(rx) == ACK && pkt_num(rx + 2) == block)) {
+               ackblock = pkt_num(rx) == ACK ? pkt_num(rx + 2) : -1;
+
+               if (timeout || ackblock == block) {
                        if (!timeout) {
                                ++block;
                                pkt_mknum(tx, DATA);
@@ -251,7 +275,7 @@ int tftp_put(struct nmrpd_args *args)
                                        ret = len;
                                        goto cleanup;
                                } else if (!len) {
-                                       if (last_len != 512) {
+                                       if (last_len != 512 && last_len != -1) {
                                                break;
                                        }
                                }
@@ -263,24 +287,43 @@ int tftp_put(struct nmrpd_args *args)
                        if (ret < 0) {
                                goto cleanup;
                        }
-               } else if (pkt_num(rx) != ACK) {
-                       fprintf(stderr, "Expected ACK(%d), got ", block);
-                       pkt_print(rx, stderr);
-                       fprintf(stderr, "!\n");
+               } else if (pkt_num(rx) != ACK || ackblock > block) {
+                       if (verbosity) {
+                               fprintf(stderr, "Expected ACK(%d), got ", block);
+                               pkt_print(rx, stderr);
+                               fprintf(stderr, ".\n");
+                       }
+
+                       if (ackblock != -1 && ++errors > 5) {
+                               fprintf(stderr, "Protocol error; bailing out.\n");
+                               ret = -1;
+                               goto cleanup;
+                       }
                }
 
-               ret = tftp_recvfrom(sock, rx, &addr, args->rx_timeout);
+               ret = tftp_recvfrom(sock, rx, &port, args->rx_timeout);
                if (ret < 0) {
                        goto cleanup;
                } else if (!ret) {
                        if (++timeout < 5) {
                                continue;
+                       } else if (block) {
+                               fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block);
+                       } else {
+                               fprintf(stderr, "Timeout while waiting for initial reply.\n");
                        }
-                       fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block);
+                       ret = -1;
                        goto cleanup;
                } else {
                        timeout = 0;
                        ret = 0;
+
+                       if (!block && port != args->port) {
+                               if (verbosity > 1) {
+                                       printf("Switching to port %d\n", port);
+                               }
+                               addr.sin_port = htons(port);
+                       }
                }
        } while(1);