Some changes related to "-f -"
[oweals/nmrpflash.git] / tftp.c
1 /**
2  * nmrpflash - Netgear Unbrick Utility
3  * Copyright (C) 2016 Joseph Lehner <joseph.c.lehner@gmail.com>
4  *
5  * nmrpflash is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * nmrpflash is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with nmrpflash.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19
20 #include <string.h>
21 #include <unistd.h>
22 #include <stdlib.h>
23 #include <stdio.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include <ctype.h>
27 #include "nmrpd.h"
28
29 #define TFTP_PKT_SIZE 516
30
31 static const char *opcode_names[] = {
32         "RRQ", "WRQ", "DATA", "ACK", "ERR"
33 };
34
35 enum tftp_opcode {
36         RRQ  = 1,
37         WRQ  = 2,
38         DATA = 3,
39         ACK  = 4,
40         ERR  = 5
41 };
42
43 static const char *leafname(const char *path)
44 {
45         const char *slash, *bslash;
46
47         slash = strrchr(path, '/');
48         bslash = strrchr(path, '\\');
49
50         if (slash && bslash) {
51                 path = 1 + (slash > bslash ? slash : bslash);
52         } else if (slash) {
53                 path = 1 + slash;
54         } else if (bslash) {
55                 path = 1 + bslash;
56         }
57
58         return path;
59 }
60
61 static bool is_netascii(const char *str)
62 {
63         uint8_t *p = (uint8_t*)str;
64
65         for (; *p; ++p) {
66                 if (*p < 0x20 || *p > 0x7f) {
67                         return false;
68                 }
69         }
70
71         return true;
72 }
73
74 static inline void pkt_mknum(char *pkt, uint16_t n)
75 {
76         *(uint16_t*)pkt = htons(n);
77 }
78
79 static inline uint16_t pkt_num(char *pkt)
80 {
81         return ntohs(*(uint16_t*)pkt);
82 }
83
84 static void pkt_mkwrq(char *pkt, const char *filename)
85 {
86         size_t len = 2;
87
88         filename = leafname(filename);
89         if (!is_netascii(filename) || strlen(filename) > 500) {
90                 fprintf(stderr, "Overlong/illegal filename; using 'firmware'.\n");
91                 filename = "firmware";
92         } else if (!strcmp(filename, "-")) {
93                 filename = "firmware";
94         }
95
96         pkt_mknum(pkt, WRQ);
97
98         strcpy(pkt + len, filename);
99         len += strlen(filename) + 1;
100         strcpy(pkt + len, "octet");
101 }
102
103 static inline void pkt_print(char *pkt, FILE *fp)
104 {
105         uint16_t opcode = pkt_num(pkt);
106         if (!opcode || opcode > ERR) {
107                 fprintf(fp, "(%d)", opcode);
108         } else {
109                 fprintf(fp, "%s", opcode_names[opcode - 1]);
110                 if (opcode == ACK || opcode == DATA) {
111                         fprintf(fp, "(%d)", pkt_num(pkt + 2));
112                 } else if (opcode == WRQ || opcode == RRQ) {
113                         fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1);
114                 }
115         }
116 }
117
118 static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port,
119                 unsigned timeout)
120 {
121         ssize_t len;
122         struct sockaddr_in src;
123 #ifndef NMRPFLASH_WINDOWS
124         socklen_t alen;
125 #else
126         int alen;
127 #endif
128
129         len = select_fd(sock, timeout);
130         if (len < 0) {
131                 return -1;
132         } else if (!len) {
133                 return 0;
134         }
135
136         alen = sizeof(src);
137         len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, (struct sockaddr*)&src, &alen);
138         if (len < 0) {
139                 sock_perror("recvfrom");
140                 return -1;
141         }
142
143         *port = ntohs(src.sin_port);
144
145         uint16_t opcode = pkt_num(pkt);
146
147         if (opcode == ERR) {
148                 fprintf(stderr, "Error (%d): %.511s\n", pkt_num(pkt + 2), pkt + 4);
149                 return -1;
150         } else if (isprint(pkt[0])) {
151                 /* In case of a firmware checksum error, the EX2700 I've tested this
152                  * on sends a raw UDP packet containing just an error message starting
153                  * at offset 0. The limit of 32 chars is arbitrary.
154                  */
155                 fprintf(stderr, "Error: %.32s\n", pkt);
156                 return -2;
157         } else if (!opcode || opcode > ERR) {
158                 fprintf(stderr, "Received invalid packet: ");
159                 pkt_print(pkt, stderr);
160                 fprintf(stderr, ".\n");
161                 return -1;
162         }
163
164         if (verbosity > 2) {
165                 printf(">> ");
166                 pkt_print(pkt, stdout);
167                 printf("\n");
168         }
169
170         return len;
171 }
172
173 static ssize_t tftp_sendto(int sock, char *pkt, size_t len,
174                 struct sockaddr_in *dst)
175 {
176         ssize_t sent;
177
178         switch (pkt_num(pkt)) {
179                 case RRQ:
180                 case WRQ:
181                         len = 2 + strlen(pkt + 2) + 1;
182                         len += strlen(pkt + len) + 1;
183                         break;
184                 case DATA:
185                         len += 4;
186                         break;
187                 case ACK:
188                         len = 4;
189                         break;
190                 case ERR:
191                         len = 4 + strlen(pkt + 4);
192                         break;
193                 default:
194                         fprintf(stderr, "Attempted to send invalid packet ");
195                         pkt_print(pkt, stderr);
196                         fprintf(stderr, "; this is a bug!\n");
197                         return -1;
198         }
199
200         if (verbosity > 2) {
201                 printf("<< ");
202                 pkt_print(pkt, stdout);
203                 printf("\n");
204         }
205
206         sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
207         if (sent < 0) {
208                 sock_perror("sendto");
209         }
210
211         return sent;
212 }
213
214 #ifdef NMRPFLASH_WINDOWS
215 void sock_perror(const char *msg)
216 {
217         win_perror2(msg, WSAGetLastError());
218 }
219 #else
220 inline void sock_perror(const char *msg)
221 {
222         perror(msg);
223 }
224 #endif
225
226 int tftp_put(struct nmrpd_args *args)
227 {
228         struct sockaddr_in addr;
229         uint16_t block, port;
230         ssize_t len, last_len;
231         int fd, sock, ret, timeout, errors, ackblock;
232         char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE];
233
234         sock = -1;
235         ret = -1;
236
237         if (!strcmp(args->filename, "-")) {
238                 fd = STDIN_FILENO;
239         } else {
240                 fd = open(args->filename, O_RDONLY);
241                 if (fd < 0) {
242                         perror("open");
243                         ret = fd;
244                         goto cleanup;
245                 }
246         }
247
248         sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
249         if (sock < 0) {
250                 sock_perror("socket");
251                 ret = sock;
252                 goto cleanup;
253         }
254
255         if ((addr.sin_addr.s_addr = inet_addr(args->ipaddr)) == INADDR_NONE) {
256                 perror("inet_addr");
257                 goto cleanup;
258         }
259
260         addr.sin_family = AF_INET;
261         addr.sin_port = htons(args->port);
262
263         block = 0;
264         last_len = -1;
265         len = 0;
266         errors = 0;
267         /* Not really, but this way the loop sends our WRQ before receiving */
268         timeout = 1;
269
270         pkt_mkwrq(tx, args->filename);
271
272         do {
273                 if (!timeout && pkt_num(rx) == ACK) {
274                         ackblock = pkt_num(rx + 2);
275                 } else {
276                         ackblock = -1;
277                 }
278
279                 if (timeout || ackblock == block) {
280                         if (!timeout) {
281                                 ++block;
282                                 pkt_mknum(tx, DATA);
283                                 pkt_mknum(tx + 2, block);
284                                 len = read(fd, tx + 4, 512);
285                                 if (len < 0) {
286                                         perror("read");
287                                         ret = len;
288                                         goto cleanup;
289                                 } else if (!len) {
290                                         if (last_len != 512 && last_len != -1) {
291                                                 break;
292                                         }
293                                 }
294
295                                 last_len = len;
296                         }
297
298                         ret = tftp_sendto(sock, tx, len, &addr);
299                         if (ret < 0) {
300                                 goto cleanup;
301                         }
302                 } else if (pkt_num(rx) != ACK || ackblock > block) {
303                         if (verbosity) {
304                                 fprintf(stderr, "Expected ACK(%d), got ", block);
305                                 pkt_print(rx, stderr);
306                                 fprintf(stderr, ".\n");
307                         }
308
309                         if (ackblock != -1 && ++errors > 5) {
310                                 fprintf(stderr, "Protocol error; bailing out.\n");
311                                 ret = -1;
312                                 goto cleanup;
313                         }
314                 }
315
316                 ret = tftp_recvfrom(sock, rx, &port, args->rx_timeout);
317                 if (ret < 0) {
318                         goto cleanup;
319                 } else if (!ret) {
320                         if (++timeout < 5) {
321                                 continue;
322                         } else if (block) {
323                                 fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block);
324                         } else {
325                                 fprintf(stderr, "Timeout while waiting for initial reply.\n");
326                         }
327                         ret = -1;
328                         goto cleanup;
329                 } else {
330                         timeout = 0;
331                         ret = 0;
332
333                         if (!block && port != args->port) {
334                                 if (verbosity > 1) {
335                                         printf("Switching to port %d\n", port);
336                                 }
337                                 addr.sin_port = htons(port);
338                         }
339                 }
340         } while(1);
341
342         ret = 0;
343
344 cleanup:
345         if (fd >= 0) {
346                 close(fd);
347         }
348
349         if (sock >= 0) {
350 #ifndef NMRPFLASH_WINDOWS
351                 shutdown(sock, SHUT_RDWR);
352                 close(sock);
353 #else
354                 shutdown(sock, SD_BOTH);
355                 closesocket(sock);
356 #endif
357         }
358
359         return ret;
360 }