TFTP fixes
[oweals/nmrpflash.git] / tftp.c
1 #define _BSD_SOURCE
2 #include <arpa/inet.h>
3 #include <sys/socket.h>
4 #include <net/if.h>
5 #include <string.h>
6 #include <unistd.h>
7 #include <stdio.h>
8 #include <errno.h>
9 #include <fcntl.h>
10
11 #define TFTP_PKT_SIZE 516
12
13 static const char *opcode_names[] = { 
14         "RRQ", "WRQ", "DATA", "ACK", "ERR"
15 };
16
17 enum tftp_opcode {
18         RRQ  = 1,
19         WRQ  = 2,
20         DATA = 3,
21         ACK  = 4,
22         ERR  = 5
23 };
24
25 static inline void pkt_mknum(char *pkt, uint16_t n)
26 {
27         *(uint16_t*)pkt = htons(n);
28 }
29
30 static inline uint16_t pkt_num(char *pkt)
31 {
32         return ntohs(*(uint16_t*)pkt);
33 }
34
35 static void pkt_mkwrq(char *pkt, const char *filename, const char *mode)
36 {
37         size_t len = 2;
38
39         pkt_mknum(pkt, WRQ);
40
41         strcpy(pkt + len, filename);
42         len += strlen(filename) + 1;
43         strcpy(pkt + len, mode);
44         len += strlen(mode) + 1;
45 }
46
47 static inline void pkt_print(char *pkt, FILE *fp)
48 {
49         uint16_t opcode = pkt_num(pkt);
50         if (!opcode || opcode > ERR) {
51                 fprintf(fp, "(%d)", opcode);
52         } else {
53                 fprintf(fp, "%s", opcode_names[opcode - 1]);
54                 if (opcode == ACK || opcode == DATA) {
55                         fprintf(fp, "(%d)", pkt_num(pkt + 2));
56                 } else if (opcode == WRQ) {
57                         fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1);
58                 }
59         }
60 }
61
62 static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src)
63 {
64         socklen_t socklen;
65         ssize_t len;
66
67         (void)src, (void)socklen;
68
69         len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, NULL, NULL);
70         if (len < 0) {
71                 if (errno != EAGAIN) {
72                         perror("recvfrom");
73                         return -1;
74                 }
75
76                 return -2;
77         }
78
79         uint16_t opcode = pkt_num(pkt);
80
81         if (opcode == ERR) {
82                 fprintf(stderr, "Error (%d): %.511s\n", pkt_num(pkt + 2), pkt + 4);
83                 return -1;
84         } else if (!opcode || opcode > ERR) {
85                 /* The EX2700 I've tested this on sends a raw TFTP packet with no
86                  * opcode, and an error message starting at offset 0.
87                  */
88                 fprintf(stderr, "Error: %.32s\n", pkt);
89                 return -3;
90         }
91
92         return len;
93 }
94
95 static ssize_t tftp_sendto(int sock, char *pkt, size_t len, 
96                 struct sockaddr_in *dst)
97 {
98         ssize_t sent;
99
100         switch (pkt_num(pkt)) {
101                 case RRQ:
102                 case WRQ:
103                         len = 2 + strlen(pkt + 2) + 1;
104                         len += strlen(pkt + len) + 1;
105                         break;
106                 case DATA:
107                         len += 4;
108                         break;
109                 case ACK:
110                         len = 4;
111                         break;
112                 case ERR:
113                         len = 4 + strlen(pkt + 4);
114                         break;
115                 default:
116                         fprintf(stderr, "Error: Invalid packet ");
117                         pkt_print(pkt, stderr);
118                         return -1;
119         }
120
121         sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
122         if (sent < 0) {
123                 perror("sendto");
124         }
125
126         return sent;
127 }
128
129 int sock_set_rx_timeout(int fd, unsigned msec)
130 {
131         struct timeval tv;
132
133         if (msec) {
134                 tv.tv_usec = (msec % 1000) * 1000;
135                 tv.tv_sec = msec / 1000;
136                 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) {
137                         perror("setsockopt(SO_RCVTIMEO)");
138                         return 1;
139                 }
140         }
141
142         return 0;
143 }
144
145 int tftp_put(const char *filename, const char *ipaddr, uint16_t port)
146 {
147         struct sockaddr_in addr;
148         uint16_t block;
149         ssize_t len;
150         int fd, sock, err, timeout, last_len;
151         char rx[TFTP_PKT_SIZE], tx[TFTP_PKT_SIZE];
152
153         fd = open(filename, O_RDONLY);
154         if (fd < 0) {
155                 perror("open");
156                 err = fd;
157                 goto cleanup;
158         }
159
160         sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
161         if (sock < 0) {
162                 perror("socket");
163                 err = sock;
164                 goto cleanup;
165         }
166
167         err = sock_set_rx_timeout(sock, 999);
168         if (err) {
169                 goto cleanup;
170         }
171
172         err = !inet_aton(ipaddr, &addr.sin_addr);
173         if (err) {
174                 perror("inet_aton");
175                 goto cleanup;
176         }
177
178         addr.sin_family = AF_INET;
179         addr.sin_port = htons(port);
180
181         pkt_mkwrq(tx, filename, "octet");
182
183         len = tftp_sendto(sock, tx, 0, &addr);
184         if (len < 0) {
185                 err = len;
186                 goto cleanup;
187         }
188
189         len = tftp_recvfrom(sock, rx, &addr);
190         if (len < 0) {
191                 err = len;
192                 goto cleanup;
193         }
194
195         timeout = 0;
196         block = 0;
197         last_len = -1;
198
199         do {
200                 if (timeout || (pkt_num(rx) == ACK && pkt_num(rx + 2) == block)) {
201                         if (!timeout) {
202                                 ++block;
203                                 pkt_mknum(tx, DATA);
204                                 pkt_mknum(tx + 2, block);
205                                 len = read(fd, tx + 4, 512);
206                                 if (len < 0) {
207                                         perror("read");
208                                         err = len;
209                                         goto cleanup;
210                                 } else if (!len) {
211                                         if (last_len != 512) {
212                                                 break;
213                                         }
214                                 }
215
216                                 last_len = len;
217                         }
218
219                         err = tftp_sendto(sock, tx, len, &addr);
220                         if (err < 0) {
221                                 goto cleanup;
222                         }
223                 } else if (pkt_num(rx) != ACK) {
224                         fprintf(stderr, "Expected ACK(%d), got ", block);
225                         pkt_print(rx, stderr);
226                         fprintf(stderr, "!\n");
227                 }
228
229                 err = tftp_recvfrom(sock, rx, &addr);
230                 if (err < 0) {
231                         if (err == -2) {
232                                 if (++timeout < 5) {
233                                         continue;
234                                 }
235                                 fprintf(stderr, "Timeout while waiting for ACK(%d)\n.", block);
236                         }
237                         goto cleanup;
238                 } else {
239                         timeout = 0;
240                 }
241         } while(1);
242
243         err = 0;
244
245 cleanup:
246         if (fd >= 0) {
247                 close(fd);
248         }
249
250         if (sock >= 0) {
251                 close(sock);
252         }
253
254         return err;
255 }