Working TFTP upload
[oweals/nmrpflash.git] / tftp.c
1 #define _BSD_SOURCE
2 #include <arpa/inet.h>
3 #include <sys/socket.h>
4 #include <net/if.h>
5 #include <string.h>
6 #include <unistd.h>
7 #include <stdio.h>
8 #include <errno.h>
9 #include <fcntl.h>
10
11 #define TFTP_PKT_SIZE 516
12
13 static const char *opcode_names[] = { 
14         "RRQ", "WRQ", "DATA", "ACK", "ERR"
15 };
16
17 enum tftp_opcode {
18         RRQ  = 1,
19         WRQ  = 2,
20         DATA = 3,
21         ACK  = 4,
22         ERR  = 5,
23         NETGEAR_ERR = 0x4669
24 };
25
26 static inline void pkt_mknum(char *pkt, uint16_t n)
27 {
28         *(uint16_t*)pkt = htons(n);
29 }
30
31 static inline uint16_t pkt_num(char *pkt)
32 {
33         return ntohs(*(uint16_t*)pkt);
34 }
35
36 static void pkt_mkwrq(char *pkt, const char *filename, const char *mode)
37 {
38         size_t len = 2;
39
40         pkt_mknum(pkt, WRQ);
41
42         strcpy(pkt + len, filename);
43         len += strlen(filename) + 1;
44         strcpy(pkt + len, mode);
45         len += strlen(mode) + 1;
46 }
47
48 static inline void pkt_print(char *pkt, FILE *fp)
49 {
50         uint16_t opcode = pkt_num(pkt);
51         if (!opcode || opcode > ERR) {
52                 fprintf(fp, "(%d)", opcode);
53         } else {
54                 fprintf(fp, "%s", opcode_names[opcode - 1]);
55                 if (opcode == ACK || opcode == DATA) {
56                         fprintf(fp, "(%d)", pkt_num(pkt + 2));
57                 } else if (opcode == WRQ) {
58                         fprintf(fp, "(%s, %s)", pkt + 2, pkt + 2 + strlen(pkt + 2) + 1);
59                 }
60         }
61 }
62
63 static ssize_t tftp_recvfrom(int sock, char *pkt, struct sockaddr_in *src)
64 {
65         static int fail = 0;
66         socklen_t socklen;
67         ssize_t len;
68
69         len = recvfrom(sock, pkt, TFTP_PKT_SIZE, 0, NULL, NULL);
70         if (len < 0) {
71                 if (errno != EAGAIN) {
72                         perror("recvfrom");
73                         return -1;
74                 }
75
76                 return -2;
77         }
78
79         uint16_t opcode = pkt_num(pkt);
80
81         if (opcode == ERR) {
82                 fprintf(stderr, "Error (%d): %.511s\n", pkt_num(pkt + 2), pkt + 4);
83                 return -1;
84         } else if (!opcode || opcode > ERR) {
85                 /* The EX2700 I've tested this on sends a raw TFTP packet with no
86                  * opcode, and an error message starting at offset 0.
87                  */
88                 fprintf(stderr, "Error: %.32s\n", pkt);
89                 return -3;
90         }
91
92         return len;
93 }
94
95 static ssize_t tftp_sendto(int sock, char *pkt, size_t len, 
96                 struct sockaddr_in *dst)
97 {
98         ssize_t sent;
99
100         switch (pkt_num(pkt)) {
101                 case RRQ:
102                 case WRQ:
103                         len = 2 + strlen(pkt + 2) + 1;
104                         len += strlen(pkt + len) + 1;
105                         break;
106                 case DATA:
107                         len += 4;
108                         break;
109                 case ACK:
110                         len = 4;
111                         break;
112                 case ERR:
113                         len = 4 + strlen(pkt + 4);
114                         break;
115                 default:
116                         fprintf(stderr, "Error: Invalid packet ");
117                         pkt_print(pkt, stderr);
118                         return -1;
119         }
120
121         sent = sendto(sock, pkt, len, 0, (struct sockaddr*)dst, sizeof(*dst));
122         if (sent < 0) {
123                 perror("sendto");
124         }
125
126         return sent;
127 }
128
129 int sock_set_rx_timeout(int fd, unsigned msec)
130 {
131         struct timeval tv;
132
133         if (msec) {
134                 tv.tv_usec = (msec % 1000) * 1000;
135                 tv.tv_sec = msec / 1000;
136                 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) {
137                         perror("setsockopt(SO_RCVTIMEO)");
138                         return 1;
139                 }
140         }
141
142         return 0;
143 }
144
145 int tftp_put(const char *filename, const char *ipaddr, uint16_t port)
146 {
147         struct sockaddr_in dst, src;
148         enum tftp_opcode opcode;
149         struct timeval tv;
150         uint16_t block;
151         ssize_t len;
152         int fd, sock, err, done, i, last_len;
153         char pkt[TFTP_PKT_SIZE];
154
155         fd = open(filename, O_RDONLY);
156         if (fd < 0) {
157                 perror("open");
158                 err = fd;
159                 goto cleanup;
160         }
161
162         sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
163         if (sock < 0) {
164                 perror("socket");
165                 err = sock;
166                 goto cleanup;
167         }
168
169         err = sock_set_rx_timeout(sock, 999);
170         if (err) {
171                 goto cleanup;
172         }
173
174         err = !inet_aton(ipaddr, &dst.sin_addr);
175         if (err) {
176                 perror("inet_aton");
177                 goto cleanup;
178         }
179
180         dst.sin_family = AF_INET;
181         dst.sin_port = htons(port);
182
183         pkt_mkwrq(pkt, filename, "octet");
184
185         len = tftp_sendto(sock, pkt, 0, &dst);
186         if (len < 0) {
187                 err = len;
188                 goto cleanup;
189         }
190
191         len = tftp_recvfrom(sock, pkt, &dst);
192         if (len < 0) {
193                 err = len;
194                 goto cleanup;
195         }
196
197         //dst.sin_port = src.sin_port;
198
199         block = 0;
200         done = 0;
201         last_len = -1;
202
203         do {
204                 if (pkt_num(pkt) == ACK && pkt_num(pkt + 2) == block) {
205                         ++block;
206                         pkt_mknum(pkt, DATA);
207                         pkt_mknum(pkt + 2, block);
208                         len = read(fd, pkt + 4, 512);
209                         if (len < 0) {
210                                 perror("read");
211                                 err = len;
212                                 goto cleanup;
213                         } else if (!len) {
214                                 done = last_len != 512;
215                         }
216
217                         last_len = len;
218
219                         len = tftp_sendto(sock, pkt, len, &dst);
220                         if (len < 0) {
221                                 err = len;
222                                 goto cleanup;
223                         }
224                 } else {
225                         fprintf(stderr, "Expected ACK(%d), got ", block);
226                         pkt_print(pkt, stderr);
227                         fprintf(stderr, "!\n");
228                         err = 1;
229                         goto cleanup;
230                 }
231
232                 len = tftp_recvfrom(sock, pkt, &dst);
233                 if (len < 0) {
234                         if (len == -2) {
235                                 fprintf(stderr, "Timeout while waiting for ACK(%d).\n", block);
236                         }
237                         err = len;
238                         goto cleanup;
239                 }
240         } while(!done);
241
242         err = 0;
243
244 cleanup:
245         if (fd >= 0) {
246                 close(fd);
247         }
248
249         if (sock >= 0) {
250                 close(sock);
251         }
252
253         return err;
254 }