Ensure that NMRP packets are at least 64 bytes
[oweals/nmrpflash.git] / nmrp.c
diff --git a/nmrp.c b/nmrp.c
index c1bfb9c747936bd60d79e33219a391d93cf74748..8cf82d6ed560b9da1812ce21b5d941fd6eecf30c 100644 (file)
--- a/nmrp.c
+++ b/nmrp.c
 #define IP_LEN 4
 #define MAX_LOOP_RECV 1024
 
+#ifndef MAX
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#endif
+
 #ifndef PACKED
 #define PACKED __attribute__((__packed__))
 #endif
@@ -88,6 +92,7 @@ struct nmrp_msg {
        uint16_t len;
        /* only opts[0] is valid! think of this as a char* */
        struct nmrp_opt opts[NMRP_MAX_OPT_NUM];
+       uint8_t padding[8];
        /* this is NOT part of the transmitted packet */
        uint32_t num_opts;
 } PACKED;
@@ -266,9 +271,10 @@ static inline void msg_init(struct nmrp_msg *msg, uint16_t code)
 }
 
 #ifdef NMRPFLASH_FUZZ
-#define ethsock_create(a, b) ethsock_create_fake(a, b)
+#define NMRP_INITIAL_TIMEOUT 0
+#define ethsock_create(a, b) ((struct ethsock*)1)
 #define ethsock_get_hwaddr(a) ethsock_get_hwaddr_fake(a)
-#define ethsock_recv(a, b, c) ethsock_recv_fake(a, b, c)
+#define ethsock_recv(sock, buf, len) read(STDIN_FILENO, buf, len)
 #define ethsock_send(a, b, c) (0)
 #define ethsock_set_timeout(a, b) (0)
 #define ethsock_ip_add(a, b, c, d) (0)
@@ -276,28 +282,19 @@ static inline void msg_init(struct nmrp_msg *msg, uint16_t code)
 #define ethsock_close(a) (0)
 #define tftp_put(a) (0)
 
-static struct ethsock* ethsock_create_fake(const char *intf, uint16_t protocol)
-{
-       return (struct ethsock*)1;
-}
-
 static uint8_t *ethsock_get_hwaddr_fake(struct ethsock* sock)
 {
-       static uint8_t hwaddr[6];
-       memset(hwaddr, 0xfa, 6);
+       static uint8_t hwaddr[6] = { 0xfa, 0xfa, 0xfa, 0xfa, 0xfa, 0xfa };
        return hwaddr;
 }
-
-static ssize_t ethsock_recv_fake(struct ethsock *sock, void *buf, size_t len)
-{
-       return read(STDIN_FILENO, buf, len);
-}
+#else
+#define NMRP_INITIAL_TIMEOUT 60
 #endif
 
 static int pkt_send(struct ethsock *sock, struct nmrp_pkt *pkt)
 {
        size_t len = ntohs(pkt->msg.len) + sizeof(pkt->eh);
-       return ethsock_send(sock, pkt, len);
+       return ethsock_send(sock, pkt, MAX(len, 64));
 }
 
 static int pkt_recv(struct ethsock *sock, struct nmrp_pkt *pkt)
@@ -385,21 +382,9 @@ static int is_valid_ip(struct ethsock *sock, struct in_addr *ipaddr,
        return status < 0 ? status : arg.result;
 }
 
-static struct ethsock *gsock = NULL;
-static struct ethsock_ip_undo *g_ip_undo = NULL;
-static struct ethsock_arp_undo *g_arp_undo = NULL;
-
 static void sigh(int sig)
 {
-       printf("\n");
-       if (gsock) {
-               ethsock_arp_del(gsock, &g_arp_undo);
-               ethsock_ip_del(gsock, &g_ip_undo);
-               ethsock_close(gsock);
-               gsock = NULL;
-       }
-
-       exit(1);
+       g_interrupted = 1;
 }
 
 static const char *spinner = "\\|/-";
@@ -413,6 +398,8 @@ int nmrp_do(struct nmrpd_args *args)
        time_t beg;
        int i, status, ulreqs, expect, upload_ok, autoip;
        struct ethsock *sock;
+       struct ethsock_ip_undo *ip_undo = NULL;
+       struct ethsock_arp_undo *arp_undo = NULL;
        uint32_t intf_addr;
        void (*sigh_orig)(int);
        struct {
@@ -493,7 +480,6 @@ int nmrp_do(struct nmrpd_args *args)
                return 1;
        }
 
-       gsock = sock;
        sigh_orig = signal(SIGINT, sigh);
 
        if (!autoip) {
@@ -510,7 +496,7 @@ int nmrp_do(struct nmrpd_args *args)
                        printf("Adding %s to interface %s.\n", args->ipaddr_intf, args->intf);
                }
 
-               if (ethsock_ip_add(sock, intf_addr, ipconf.mask.s_addr, &g_ip_undo) != 0) {
+               if (ethsock_ip_add(sock, intf_addr, ipconf.mask.s_addr, &ip_undo) != 0) {
                        goto out;
                }
        }
@@ -536,14 +522,14 @@ int nmrp_do(struct nmrpd_args *args)
        upload_ok = 0;
        beg = time_monotonic();
 
-       while (1) {
+       while (!g_interrupted) {
                printf("\rAdvertising NMRP server on %s ... %c",
                                args->intf, spinner[i]);
                fflush(stdout);
                i = (i + 1) & 3;
 
                if (pkt_send(sock, &tx) < 0) {
-                       perror("sendto");
+                       xperror("sendto");
                        goto out;
                }
 
@@ -553,13 +539,12 @@ int nmrp_do(struct nmrpd_args *args)
                } else if (status == 1) {
                        goto out;
                } else {
-                       if ((time_monotonic() - beg) >= 60) {
+                       /* because we don't want nmrpflash's exit status to be zero */
+                       status = 1;
+                       if ((time_monotonic() - beg) >= NMRP_INITIAL_TIMEOUT) {
                                printf("\nNo response after 60 seconds. Bailing out.\n");
                                goto out;
                        }
-#ifdef NMRPFLASH_FUZZ
-                       goto out;
-#endif
                }
        }
 
@@ -568,7 +553,7 @@ int nmrp_do(struct nmrpd_args *args)
        expect = NMRP_C_CONF_REQ;
        ulreqs = 0;
 
-       do {
+       while (!g_interrupted) {
                if (expect != NMRP_C_NONE && rx.msg.code != expect) {
                        fprintf(stderr, "Received %s while waiting for %s!\n",
                                        msg_code_str(rx.msg.code), msg_code_str(expect));
@@ -606,7 +591,7 @@ int nmrp_do(struct nmrpd_args *args)
                                printf("Sending configuration: %s, netmask %s.\n",
                                                args->ipaddr, args->ipmask);
 
-                               if (ethsock_arp_add(sock, rx.eh.ether_shost, ipconf.addr.s_addr, &g_arp_undo) != 0) {
+                               if (ethsock_arp_add(sock, rx.eh.ether_shost, ipconf.addr.s_addr, &arp_undo) != 0) {
                                        goto out;
                                }
 
@@ -677,11 +662,14 @@ int nmrp_do(struct nmrpd_args *args)
                                                printf("Uploading %s ... ", leafname(args->file_local));
                                        }
                                        fflush(stdout);
-                                       status = tftp_put(args);
+                                       if (!(status = tftp_put(args))) {
+                                               printf("OK\n");
+                                       }
+
                                }
 
                                if (!status) {
-                                       printf("OK\nWaiting for remote to respond.\n");
+                                       printf("Waiting for remote to respond.\n");
                                        upload_ok = 1;
                                        ethsock_set_timeout(sock, args->ul_timeout);
                                        tx.msg.code = NMRP_C_KEEP_ALIVE_REQ;
@@ -714,7 +702,7 @@ int nmrp_do(struct nmrpd_args *args)
                        msg_hton(&tx.msg);
 
                        if (pkt_send(sock, &tx) < 0) {
-                               perror("sendto");
+                               xperror("sendto");
                                goto out;
                        }
 
@@ -739,21 +727,21 @@ int nmrp_do(struct nmrpd_args *args)
 
                ethsock_set_timeout(sock, args->rx_timeout);
 
-       } while (1);
-
-       status = 0;
+       }
 
-       if (ulreqs) {
-               printf("Reboot your device now.\n");
-       } else {
-               printf("No upload request received.\n");
+       if (!g_interrupted) {
+               status = 0;
+               if (ulreqs) {
+                       printf("Reboot your device now.\n");
+               } else {
+                       printf("No upload request received.\n");
+               }
        }
 
 out:
        signal(SIGINT, sigh_orig);
-       gsock = NULL;
-       ethsock_arp_del(sock, &g_arp_undo);
-       ethsock_ip_del(sock, &g_ip_undo);
+       ethsock_arp_del(sock, &arp_undo);
+       ethsock_ip_del(sock, &ip_undo);
        ethsock_close(sock);
        return status;
 }