Merge branch 'master' of github.com:jclehner/nmrp-flash
[oweals/nmrpflash.git] / tftp.c
1 /**
2  * nmrpflash - Netgear Unbrick Utility
3  * Copyright (C) 2016 Joseph Lehner <joseph.c.lehner@gmail.com>
4  *
5  * nmrpflash is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * nmrpflash is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with nmrpflash.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19
20 #include <string.h>
21 #include <unistd.h>
22 #include <stdlib.h>
23 #include <stdio.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include <ctype.h>
27 #include "nmrpd.h"
28
29 #define TFTP_PKT_SIZE 516
30
31 static const char *opcode_names[] = {
32         "RRQ", "WRQ", "DATA", "ACK", "ERR"
33 };
34
35 enum tftp_opcode {
36         RRQ  = 1,
37         WRQ  = 2,
38         DATA = 3,
39         ACK  = 4,
40         ERR  = 5
41 };
42
43 static const char *leafname(const char *path)
44 {
45         const char *slash, *bslash;
46
47         slash = strrchr(path, '/');
48         bslash = strrchr(path, '\\');
49
50         if (slash && bslash) {
51                 path = 1 + (slash > bslash ? slash : bslash);
52         } else if (slash) {
53                 path = 1 + slash;
54         } else if (bslash) {
55                 path = 1 + bslash;
56         }
57
58         return path;
59 }
60
61 static bool is_netascii(const char *str)
62 {
63         uint8_t *p = (uint8_t*)str;
64
65         for (; *p; ++p) {
66                 if (*p < 0x20 || *p > 0x7f) {
67                         return false;
68                 }
69         }
70
71         return true;
72 }
73
74 static inline void pkt_mknum(char *pkt, uint16_t n)
75 {
76         *(uint16_t*)pkt = htons(n);
77 }
78
79 static inline uint16_t pkt_num(char *pkt)
80 {
81         return ntohs(*(uint16_t*)pkt);
82 }
83
84 static void pkt_mkwrq(char *pkt, const char *filename)
85 {
86         size_t len = 2;
87
88         filename = leafname(filename);
89         if (!is_netascii(filename) || strlen(filename) > 500) {
90                 fprintf(stderr, "Overlong/illegal filename; using 'firmware.bin'.\n");
91                 filename = "firmware.bin";
92         }
93
94         pkt_mknum(pkt, WRQ);
95
96         strcpy(pkt + len, filename);
97         len += strlen(filename) + 1;
98         strcpy(pkt + len, "octet");
99 }
100
101 static inline void pkt_print(char *pkt, FILE *fp)
102 {
103         uint16_t opcode = pkt_num(pkt);
104         if (!opcode || opcode > ERR) {
105                 fprintf(fp, "(%d)", opcode);
106         } else {
107                 fprintf(fp, "%s", opcode_names[opcode - 1]);
108                 if (opcode == ACK || opcode == DATA) {
109                         fprintf(fp, "(%d)", pkt_num(pkt + 2));
110                 } else if (opcode == WRQ || opcode == RRQ) {
111                         fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1);
112                 }
113         }
114 }
115
116 static ssize_t tftp_recvfrom(int sock, char *pkt, uint16_t* port,
117                 unsigned timeout)
118 {
119         ssize_t len;
120         struct sockaddr_in src;
121 #ifndef NMRPFLASH_WINDOWS
122         socklen_t alen;
123 #else
124         int alen;
125 #endif
126
127         len = select_fd(sock, timeout);
128         if (len < 0) {
129                 return -1;
130         } else if (!len) {
131                 return 0;
132         }
133
134         alen = sizeof(src);
135         len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, (struct sockaddr*)&src, &alen);
136         if (len < 0) {
137                 sock_perror("recvfrom");
138                 return -1;
139         }
140
141         *port = ntohs(src.sin_port);
142
143         uint16_t opcode = pkt_num(pkt);
144
145         if (opcode == ERR) {
146                 fprintf(stderr, "Error (%d): %.511s\n", pkt_num(pkt + 2), pkt + 4);
147                 return -1;
148         } else if (isprint(pkt[0])) {
149                 /* In case of a firmware checksum error, the EX2700 I've tested this
150                  * on sends a raw UDP packet containing just an error message starting
151                  * at offset 0. The limit of 32 chars is arbitrary.
152                  */
153                 fprintf(stderr, "Error: %.32s\n", pkt);
154                 return -2;
155         } else if (!opcode || opcode > ERR) {
156                 fprintf(stderr, "Received invalid packet: ");
157                 pkt_print(pkt, stderr);
158                 fprintf(stderr, ".\n");
159                 return -1;
160         }
161
162         if (verbosity > 2) {
163                 printf(">> ");
164                 pkt_print(pkt, stdout);
165                 printf("\n");
166         }
167
168         return len;
169 }
170
171 static ssize_t tftp_sendto(int sock, char *pkt, size_t len,
172                 struct sockaddr_in *dst)
173 {
174         ssize_t sent;
175
176         switch (pkt_num(pkt)) {
177                 case RRQ:
178                 case WRQ:
179                         len = 2 + strlen(pkt + 2) + 1;
180                         len += strlen(pkt + len) + 1;
181                         break;
182                 case DATA:
183                         len += 4;
184                         break;
185                 case ACK:
186                         len = 4;
187                         break;
188                 case ERR:
189                         len = 4 + strlen(pkt + 4);
190                         break;
191                 default:
192                         fprintf(stderr, "Attempted to send invalid packet ");
193                         pkt_print(pkt, stderr);
194                         fprintf(stderr, "; this is a bug!\n");
195                         return -1;
196         }
197
198         if (verbosity > 2) {
199                 printf("<< ");
200                 pkt_print(pkt, stdout);
201                 printf("\n");
202         }
203
204         sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
205         if (sent < 0) {
206                 sock_perror("sendto");
207         }
208
209         return sent;
210 }
211
212 #ifdef NMRPFLASH_WINDOWS
213 void sock_perror(const char *msg)
214 {
215         win_perror2(msg, WSAGetLastError());
216 }
217 #else
218 inline void sock_perror(const char *msg)
219 {
220         perror(msg);
221 }
222 #endif
223
224 int tftp_put(struct nmrpd_args *args)
225 {
226         struct sockaddr_in addr;
227         uint16_t block, port;
228         ssize_t len, last_len;
229         int fd, sock, ret, timeout, errors, ackblock;
230         char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE];
231
232         sock = -1;
233         ret = -1;
234
235         fd = open(args->filename, O_RDONLY);
236         if (fd < 0) {
237                 perror("open");
238                 ret = fd;
239                 goto cleanup;
240         }
241
242         sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
243         if (sock < 0) {
244                 sock_perror("socket");
245                 ret = sock;
246                 goto cleanup;
247         }
248
249         if ((addr.sin_addr.s_addr = inet_addr(args->ipaddr)) == INADDR_NONE) {
250                 perror("inet_addr");
251                 goto cleanup;
252         }
253
254         addr.sin_family = AF_INET;
255         addr.sin_port = htons(args->port);
256
257         block = 0;
258         last_len = -1;
259         len = 0;
260         errors = 0;
261         /* Not really, but this way the loop sends our WRQ before receiving */
262         timeout = 1;
263
264         pkt_mkwrq(tx, args->filename);
265
266         do {
267                 if (!timeout && pkt_num(rx) == ACK) {
268                         ackblock = pkt_num(rx + 2);
269                 } else {
270                         ackblock = -1;
271                 }
272
273                 if (timeout || ackblock == block) {
274                         if (!timeout) {
275                                 ++block;
276                                 pkt_mknum(tx, DATA);
277                                 pkt_mknum(tx + 2, block);
278                                 len = read(fd, tx + 4, 512);
279                                 if (len < 0) {
280                                         perror("read");
281                                         ret = len;
282                                         goto cleanup;
283                                 } else if (!len) {
284                                         if (last_len != 512 && last_len != -1) {
285                                                 break;
286                                         }
287                                 }
288
289                                 last_len = len;
290                         }
291
292                         ret = tftp_sendto(sock, tx, len, &addr);
293                         if (ret < 0) {
294                                 goto cleanup;
295                         }
296                 } else if (pkt_num(rx) != ACK || ackblock > block) {
297                         if (verbosity) {
298                                 fprintf(stderr, "Expected ACK(%d), got ", block);
299                                 pkt_print(rx, stderr);
300                                 fprintf(stderr, ".\n");
301                         }
302
303                         if (ackblock != -1 && ++errors > 5) {
304                                 fprintf(stderr, "Protocol error; bailing out.\n");
305                                 ret = -1;
306                                 goto cleanup;
307                         }
308                 }
309
310                 ret = tftp_recvfrom(sock, rx, &port, args->rx_timeout);
311                 if (ret < 0) {
312                         goto cleanup;
313                 } else if (!ret) {
314                         if (++timeout < 5) {
315                                 continue;
316                         } else if (block) {
317                                 fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block);
318                         } else {
319                                 fprintf(stderr, "Timeout while waiting for initial reply.\n");
320                         }
321                         ret = -1;
322                         goto cleanup;
323                 } else {
324                         timeout = 0;
325                         ret = 0;
326
327                         if (!block && port != args->port) {
328                                 if (verbosity > 1) {
329                                         printf("Switching to port %d\n", port);
330                                 }
331                                 addr.sin_port = htons(port);
332                         }
333                 }
334         } while(1);
335
336         ret = 0;
337
338 cleanup:
339         if (fd >= 0) {
340                 close(fd);
341         }
342
343         if (sock >= 0) {
344 #ifndef NMRPFLASH_WINDOWS
345                 shutdown(sock, SHUT_RDWR);
346                 close(sock);
347 #else
348                 shutdown(sock, SD_BOTH);
349                 closesocket(sock);
350 #endif
351         }
352
353         return ret;
354 }