Add more robust error checking to deSerialize*String routines
[oweals/minetest.git] / src / util / serialize.cpp
1 /*
2 Minetest
3 Copyright (C) 2010-2013 celeron55, Perttu Ahola <celeron55@gmail.com>
4
5 This program is free software; you can redistribute it and/or modify
6 it under the terms of the GNU Lesser General Public License as published by
7 the Free Software Foundation; either version 2.1 of the License, or
8 (at your option) any later version.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 GNU Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public License along
16 with this program; if not, write to the Free Software Foundation, Inc.,
17 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "serialize.h"
21 #include "pointer.h"
22 #include "porting.h"
23 #include "util/string.h"
24 #include "../exceptions.h"
25 #include "../irrlichttypes.h"
26
27 #include <sstream>
28 #include <iomanip>
29 #include <vector>
30
31 ////
32 //// String
33 ////
34
35 std::string serializeString(const std::string &plain)
36 {
37         std::string s;
38         char buf[2];
39
40         if (plain.size() > 65535)
41                 throw SerializationError("String too long for serializeString");
42
43         writeU16((u8 *)&buf[0], plain.size());
44         s.append(buf, 2);
45
46         s.append(plain);
47         return s;
48 }
49
50 std::string deSerializeString(std::istream &is)
51 {
52         std::string s;
53         char buf[2];
54
55         is.read(buf, 2);
56         if (is.gcount() != 2)
57                 throw SerializationError("deSerializeString: size not read");
58
59         u16 s_size = readU16((u8 *)buf);
60         if (s_size == 0)
61                 return s;
62
63         Buffer<char> buf2(s_size);
64         is.read(&buf2[0], s_size);
65         if (is.gcount() != s_size)
66                 throw SerializationError("deSerializeString: couldn't read all chars");
67
68         s.reserve(s_size);
69         s.append(&buf2[0], s_size);
70         return s;
71 }
72
73 ////
74 //// Wide String
75 ////
76
77 std::string serializeWideString(const std::wstring &plain)
78 {
79         std::string s;
80         char buf[2];
81
82         if (plain.size() > 65535)
83                 throw SerializationError("String too long for serializeString");
84
85         writeU16((u8 *)buf, plain.size());
86         s.append(buf, 2);
87
88         for (u32 i = 0; i < plain.size(); i++) {
89                 writeU16((u8 *)buf, plain[i]);
90                 s.append(buf, 2);
91         }
92         return s;
93 }
94
95 std::wstring deSerializeWideString(std::istream &is)
96 {
97         std::wstring s;
98         char buf[2];
99
100         is.read(buf, 2);
101         if (is.gcount() != 2)
102                 throw SerializationError("deSerializeString: size not read");
103
104         u16 s_size = readU16((u8 *)buf);
105         if (s_size == 0)
106                 return s;
107
108         s.reserve(s_size);
109         for (u32 i = 0; i < s_size; i++) {
110                 is.read(&buf[0], 2);
111                 if (is.gcount() != 2) {
112                         throw SerializationError(
113                                 "deSerializeWideString: couldn't read all chars");
114                 }
115
116                 wchar_t c16 = readU16((u8 *)buf);
117                 s.append(&c16, 1);
118         }
119         return s;
120 }
121
122 ////
123 //// Long String
124 ////
125
126 std::string serializeLongString(const std::string &plain)
127 {
128         char buf[4];
129         writeU32((u8*)&buf[0], plain.size());
130         std::string s;
131         s.append(buf, 4);
132         s.append(plain);
133         return s;
134 }
135
136 std::string deSerializeLongString(std::istream &is)
137 {
138         std::string s;
139         char buf[4];
140
141         is.read(buf, 4);
142         if (is.gcount() != 4)
143                 throw SerializationError("deSerializeLongString: size not read");
144
145         u32 s_size = readU32((u8 *)buf);
146         if (s_size == 0)
147                 return s;
148
149         // We don't really want a remote attacker to force us to allocate 4GB...
150         if (s_size > LONG_STRING_MAX)
151                 throw SerializationError("deSerializeLongString: string too long");
152
153         Buffer<char> buf2(s_size);
154         is.read(&buf2[0], s_size);
155         if (is.gcount() != s_size)
156                 throw SerializationError("deSerializeString: couldn't read all chars");
157
158         s.reserve(s_size);
159         s.append(&buf2[0], s_size);
160         return s;
161 }
162
163 ////
164 //// JSON
165 ////
166
167 std::string serializeJsonString(const std::string &plain)
168 {
169         std::ostringstream os(std::ios::binary);
170         os << "\"";
171
172         for (size_t i = 0; i < plain.size(); i++) {
173                 char c = plain[i];
174                 switch (c) {
175                         case '"':
176                                 os << "\\\"";
177                                 break;
178                         case '\\':
179                                 os << "\\\\";
180                                 break;
181                         case '/':
182                                 os << "\\/";
183                                 break;
184                         case '\b':
185                                 os << "\\b";
186                                 break;
187                         case '\f':
188                                 os << "\\f";
189                                 break;
190                         case '\n':
191                                 os << "\\n";
192                                 break;
193                         case '\r':
194                                 os << "\\r";
195                                 break;
196                         case '\t':
197                                 os << "\\t";
198                                 break;
199                         default: {
200                                 if (c >= 32 && c <= 126) {
201                                         os << c;
202                                 } else {
203                                         u32 cnum = (u8)c;
204                                         os << "\\u" << std::hex << std::setw(4)
205                                                 << std::setfill('0') << cnum;
206                                 }
207                                 break;
208                         }
209                 }
210         }
211
212         os << "\"";
213         return os.str();
214 }
215
216 std::string deSerializeJsonString(std::istream &is)
217 {
218         std::ostringstream os(std::ios::binary);
219         char c, c2;
220
221         // Parse initial doublequote
222         is >> c;
223         if (c != '"')
224                 throw SerializationError("JSON string must start with doublequote");
225
226         // Parse characters
227         for (;;) {
228                 c = is.get();
229                 if (is.eof())
230                         throw SerializationError("JSON string ended prematurely");
231
232                 if (c == '"') {
233                         return os.str();
234                 } else if (c == '\\') {
235                         c2 = is.get();
236                         if (is.eof())
237                                 throw SerializationError("JSON string ended prematurely");
238                         switch (c2) {
239                                 case 'b':
240                                         os << '\b';
241                                         break;
242                                 case 'f':
243                                         os << '\f';
244                                         break;
245                                 case 'n':
246                                         os << '\n';
247                                         break;
248                                 case 'r':
249                                         os << '\r';
250                                         break;
251                                 case 't':
252                                         os << '\t';
253                                         break;
254                                 case 'u': {
255                                         int hexnumber;
256                                         char hexdigits[4 + 1];
257
258                                         is.read(hexdigits, 4);
259                                         if (is.eof())
260                                                 throw SerializationError("JSON string ended prematurely");
261                                         hexdigits[4] = 0;
262
263                                         std::istringstream tmp_is(hexdigits, std::ios::binary);
264                                         tmp_is >> std::hex >> hexnumber;
265                                         os << (char)hexnumber;
266                                         break;
267                                 }
268                                 default:
269                                         os << c2;
270                                         break;
271                         }
272                 } else {
273                         os << c;
274                 }
275         }
276
277         return os.str();
278 }
279
280 ////
281 //// String/Struct conversions
282 ////
283
284 bool deSerializeStringToStruct(std::string valstr,
285         std::string format, void *out, size_t olen)
286 {
287         size_t len = olen;
288         std::vector<std::string *> strs_alloced;
289         std::string *str;
290         char *f, *snext;
291         size_t pos;
292
293         char *s = &valstr[0];
294         char *buf = new char[len];
295         char *bufpos = buf;
296
297         char *fmtpos, *fmt = &format[0];
298         while ((f = strtok_r(fmt, ",", &fmtpos)) && s) {
299                 fmt = NULL;
300
301                 bool is_unsigned = false;
302                 int width = 0;
303                 char valtype = *f;
304
305                 width = (int)strtol(f + 1, &f, 10);
306                 if (width && valtype == 's')
307                         valtype = 'i';
308
309                 switch (valtype) {
310                         case 'u':
311                                 is_unsigned = true;
312                                 /* FALLTHROUGH */
313                         case 'i':
314                                 if (width == 16) {
315                                         bufpos += PADDING(bufpos, u16);
316                                         if ((bufpos - buf) + sizeof(u16) <= len) {
317                                                 if (is_unsigned)
318                                                         *(u16 *)bufpos = (u16)strtoul(s, &s, 10);
319                                                 else
320                                                         *(s16 *)bufpos = (s16)strtol(s, &s, 10);
321                                         }
322                                         bufpos += sizeof(u16);
323                                 } else if (width == 32) {
324                                         bufpos += PADDING(bufpos, u32);
325                                         if ((bufpos - buf) + sizeof(u32) <= len) {
326                                                 if (is_unsigned)
327                                                         *(u32 *)bufpos = (u32)strtoul(s, &s, 10);
328                                                 else
329                                                         *(s32 *)bufpos = (s32)strtol(s, &s, 10);
330                                         }
331                                         bufpos += sizeof(u32);
332                                 } else if (width == 64) {
333                                         bufpos += PADDING(bufpos, u64);
334                                         if ((bufpos - buf) + sizeof(u64) <= len) {
335                                                 if (is_unsigned)
336                                                         *(u64 *)bufpos = (u64)strtoull(s, &s, 10);
337                                                 else
338                                                         *(s64 *)bufpos = (s64)strtoll(s, &s, 10);
339                                         }
340                                         bufpos += sizeof(u64);
341                                 }
342                                 s = strchr(s, ',');
343                                 break;
344                         case 'b':
345                                 snext = strchr(s, ',');
346                                 if (snext)
347                                         *snext++ = 0;
348
349                                 bufpos += PADDING(bufpos, bool);
350                                 if ((bufpos - buf) + sizeof(bool) <= len)
351                                         *(bool *)bufpos = is_yes(std::string(s));
352                                 bufpos += sizeof(bool);
353
354                                 s = snext;
355                                 break;
356                         case 'f':
357                                 bufpos += PADDING(bufpos, float);
358                                 if ((bufpos - buf) + sizeof(float) <= len)
359                                         *(float *)bufpos = strtof(s, &s);
360                                 bufpos += sizeof(float);
361
362                                 s = strchr(s, ',');
363                                 break;
364                         case 's':
365                                 while (*s == ' ' || *s == '\t')
366                                         s++;
367                                 if (*s++ != '"') //error, expected string
368                                         goto fail;
369                                 snext = s;
370
371                                 while (snext[0] && !(snext[-1] != '\\' && snext[0] == '"'))
372                                         snext++;
373                                 *snext++ = 0;
374
375                                 bufpos += PADDING(bufpos, std::string *);
376
377                                 str = new std::string(s);
378                                 pos = 0;
379                                 while ((pos = str->find("\\\"", pos)) != std::string::npos)
380                                         str->erase(pos, 1);
381
382                                 if ((bufpos - buf) + sizeof(std::string *) <= len)
383                                         *(std::string **)bufpos = str;
384                                 bufpos += sizeof(std::string *);
385                                 strs_alloced.push_back(str);
386
387                                 s = *snext ? snext + 1 : NULL;
388                                 break;
389                         case 'v':
390                                 while (*s == ' ' || *s == '\t')
391                                         s++;
392                                 if (*s++ != '(') //error, expected vector
393                                         goto fail;
394
395                                 if (width == 2) {
396                                         bufpos += PADDING(bufpos, v2f);
397
398                                         if ((bufpos - buf) + sizeof(v2f) <= len) {
399                                         v2f *v = (v2f *)bufpos;
400                                                 v->X = strtof(s, &s);
401                                                 s++;
402                                                 v->Y = strtof(s, &s);
403                                         }
404
405                                         bufpos += sizeof(v2f);
406                                 } else if (width == 3) {
407                                         bufpos += PADDING(bufpos, v3f);
408                                         if ((bufpos - buf) + sizeof(v3f) <= len) {
409                                                 v3f *v = (v3f *)bufpos;
410                                                 v->X = strtof(s, &s);
411                                                 s++;
412                                                 v->Y = strtof(s, &s);
413                                                 s++;
414                                                 v->Z = strtof(s, &s);
415                                         }
416
417                                         bufpos += sizeof(v3f);
418                                 }
419                                 s = strchr(s, ',');
420                                 break;
421                         default: //error, invalid format specifier
422                                 goto fail;
423                 }
424
425                 if (s && *s == ',')
426                         s++;
427
428                 if ((size_t)(bufpos - buf) > len) //error, buffer too small
429                         goto fail;
430         }
431
432         if (f && *f) { //error, mismatched number of fields and values
433 fail:
434                 for (size_t i = 0; i != strs_alloced.size(); i++)
435                         delete strs_alloced[i];
436                 delete[] buf;
437                 return false;
438         }
439
440         memcpy(out, buf, olen);
441         delete[] buf;
442         return true;
443 }
444
445 // Casts *buf to a signed or unsigned fixed-width integer of 'w' width
446 #define SIGN_CAST(w, buf) (is_unsigned ? *((u##w *) buf) : *((s##w *) buf))
447
448 bool serializeStructToString(std::string *out,
449         std::string format, void *value)
450 {
451         std::ostringstream os;
452         std::string str;
453         char *f;
454         size_t strpos;
455
456         char *bufpos = (char *) value;
457         char *fmtpos, *fmt = &format[0];
458         while ((f = strtok_r(fmt, ",", &fmtpos))) {
459                 fmt = NULL;
460                 bool is_unsigned = false;
461                 int width = 0;
462                 char valtype = *f;
463
464                 width = (int)strtol(f + 1, &f, 10);
465                 if (width && valtype == 's')
466                         valtype = 'i';
467
468                 switch (valtype) {
469                         case 'u':
470                                 is_unsigned = true;
471                                 /* FALLTHROUGH */
472                         case 'i':
473                                 if (width == 16) {
474                                         bufpos += PADDING(bufpos, u16);
475                                         os << SIGN_CAST(16, bufpos);
476                                         bufpos += sizeof(u16);
477                                 } else if (width == 32) {
478                                         bufpos += PADDING(bufpos, u32);
479                                         os << SIGN_CAST(32, bufpos);
480                                         bufpos += sizeof(u32);
481                                 } else if (width == 64) {
482                                         bufpos += PADDING(bufpos, u64);
483                                         os << SIGN_CAST(64, bufpos);
484                                         bufpos += sizeof(u64);
485                                 }
486                                 break;
487                         case 'b':
488                                 bufpos += PADDING(bufpos, bool);
489                                 os << std::boolalpha << *((bool *) bufpos);
490                                 bufpos += sizeof(bool);
491                                 break;
492                         case 'f':
493                                 bufpos += PADDING(bufpos, float);
494                                 os << *((float *) bufpos);
495                                 bufpos += sizeof(float);
496                                 break;
497                         case 's':
498                                 bufpos += PADDING(bufpos, std::string *);
499                                 str = **((std::string **) bufpos);
500
501                                 strpos = 0;
502                                 while ((strpos = str.find('"', strpos)) != std::string::npos) {
503                                         str.insert(strpos, 1, '\\');
504                                         strpos += 2;
505                                 }
506
507                                 os << str;
508                                 bufpos += sizeof(std::string *);
509                                 break;
510                         case 'v':
511                                 if (width == 2) {
512                                         bufpos += PADDING(bufpos, v2f);
513                                         v2f *v = (v2f *) bufpos;
514                                         os << '(' << v->X << ", " << v->Y << ')';
515                                         bufpos += sizeof(v2f);
516                                 } else {
517                                         bufpos += PADDING(bufpos, v3f);
518                                         v3f *v = (v3f *) bufpos;
519                                         os << '(' << v->X << ", " << v->Y << ", " << v->Z << ')';
520                                         bufpos += sizeof(v3f);
521                                 }
522                                 break;
523                         default:
524                                 return false;
525                 }
526                 os << ", ";
527         }
528         *out = os.str();
529
530         // Trim off the trailing comma and space
531         if (out->size() >= 2)
532                 out->resize(out->size() - 2);
533
534         return true;
535 }
536
537 #undef SIGN_CAST
538
539 ////
540 //// Other
541 ////
542
543 std::string serializeHexString(const std::string &data, bool insert_spaces)
544 {
545         std::string result;
546         result.reserve(data.size() * (2 + insert_spaces));
547
548         static const char hex_chars[] = "0123456789abcdef";
549
550         const size_t len = data.size();
551         for (size_t i = 0; i != len; i++) {
552                 u8 byte = data[i];
553                 result.push_back(hex_chars[(byte >> 4) & 0x0F]);
554                 result.push_back(hex_chars[(byte >> 0) & 0x0F]);
555                 if (insert_spaces && i != len - 1)
556                         result.push_back(' ');
557         }
558
559         return result;
560 }