Add GPLv3 notices
[oweals/nmrpflash.git] / tftp.c
1 /**
2  * nmrp-flash - Netgear Unbrick Utility
3  * Copyright (C) 2016 Joseph Lehner <joseph.c.lehner@gmail.com>
4  *
5  * nmrp-flash is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * nmrp-flash is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with nmrp-flash.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19
20 #define _BSD_SOURCE
21 #include <arpa/inet.h>
22 #include <sys/socket.h>
23 #include <net/if.h>
24 #include <string.h>
25 #include <unistd.h>
26 #include <stdio.h>
27 #include <errno.h>
28 #include <fcntl.h>
29 #include "nmrpd.h"
30
31 #define TFTP_PKT_SIZE 516
32
33 static const char *opcode_names[] = { 
34         "RRQ", "WRQ", "DATA", "ACK", "ERR"
35 };
36
37 enum tftp_opcode {
38         RRQ  = 1,
39         WRQ  = 2,
40         DATA = 3,
41         ACK  = 4,
42         ERR  = 5
43 };
44
45 static inline void pkt_mknum(char *pkt, uint16_t n)
46 {
47         *(uint16_t*)pkt = htons(n);
48 }
49
50 static inline uint16_t pkt_num(char *pkt)
51 {
52         return ntohs(*(uint16_t*)pkt);
53 }
54
55 static void pkt_mkwrq(char *pkt, const char *filename, const char *mode)
56 {
57         size_t len = 2;
58
59         pkt_mknum(pkt, WRQ);
60
61         strcpy(pkt + len, filename);
62         len += strlen(filename) + 1;
63         strcpy(pkt + len, mode);
64         len += strlen(mode) + 1;
65 }
66
67 static inline void pkt_print(char *pkt, FILE *fp)
68 {
69         uint16_t opcode = pkt_num(pkt);
70         if (!opcode || opcode > ERR) {
71                 fprintf(fp, "(%d)", opcode);
72         } else {
73                 fprintf(fp, "%s", opcode_names[opcode - 1]);
74                 if (opcode == ACK || opcode == DATA) {
75                         fprintf(fp, "(%d)", pkt_num(pkt + 2));
76                 } else if (opcode == WRQ) {
77                         fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1);
78                 }
79         }
80 }
81
82 static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src)
83 {
84         socklen_t socklen;
85         ssize_t len;
86
87         (void)src, (void)socklen;
88
89         len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, NULL, NULL);
90         if (len < 0) {
91                 if (errno != EAGAIN) {
92                         perror("recvfrom");
93                         return -1;
94                 }
95
96                 return -2;
97         }
98
99         uint16_t opcode = pkt_num(pkt);
100
101         if (opcode == ERR) {
102                 fprintf(stderr, "Error (%d): %.511s\n", pkt_num(pkt + 2), pkt + 4);
103                 return -1;
104         } else if (!opcode || opcode > ERR) {
105                 /* The EX2700 I've tested this on sends a raw TFTP packet with no
106                  * opcode, and an error message starting at offset 0.
107                  */
108                 fprintf(stderr, "Error: %.32s\n", pkt);
109                 return -3;
110         }
111
112         return len;
113 }
114
115 static ssize_t tftp_sendto(int sock, char *pkt, size_t len, 
116                 struct sockaddr_in *dst)
117 {
118         ssize_t sent;
119
120         switch (pkt_num(pkt)) {
121                 case RRQ:
122                 case WRQ:
123                         len = 2 + strlen(pkt + 2) + 1;
124                         len += strlen(pkt + len) + 1;
125                         break;
126                 case DATA:
127                         len += 4;
128                         break;
129                 case ACK:
130                         len = 4;
131                         break;
132                 case ERR:
133                         len = 4 + strlen(pkt + 4);
134                         break;
135                 default:
136                         fprintf(stderr, "Error: Invalid packet ");
137                         pkt_print(pkt, stderr);
138                         return -1;
139         }
140
141         sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
142         if (sent < 0) {
143                 perror("sendto");
144         }
145
146         return sent;
147 }
148
149 int sock_set_rx_timeout(int fd, unsigned msec)
150 {
151         struct timeval tv;
152
153         if (msec) {
154                 tv.tv_usec = (msec % 1000) * 1000;
155                 tv.tv_sec = msec / 1000;
156                 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) {
157                         perror("setsockopt(SO_RCVTIMEO)");
158                         return 1;
159                 }
160         }
161
162         return 0;
163 }
164
165 int tftp_put(struct nmrpd_args *args)
166 {
167         struct sockaddr_in addr;
168         uint16_t block;
169         ssize_t len;
170         int fd, sock, err, timeout, last_len;
171         char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE];
172
173         fd = open(args->filename, O_RDONLY);
174         if (fd < 0) {
175                 perror("open");
176                 err = fd;
177                 goto cleanup;
178         }
179
180         sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
181         if (sock < 0) {
182                 perror("socket");
183                 err = sock;
184                 goto cleanup;
185         }
186
187         err = sock_set_rx_timeout(sock, args->rx_timeout);
188         if (err) {
189                 goto cleanup;
190         }
191
192         err = !inet_aton(args->ipaddr, &addr.sin_addr);
193         if (err) {
194                 perror("inet_aton");
195                 goto cleanup;
196         }
197
198         addr.sin_family = AF_INET;
199         addr.sin_port = htons(args->port);
200
201         pkt_mkwrq(tx, args->filename, "octet");
202
203         len = tftp_sendto(sock, tx, 0, &addr);
204         if (len < 0) {
205                 err = len;
206                 goto cleanup;
207         }
208
209         len = tftp_recvfrom(sock, rx, &addr);
210         if (len < 0) {
211                 err = len;
212                 goto cleanup;
213         }
214
215         timeout = 0;
216         block = 0;
217         last_len = -1;
218
219         do {
220                 if (timeout || (pkt_num(rx) == ACK && pkt_num(rx + 2) == block)) {
221                         if (!timeout) {
222                                 ++block;
223                                 pkt_mknum(tx, DATA);
224                                 pkt_mknum(tx + 2, block);
225                                 len = read(fd, tx + 4, 512);
226                                 if (len < 0) {
227                                         perror("read");
228                                         err = len;
229                                         goto cleanup;
230                                 } else if (!len) {
231                                         if (last_len != 512) {
232                                                 break;
233                                         }
234                                 }
235
236                                 last_len = len;
237                         }
238
239                         err = tftp_sendto(sock, tx, len, &addr);
240                         if (err < 0) {
241                                 goto cleanup;
242                         }
243                 } else if (pkt_num(rx) != ACK) {
244                         fprintf(stderr, "Expected ACK(%d), got ", block);
245                         pkt_print(rx, stderr);
246                         fprintf(stderr, "!\n");
247                 }
248
249                 err = tftp_recvfrom(sock, rx, &addr);
250                 if (err < 0) {
251                         if (err == -2) {
252                                 if (++timeout < 5) {
253                                         continue;
254                                 }
255                                 fprintf(stderr, "Timeout while waiting for ACK(%d)\n.", block);
256                         }
257                         goto cleanup;
258                 } else {
259                         timeout = 0;
260                         err = 0;
261                 }
262         } while(1);
263
264         err = 0;
265
266 cleanup:
267         if (fd >= 0) {
268                 close(fd);
269         }
270
271         if (sock >= 0) {
272                 close(sock);
273         }
274
275         return err;
276 }