19280d6bbd133c0522acd4b16c3b04cb83f67d6a
[oweals/u-boot_mod.git] / u-boot / net / tftp.c
1 /*
2  *      Copyright 1994, 1995, 2000 Neil Russell.
3  *      (See License)
4  *      Copyright 2000, 2001 DENX Software Engineering, Wolfgang Denk, wd@denx.de
5  */
6
7 #include <common.h>
8 #include <command.h>
9 #include <net.h>
10 #include "tftp.h"
11 #include "bootp.h"
12
13 #undef  ET_DEBUG
14
15 #if defined(CONFIG_CMD_NET)
16
17 DECLARE_GLOBAL_DATA_PTR;
18
19 #define WELL_KNOWN_PORT         69              /* Well known TFTP port # */
20 #define TIMEOUT                         5               /* Seconds to timeout for a lost pkt */
21
22 #ifndef CONFIG_NET_RETRY_COUNT
23         #define TIMEOUT_COUNT   10              /* # of timeouts before giving up */
24 #else
25         #define TIMEOUT_COUNT  (CONFIG_NET_RETRY_COUNT * 2)
26 #endif
27
28 /* (for checking the image size)        */
29 #define HASHES_PER_LINE         40              /* Number of "loading" hashes per line */
30
31 /*
32  *      TFTP operations.
33  */
34 #define TFTP_RRQ        1
35 #define TFTP_WRQ        2
36 #define TFTP_DATA       3
37 #define TFTP_ACK        4
38 #define TFTP_ERROR      5
39 #define TFTP_OACK       6
40
41
42 static int TftpServerPort;                      /* The UDP port at their end */
43 static int TftpOurPort;                         /* The UDP port at our end */
44 static int TftpTimeoutCount;
45 static int TftpState;
46
47 static ulong TftpBlock;                         /* packet sequence number */
48 static ulong TftpLastBlock;                     /* last packet sequence number received */
49 static ulong TftpBlockWrap;                     /* count of sequence number wraparounds */
50 static ulong TftpBlockWrapOffset;       /* memory offset due to wrapping */
51
52 #define STATE_RRQ                       1
53 #define STATE_DATA                      2
54 #define STATE_TOO_LARGE         3
55 #define STATE_BAD_MAGIC         4
56 #define STATE_OACK                      5
57
58 #define TFTP_BLOCK_SIZE         512                             /* default TFTP block size */
59 #define TFTP_SEQUENCE_SIZE      ((ulong)(1<<16))    /* sequence number is 16 bit */
60
61 #define DEFAULT_NAME_LEN        (8 + 4 + 1)
62
63 static char default_filename[DEFAULT_NAME_LEN];
64 static char *tftp_filename;
65
66 #ifdef CFG_DIRECT_FLASH_TFTP
67         extern flash_info_t flash_info[];
68 #endif
69
70 static __inline__ void store_block(unsigned block, uchar * src, unsigned len){
71         ulong offset = block * TFTP_BLOCK_SIZE + TftpBlockWrapOffset;
72         ulong newsize = offset + len;
73
74 #ifdef CFG_DIRECT_FLASH_TFTP
75         int i, rc = 0;
76
77         for(i=0; i<CFG_MAX_FLASH_BANKS; i++){
78                 /* start address in flash? */
79                 if(load_addr + offset >= flash_info[i].start[0]){
80                         rc = 1;
81                         break;
82                 }
83         }
84
85         if(rc){ /* Flash is destination for this packet */
86                 rc = flash_write((char *)src, (ulong)(load_addr+offset), len);
87
88                 if(rc){
89                         flash_perror(rc);
90                         NetState = NETLOOP_FAIL;
91                         return;
92                 }
93         } else
94 #endif /* CFG_DIRECT_FLASH_TFTP */
95         {
96                 (void)memcpy((void *)(load_addr + offset), src, len);
97         }
98
99         if(NetBootFileXferSize < newsize){
100                 NetBootFileXferSize = newsize;
101         }
102 }
103
104 static void TftpSend(void);
105 static void TftpTimeout(void);
106
107 /**********************************************************************/
108
109 static void TftpSend(void){
110         volatile uchar *pkt;
111         volatile uchar *xp;
112         int len = 0;
113         volatile ushort *s;
114
115         /*
116          *      We will always be sending some sort of packet, so
117          *      cobble together the packet headers now.
118          */
119         pkt = NetTxPacket + NetEthHdrSize() + IP_HDR_SIZE;
120
121         switch(TftpState){
122                 case STATE_RRQ:
123                         xp = pkt;
124                         s = (ushort *)pkt;
125                         *s++ = htons(TFTP_RRQ);
126
127                         pkt = (uchar *)s;
128                         strcpy ((char *)pkt, tftp_filename);
129
130                         pkt += strlen(tftp_filename) + 1;
131                         strcpy ((char *)pkt, "octet");
132
133                         pkt += 5 /*strlen("octet")*/ + 1;
134                         strcpy ((char *)pkt, "timeout");
135
136                         pkt += 7 /*strlen("timeout")*/ + 1;
137                         sprintf((char *)pkt, "%d", TIMEOUT);
138         #ifdef ET_DEBUG
139                         printf("send option \"timeout %s\"\n", (char *)pkt);
140         #endif
141                         pkt += strlen((char *)pkt) + 1;
142                         len = pkt - xp;
143                         break;
144
145                 case STATE_DATA:
146                 case STATE_OACK:
147                         xp = pkt;
148                         s = (ushort *)pkt;
149                         *s++ = htons(TFTP_ACK);
150                         *s++ = htons(TftpBlock);
151                         pkt = (uchar *)s;
152                         len = pkt - xp;
153                         break;
154
155                 case STATE_TOO_LARGE:
156                         xp = pkt;
157                         s = (ushort *)pkt;
158                         *s++ = htons(TFTP_ERROR);
159                         *s++ = htons(3);
160                         pkt = (uchar *)s;
161                         strcpy((char *)pkt, "File too large");
162                         pkt += 14 /*strlen("File too large")*/ + 1;
163                         len = pkt - xp;
164                         break;
165
166                 case STATE_BAD_MAGIC:
167                         xp = pkt;
168                         s = (ushort *)pkt;
169                         *s++ = htons(TFTP_ERROR);
170                         *s++ = htons(2);
171                         pkt = (uchar *)s;
172                         strcpy((char *)pkt, "File has bad magic");
173                         pkt += 18 /*strlen("File has bad magic")*/ + 1;
174                         len = pkt - xp;
175                         break;
176         }
177
178         NetSendUDPPacket(NetServerEther, NetServerIP, TftpServerPort, TftpOurPort, len);
179 }
180
181 static void TftpHandler(uchar * pkt, unsigned dest, unsigned src, unsigned len){
182         bd_t *bd = gd->bd;
183         ushort proto;
184         ushort *s;
185
186         if(dest != TftpOurPort){
187                 return;
188         }
189
190         if(TftpState != STATE_RRQ && src != TftpServerPort){
191                 return;
192         }
193
194         if(len < 2){
195                 return;
196         }
197
198         len -= 2;
199
200         /* warning: don't use increment (++) in ntohs() macros!! */
201         s = (ushort *)pkt;
202         proto = *s++;
203         pkt = (uchar *)s;
204
205         switch(ntohs(proto)){
206                 case TFTP_RRQ:
207                 case TFTP_WRQ:
208                 case TFTP_ACK:
209                         break;
210
211                 default:
212                         break;
213
214                 case TFTP_OACK:
215         #ifdef ET_DEBUG
216                         printf("Got OACK: %s %s\n", pkt, pkt+strlen(pkt)+1);
217         #endif
218                         TftpState = STATE_OACK;
219                         TftpServerPort = src;
220                         TftpSend(); /* Send ACK */
221                         break;
222
223                 // TFTP DATA PACKET
224                 case TFTP_DATA:
225                         if(len < 2){
226                                 return;
227                         }
228
229                         len -= 2;
230                         TftpBlock = ntohs(*(ushort *)pkt);
231
232                         /*
233                          * RFC1350 specifies that the first data packet will
234                          * have sequence number 1. If we receive a sequence
235                          * number of 0 this means that there was a wrap
236                          * around of the (16 bit) counter.
237                          */
238                         if(TftpBlock == 0){
239                                 TftpBlockWrap++;
240                                 TftpBlockWrapOffset += TFTP_BLOCK_SIZE * TFTP_SEQUENCE_SIZE;
241                                 printf("\n         %lu MB received\n         ", TftpBlockWrapOffset>>20);
242                         } else {
243                                 if(((TftpBlock - 1) % 10) == 0){
244                                         putc('#');
245                                 } else if((TftpBlock % (10 * HASHES_PER_LINE)) == 0){
246                                         puts("\n              ");
247                                 }
248                         }
249
250         #ifdef ET_DEBUG
251                         if(TftpState == STATE_RRQ){
252                                 puts("## Error: server did not acknowledge timeout option!\n");
253                         }
254         #endif
255
256                         if(TftpState == STATE_RRQ || TftpState == STATE_OACK){
257                                 /* first block received */
258                                 TftpState = STATE_DATA;
259                                 TftpServerPort = src;
260                                 TftpLastBlock = 0;
261                                 TftpBlockWrap = 0;
262                                 TftpBlockWrapOffset = 0;
263
264                                 if(TftpBlock != 1){     /* Assertion */
265                                         printf("\n## Error: first block is not block 1 (%ld), starting again!\n\n", TftpBlock);
266                                         NetStartAgain();
267                                         break;
268                                 }
269                         }
270
271                         if(TftpBlock == TftpLastBlock){
272                                 /*
273                                  *      Same block again; ignore it.
274                                  */
275                                 break;
276                         }
277
278                         TftpLastBlock = TftpBlock;
279                         NetSetTimeout(TIMEOUT * CFG_HZ, TftpTimeout);
280
281                         store_block(TftpBlock - 1, pkt + 2, len);
282
283                         /*
284                          *      Acknoledge the block just received, which will prompt
285                          *      the server for the next one.
286                          */
287                         TftpSend();
288
289                         if(len < TFTP_BLOCK_SIZE){
290                                 /*
291                                  *      We received the whole thing.  Try to
292                                  *      run it.
293                                  */
294                                 puts("\n\nTFTP transfer complete!\n");
295                                 NetState = NETLOOP_SUCCESS;
296                         }
297
298                         break;
299
300                 case TFTP_ERROR:
301                         printf("\n## Error: '%s' (%d), starting again!\n\n", pkt + 2, ntohs(*(ushort *)pkt));
302                         NetStartAgain();
303                         break;
304         }
305 }
306
307 static void TftpTimeout(void){
308         bd_t *bd = gd->bd;
309
310         if(++TftpTimeoutCount > TIMEOUT_COUNT){
311                 puts("\n\n## Error: retry count exceeded, starting again!\n\n");
312                 NetStartAgain();
313         } else {
314                 puts("T ");
315                 NetSetTimeout(TIMEOUT * CFG_HZ, TftpTimeout);
316                 TftpSend();
317         }
318 }
319
320 void TftpStart(void){
321         bd_t *bd = gd->bd;
322
323 #ifdef CONFIG_TFTP_PORT
324         char *ep; /* Environment pointer */
325 #endif
326
327         if(BootFile[0] == '\0'){
328                 sprintf(default_filename, "%02lX%02lX%02lX%02lX.img", NetOurIP & 0xFF, (NetOurIP >> 8) & 0xFF,  (NetOurIP >> 16) & 0xFF, (NetOurIP >> 24) & 0xFF);
329                 tftp_filename = default_filename;
330
331                 printf("** Warning: no boot file name, using: '%s'\n", tftp_filename);
332         } else {
333                 tftp_filename = BootFile;
334         }
335
336         puts("\nTFTP from IP: ");
337         print_IPaddr(NetServerIP);
338
339         puts("\n      Our IP: ");
340         print_IPaddr(NetOurIP);
341
342         /* Check if we need to send across this subnet */
343         if(NetOurGatewayIP && NetOurSubnetMask){
344             IPaddr_t OurNet     = NetOurIP    & NetOurSubnetMask;
345             IPaddr_t ServerNet  = NetServerIP & NetOurSubnetMask;
346
347             if(OurNet != ServerNet){
348                 puts("\n  Gateway IP: ");
349                 print_IPaddr(NetOurGatewayIP) ;
350             }
351         }
352
353         printf("\n    Filename: '%s'", tftp_filename);
354
355         if(NetBootFileSize){
356                 printf("\n        Size: 0x%x Bytes = ", NetBootFileSize<<9);
357                 print_size(NetBootFileSize<<9, "");
358         }
359
360         printf("\nLoad address: 0x%lx", load_addr);
361
362 #if defined(CONFIG_NET_MULTI)
363         printf("\n       Using: %s", eth_get_name());
364 #endif
365
366         puts("\n\n     Loading: *\b");
367
368         NetSetTimeout(TIMEOUT * CFG_HZ, TftpTimeout);
369         NetSetHandler(TftpHandler);
370
371         TftpServerPort = WELL_KNOWN_PORT;
372         TftpTimeoutCount = 0;
373         TftpState = STATE_RRQ;
374
375         /* Use a pseudo-random port unless a specific port is set */
376         TftpOurPort = 1024 + (get_timer(0) % 3072);
377
378 #ifdef CONFIG_TFTP_PORT
379         if((ep = getenv("tftpdstp")) != NULL){
380                 TftpServerPort = simple_strtol(ep, NULL, 10);
381         }
382         if((ep = getenv("tftpsrcp")) != NULL){
383                 TftpOurPort= simple_strtol(ep, NULL, 10);
384         }
385 #endif
386
387         TftpBlock = 0;
388
389         /* zero out server ether in case the server ip has changed */
390         memset(NetServerEther, 0, 6);
391
392         TftpSend();
393 }
394
395 #endif /* CONFIG_CMD_NET */