X-Git-Url: https://git.librecmc.org/?p=oweals%2Fnmrpflash.git;a=blobdiff_plain;f=tftp.c;h=735f992e36e635379c180087657e489b88c8a9c7;hp=e18c0aab665c22aaa618c4e2cb9eee372b25bb57;hb=HEAD;hpb=98607d141bd8deec3cedb599277589b3d1dc6280 diff --git a/tftp.c b/tftp.c index e18c0aa..735f992 100644 --- a/tftp.c +++ b/tftp.c @@ -26,10 +26,14 @@ #include #include "nmrpd.h" -#define TFTP_PKT_SIZE 516 +#ifndef O_BINARY +#define O_BINARY 0 +#endif + +#define TFTP_BLKSIZE 1456 static const char *opcode_names[] = { - "RRQ", "WRQ", "DATA", "ACK", "ERR" + "RRQ", "WRQ", "DATA", "ACK", "ERR", "OACK" }; enum tftp_opcode { @@ -37,27 +41,10 @@ enum tftp_opcode { WRQ = 2, DATA = 3, ACK = 4, - ERR = 5 + ERR = 5, + OACK = 6 }; -static const char *leafname(const char *path) -{ - const char *slash, *bslash; - - slash = strrchr(path, '/'); - bslash = strrchr(path, '\\'); - - if (slash && bslash) { - path = 1 + (slash > bslash ? slash : bslash); - } else if (slash) { - path = 1 + slash; - } else if (bslash) { - path = 1 + bslash; - } - - return path; -} - static bool is_netascii(const char *str) { uint8_t *p = (uint8_t*)str; @@ -71,9 +58,10 @@ static bool is_netascii(const char *str) return true; } -static inline void pkt_mknum(char *pkt, uint16_t n) +static inline char *pkt_mknum(char *pkt, uint16_t n) { *(uint16_t*)pkt = htons(n); + return pkt + 2; } static inline uint16_t pkt_num(char *pkt) @@ -81,29 +69,93 @@ static inline uint16_t pkt_num(char *pkt) return ntohs(*(uint16_t*)pkt); } -static void pkt_mkwrq(char *pkt, const char *filename) +static char *pkt_mkopt(char *pkt, const char *opt, const char* val) { - size_t len = 2; + strcpy(pkt, opt); + pkt += strlen(opt) + 1; + strcpy(pkt, val); + pkt += strlen(val) + 1; + return pkt; +} +static bool pkt_nextstr(char **pkt, char **str, size_t *rem) +{ + size_t len; + + if (!isprint(**pkt) || !(len = strnlen(*pkt, *rem))) { + return false; + } else if (str) { + *str = *pkt; + } + + *pkt += len + 1; + + if (*rem > 1) { + *rem -= len + 1; + } else { + *rem = 0; + } + + return true; +} + +static bool pkt_nextopt(char **pkt, char **opt, char **val, size_t *rem) +{ + return pkt_nextstr(pkt, opt, rem) && pkt_nextstr(pkt, val, rem); +} + +static char *pkt_optval(char* pkt, const char* name) +{ + size_t rem = 512; + char *opt, *val; + pkt += 2; + + while (pkt_nextopt(&pkt, &opt, &val, &rem)) { + if (!strcasecmp(name, opt)) { + return val; + } + } + + return NULL; +} + +static size_t pkt_xrqlen(char *pkt) +{ + size_t rem = 512; + + pkt += 2; + while (pkt_nextopt(&pkt, NULL, NULL, &rem)) { + ; + } + + return 514 - rem; +} + +static void pkt_mkwrq(char *pkt, const char *filename, unsigned blksize) +{ filename = leafname(filename); - if (!is_netascii(filename) || strlen(filename) > 500) { + if (!tftp_is_valid_filename(filename)) { fprintf(stderr, "Overlong/illegal filename; using 'firmware'.\n"); filename = "firmware"; } else if (!strcmp(filename, "-")) { filename = "firmware"; } - pkt_mknum(pkt, WRQ); + pkt = pkt_mknum(pkt, WRQ); + pkt = pkt_mkopt(pkt, filename, "octet"); - strcpy(pkt + len, filename); - len += strlen(filename) + 1; - strcpy(pkt + len, "octet"); + if (blksize && blksize != 512) { + pkt = pkt_mkopt(pkt, "blksize", lltostr(blksize, 10)); + } } static inline void pkt_print(char *pkt, FILE *fp) { uint16_t opcode = pkt_num(pkt); - if (!opcode || opcode > ERR) { + size_t rem; + char *opt, *val; + + if (!opcode || opcode > OACK) { fprintf(fp, "(%d)", opcode); } else { fprintf(fp, "%s", opcode_names[opcode - 1]); @@ -111,12 +163,20 @@ static inline void pkt_print(char *pkt, FILE *fp) fprintf(fp, "(%d)", pkt_num(pkt + 2)); } else if (opcode == WRQ || opcode == RRQ) { fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1); + } else if (opcode == OACK) { + fprintf(fp, "("); + rem = 512; + pkt += 2; + while (pkt_nextopt(&pkt, &opt, &val, &rem)) { + fprintf(fp, " %s=%s ", opt, val); + } + fprintf(fp, ")"); } } } static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port, - unsigned timeout) + unsigned timeout, size_t pktlen) { ssize_t len; struct sockaddr_in src; @@ -134,7 +194,7 @@ static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port, } alen = sizeof(src); - len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, (struct sockaddr*)&src, &alen); + len = recvfrom(sock, pkt, pktlen, 0, (struct sockaddr*)&src, &alen); if (len < 0) { sock_perror("recvfrom"); return -1; @@ -154,7 +214,7 @@ static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port, */ fprintf(stderr, "Error: %.32s\n", pkt); return -2; - } else if (!opcode || opcode > ERR) { + } else if (!opcode || opcode > OACK) { fprintf(stderr, "Received invalid packet: "); pkt_print(pkt, stderr); fprintf(stderr, ".\n"); @@ -178,8 +238,8 @@ static ssize_t tftp_sendto(int sock, char *pkt, size_t len, switch (pkt_num(pkt)) { case RRQ: case WRQ: - len = 2 + strlen(pkt + 2) + 1; - len += strlen(pkt + len) + 1; + case OACK: + len = pkt_xrqlen(pkt); break; case DATA: len += 4; @@ -211,37 +271,76 @@ static ssize_t tftp_sendto(int sock, char *pkt, size_t len, return sent; } +const char *leafname(const char *path) +{ + if (!path) { + return NULL; + } + + const char *slash, *bslash; + + slash = strrchr(path, '/'); + bslash = strrchr(path, '\\'); + + if (slash && bslash) { + path = 1 + (slash > bslash ? slash : bslash); + } else if (slash) { + path = 1 + slash; + } else if (bslash) { + path = 1 + bslash; + } + + return path; +} + #ifdef NMRPFLASH_WINDOWS void sock_perror(const char *msg) { win_perror2(msg, WSAGetLastError()); } -#else -inline void sock_perror(const char *msg) +#endif + +inline bool tftp_is_valid_filename(const char *filename) { - perror(msg); + return strlen(filename) <= 255 && is_netascii(filename); } -#endif + +static const char *spinner = "\\|/-"; int tftp_put(struct nmrpd_args *args) { struct sockaddr_in addr; - uint16_t block, port; + uint16_t block, port, op, blksize; ssize_t len, last_len; - int fd, sock, ret, timeout, errors, ackblock; - char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE]; + int fd, sock, ret, timeouts, errors, ackblock; + char rx[2048], tx[2048]; + const char *file_remote = args->file_remote; + char *val, *end; + bool rollover; + const unsigned rx_timeout = MAX(args->rx_timeout / (args->blind ? 50 : 5), 2000); + const unsigned max_timeouts = args->blind ? 3 : 5; sock = -1; ret = -1; + fd = -1; + + if (g_interrupted) { + goto cleanup; + } if (!strcmp(args->file_local, "-")) { fd = STDIN_FILENO; + if (!file_remote) { + file_remote = "firmware"; + } } else { - fd = open(args->file_local, O_RDONLY); + fd = open(args->file_local, O_RDONLY | O_BINARY); if (fd < 0) { - perror("open"); + xperror("open"); ret = fd; goto cleanup; + } else if (!file_remote) { + file_remote = args->file_local; } } @@ -252,42 +351,86 @@ int tftp_put(struct nmrpd_args *args) goto cleanup; } + memset(&addr, 0, sizeof(addr)); + + addr.sin_family = AF_INET; + + if (args->ipaddr_intf) { + if ((addr.sin_addr.s_addr = inet_addr(args->ipaddr_intf)) == INADDR_NONE) { + xperror("inet_addr"); + goto cleanup; + } + + if (bind(sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) { + sock_perror("bind"); + goto cleanup; + } + } + if ((addr.sin_addr.s_addr = inet_addr(args->ipaddr)) == INADDR_NONE) { - perror("inet_addr"); + xperror("inet_addr"); goto cleanup; } - addr.sin_family = AF_INET; addr.sin_port = htons(args->port); + blksize = 512; block = 0; last_len = -1; len = 0; errors = 0; + rollover = false; /* Not really, but this way the loop sends our WRQ before receiving */ - timeout = 1; - - pkt_mkwrq(tx, args->file_local); + timeouts = 1; + + pkt_mkwrq(tx, file_remote, TFTP_BLKSIZE); + + while (!g_interrupted) { + ackblock = -1; + op = pkt_num(rx); + + if (!timeouts) { + if (op == ACK) { + ackblock = pkt_num(rx + 2); + } else if (op == OACK) { + ackblock = 0; + if ((val = pkt_optval(rx, "blksize"))) { + blksize = strtol(val, &end, 10); + if (*end != '\0' || blksize < 8 || blksize > TFTP_BLKSIZE) { + fprintf(stderr, "Error: invalid blksize in OACK: %s\n", val); + ret = -1; + goto cleanup; + } - do { - if (!timeout && pkt_num(rx) == ACK) { - ackblock = pkt_num(rx + 2); - } else { - ackblock = -1; + if (verbosity) { + printf("Remote accepted blksize option: %d b\n", blksize); + } + } + } } - if (timeout || ackblock == block) { - if (!timeout) { - ++block; + if (timeouts || ackblock == block) { + if (!timeouts) { + if (++block == 0) { + if (!rollover) { + printf("Warning: TFTP block rollover. Upload might fail!\n"); + rollover = true; + } + } + + printf("%c ", spinner[block & 3]); + fflush(stdout); + printf("\b\b"); + pkt_mknum(tx, DATA); pkt_mknum(tx + 2, block); - len = read(fd, tx + 4, 512); + len = read(fd, tx + 4, blksize); if (len < 0) { - perror("read"); + xperror("read"); ret = len; goto cleanup; } else if (!len) { - if (last_len != 512 && last_len != -1) { + if (last_len != blksize && last_len != -1) { break; } } @@ -299,7 +442,7 @@ int tftp_put(struct nmrpd_args *args) if (ret < 0) { goto cleanup; } - } else if (pkt_num(rx) != ACK || ackblock > block) { + } else if ((op != OACK && op != ACK) || ackblock > block) { if (verbosity) { fprintf(stderr, "Expected ACK(%d), got ", block); pkt_print(rx, stderr); @@ -313,21 +456,27 @@ int tftp_put(struct nmrpd_args *args) } } - ret = tftp_recvfrom(sock, rx, &port, args->rx_timeout); + ret = tftp_recvfrom(sock, rx, &port, rx_timeout, blksize + 4); if (ret < 0) { goto cleanup; } else if (!ret) { - if (++timeout < 5) { + if (++timeouts < max_timeouts || (!block && timeouts < (max_timeouts * 4))) { + continue; + } else if (args->blind) { + timeouts = 0; + // fake an ACK packet + pkt_mknum(rx, ACK); + pkt_mknum(rx + 2, block); continue; } else if (block) { fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block); } else { - fprintf(stderr, "Timeout while waiting for initial reply.\n"); + fprintf(stderr, "Timeout while waiting for ACK(0)/OACK.\n"); } ret = -1; goto cleanup; } else { - timeout = 0; + timeouts = 0; ret = 0; if (!block && port != args->port) { @@ -337,9 +486,9 @@ int tftp_put(struct nmrpd_args *args) addr.sin_port = htons(port); } } - } while(1); + } - ret = 0; + ret = !g_interrupted ? 0 : -1; cleanup: if (fd >= 0) {