Merge pull request #2235 from dibdot/ovpn-upload
[oweals/luci.git] / modules / luci-base / luasrc / util.lua
1 -- Copyright 2008 Steven Barth <steven@midlink.org>
2 -- Licensed to the public under the Apache License 2.0.
3
4 local io = require "io"
5 local math = require "math"
6 local table = require "table"
7 local debug = require "debug"
8 local ldebug = require "luci.debug"
9 local string = require "string"
10 local coroutine = require "coroutine"
11 local tparser = require "luci.template.parser"
12 local json = require "luci.jsonc"
13 local lhttp = require "lucihttp"
14
15 local _ubus = require "ubus"
16 local _ubus_connection = nil
17
18 local getmetatable, setmetatable = getmetatable, setmetatable
19 local rawget, rawset, unpack, select = rawget, rawset, unpack, select
20 local tostring, type, assert, error = tostring, type, assert, error
21 local ipairs, pairs, next, loadstring = ipairs, pairs, next, loadstring
22 local require, pcall, xpcall = require, pcall, xpcall
23 local collectgarbage, get_memory_limit = collectgarbage, get_memory_limit
24
25 module "luci.util"
26
27 --
28 -- Pythonic string formatting extension
29 --
30 getmetatable("").__mod = function(a, b)
31         local ok, res
32
33         if not b then
34                 return a
35         elseif type(b) == "table" then
36                 local k, _
37                 for k, _ in pairs(b) do if type(b[k]) == "userdata" then b[k] = tostring(b[k]) end end
38
39                 ok, res = pcall(a.format, a, unpack(b))
40                 if not ok then
41                         error(res, 2)
42                 end
43                 return res
44         else
45                 if type(b) == "userdata" then b = tostring(b) end
46
47                 ok, res = pcall(a.format, a, b)
48                 if not ok then
49                         error(res, 2)
50                 end
51                 return res
52         end
53 end
54
55
56 --
57 -- Class helper routines
58 --
59
60 -- Instantiates a class
61 local function _instantiate(class, ...)
62         local inst = setmetatable({}, {__index = class})
63
64         if inst.__init__ then
65                 inst:__init__(...)
66         end
67
68         return inst
69 end
70
71 -- The class object can be instantiated by calling itself.
72 -- Any class functions or shared parameters can be attached to this object.
73 -- Attaching a table to the class object makes this table shared between
74 -- all instances of this class. For object parameters use the __init__ function.
75 -- Classes can inherit member functions and values from a base class.
76 -- Class can be instantiated by calling them. All parameters will be passed
77 -- to the __init__ function of this class - if such a function exists.
78 -- The __init__ function must be used to set any object parameters that are not shared
79 -- with other objects of this class. Any return values will be ignored.
80 function class(base)
81         return setmetatable({}, {
82                 __call  = _instantiate,
83                 __index = base
84         })
85 end
86
87 function instanceof(object, class)
88         local meta = getmetatable(object)
89         while meta and meta.__index do
90                 if meta.__index == class then
91                         return true
92                 end
93                 meta = getmetatable(meta.__index)
94         end
95         return false
96 end
97
98
99 --
100 -- Scope manipulation routines
101 --
102
103 coxpt = setmetatable({}, { __mode = "kv" })
104
105 local tl_meta = {
106         __mode = "k",
107
108         __index = function(self, key)
109                 local t = rawget(self, coxpt[coroutine.running()]
110                  or coroutine.running() or 0)
111                 return t and t[key]
112         end,
113
114         __newindex = function(self, key, value)
115                 local c = coxpt[coroutine.running()] or coroutine.running() or 0
116                 local r = rawget(self, c)
117                 if not r then
118                         rawset(self, c, { [key] = value })
119                 else
120                         r[key] = value
121                 end
122         end
123 }
124
125 -- the current active coroutine. A thread local store is private a table object
126 -- whose values can't be accessed from outside of the running coroutine.
127 function threadlocal(tbl)
128         return setmetatable(tbl or {}, tl_meta)
129 end
130
131
132 --
133 -- Debugging routines
134 --
135
136 function perror(obj)
137         return io.stderr:write(tostring(obj) .. "\n")
138 end
139
140 function dumptable(t, maxdepth, i, seen)
141         i = i or 0
142         seen = seen or setmetatable({}, {__mode="k"})
143
144         for k,v in pairs(t) do
145                 perror(string.rep("\t", i) .. tostring(k) .. "\t" .. tostring(v))
146                 if type(v) == "table" and (not maxdepth or i < maxdepth) then
147                         if not seen[v] then
148                                 seen[v] = true
149                                 dumptable(v, maxdepth, i+1, seen)
150                         else
151                                 perror(string.rep("\t", i) .. "*** RECURSION ***")
152                         end
153                 end
154         end
155 end
156
157
158 --
159 -- String and data manipulation routines
160 --
161
162 function pcdata(value)
163         return value and tparser.pcdata(tostring(value))
164 end
165
166 function urlencode(value)
167         if value ~= nil then
168                 local str = tostring(value)
169                 return lhttp.urlencode(str, lhttp.ENCODE_IF_NEEDED + lhttp.ENCODE_FULL)
170                         or str
171         end
172         return nil
173 end
174
175 function urldecode(value, decode_plus)
176         if value ~= nil then
177                 local flag = decode_plus and lhttp.DECODE_PLUS or 0
178                 local str = tostring(value)
179                 return lhttp.urldecode(str, lhttp.DECODE_IF_NEEDED + flag)
180                         or str
181         end
182         return nil
183 end
184
185 function striptags(value)
186         return value and tparser.striptags(tostring(value))
187 end
188
189 function shellquote(value)
190         return string.format("'%s'", string.gsub(value or "", "'", "'\\''"))
191 end
192
193 -- for bash, ash and similar shells single-quoted strings are taken
194 -- literally except for single quotes (which terminate the string)
195 -- (and the exception noted below for dash (-) at the start of a
196 -- command line parameter).
197 function shellsqescape(value)
198    local res
199    res, _ = string.gsub(value, "'", "'\\''")
200    return res
201 end
202
203 -- bash, ash and other similar shells interpret a dash (-) at the start
204 -- of a command-line parameters as an option indicator regardless of
205 -- whether it is inside a single-quoted string.  It must be backlash
206 -- escaped to resolve this.  This requires in some funky special-case
207 -- handling.  It may actually be a property of the getopt function
208 -- rather than the shell proper.
209 function shellstartsqescape(value)
210    res, _ = string.gsub(value, "^\-", "\\-")
211    res, _ = string.gsub(res, "^-", "\-")
212    return shellsqescape(value)
213 end
214
215 -- containing the resulting substrings. The optional max parameter specifies
216 -- the number of bytes to process, regardless of the actual length of the given
217 -- string. The optional last parameter, regex, specifies whether the separator
218 -- sequence is interpreted as regular expression.
219 --                                      pattern as regular expression (optional, default is false)
220 function split(str, pat, max, regex)
221         pat = pat or "\n"
222         max = max or #str
223
224         local t = {}
225         local c = 1
226
227         if #str == 0 then
228                 return {""}
229         end
230
231         if #pat == 0 then
232                 return nil
233         end
234
235         if max == 0 then
236                 return str
237         end
238
239         repeat
240                 local s, e = str:find(pat, c, not regex)
241                 max = max - 1
242                 if s and max < 0 then
243                         t[#t+1] = str:sub(c)
244                 else
245                         t[#t+1] = str:sub(c, s and s - 1)
246                 end
247                 c = e and e + 1 or #str + 1
248         until not s or max < 0
249
250         return t
251 end
252
253 function trim(str)
254         return (str:gsub("^%s*(.-)%s*$", "%1"))
255 end
256
257 function cmatch(str, pat)
258         local count = 0
259         for _ in str:gmatch(pat) do count = count + 1 end
260         return count
261 end
262
263 -- one token per invocation, the tokens are separated by whitespace. If the
264 -- input value is a table, it is transformed into a string first. A nil value
265 -- will result in a valid iterator which aborts with the first invocation.
266 function imatch(v)
267         if type(v) == "table" then
268                 local k = nil
269                 return function()
270                         k = next(v, k)
271                         return v[k]
272                 end
273
274         elseif type(v) == "number" or type(v) == "boolean" then
275                 local x = true
276                 return function()
277                         if x then
278                                 x = false
279                                 return tostring(v)
280                         end
281                 end
282
283         elseif type(v) == "userdata" or type(v) == "string" then
284                 return tostring(v):gmatch("%S+")
285         end
286
287         return function() end
288 end
289
290 -- value or 0 if the unit is unknown. Upper- or lower case is irrelevant.
291 -- Recognized units are:
292 --      o "y"   - one year   (60*60*24*366)
293 --  o "m"       - one month  (60*60*24*31)
294 --  o "w"       - one week   (60*60*24*7)
295 --  o "d"       - one day    (60*60*24)
296 --  o "h"       - one hour       (60*60)
297 --  o "min"     - one minute (60)
298 --  o "kb"  - one kilobyte (1024)
299 --  o "mb"      - one megabyte (1024*1024)
300 --  o "gb"      - one gigabyte (1024*1024*1024)
301 --  o "kib" - one si kilobyte (1000)
302 --  o "mib"     - one si megabyte (1000*1000)
303 --  o "gib"     - one si gigabyte (1000*1000*1000)
304 function parse_units(ustr)
305
306         local val = 0
307
308         -- unit map
309         local map = {
310                 -- date stuff
311                 y   = 60 * 60 * 24 * 366,
312                 m   = 60 * 60 * 24 * 31,
313                 w   = 60 * 60 * 24 * 7,
314                 d   = 60 * 60 * 24,
315                 h   = 60 * 60,
316                 min = 60,
317
318                 -- storage sizes
319                 kb  = 1024,
320                 mb  = 1024 * 1024,
321                 gb  = 1024 * 1024 * 1024,
322
323                 -- storage sizes (si)
324                 kib = 1000,
325                 mib = 1000 * 1000,
326                 gib = 1000 * 1000 * 1000
327         }
328
329         -- parse input string
330         for spec in ustr:lower():gmatch("[0-9%.]+[a-zA-Z]*") do
331
332                 local num = spec:gsub("[^0-9%.]+$","")
333                 local spn = spec:gsub("^[0-9%.]+", "")
334
335                 if map[spn] or map[spn:sub(1,1)] then
336                         val = val + num * ( map[spn] or map[spn:sub(1,1)] )
337                 else
338                         val = val + num
339                 end
340         end
341
342
343         return val
344 end
345
346 -- also register functions above in the central string class for convenience
347 string.pcdata      = pcdata
348 string.striptags   = striptags
349 string.split       = split
350 string.trim        = trim
351 string.cmatch      = cmatch
352 string.parse_units = parse_units
353
354
355 function append(src, ...)
356         for i, a in ipairs({...}) do
357                 if type(a) == "table" then
358                         for j, v in ipairs(a) do
359                                 src[#src+1] = v
360                         end
361                 else
362                         src[#src+1] = a
363                 end
364         end
365         return src
366 end
367
368 function combine(...)
369         return append({}, ...)
370 end
371
372 function contains(table, value)
373         for k, v in pairs(table) do
374                 if value == v then
375                         return k
376                 end
377         end
378         return false
379 end
380
381 -- Both table are - in fact - merged together.
382 function update(t, updates)
383         for k, v in pairs(updates) do
384                 t[k] = v
385         end
386 end
387
388 function keys(t)
389         local keys = { }
390         if t then
391                 for k, _ in kspairs(t) do
392                         keys[#keys+1] = k
393                 end
394         end
395         return keys
396 end
397
398 function clone(object, deep)
399         local copy = {}
400
401         for k, v in pairs(object) do
402                 if deep and type(v) == "table" then
403                         v = clone(v, deep)
404                 end
405                 copy[k] = v
406         end
407
408         return setmetatable(copy, getmetatable(object))
409 end
410
411
412 -- Serialize the contents of a table value.
413 function _serialize_table(t, seen)
414         assert(not seen[t], "Recursion detected.")
415         seen[t] = true
416
417         local data  = ""
418         local idata = ""
419         local ilen  = 0
420
421         for k, v in pairs(t) do
422                 if type(k) ~= "number" or k < 1 or math.floor(k) ~= k or ( k - #t ) > 3 then
423                         k = serialize_data(k, seen)
424                         v = serialize_data(v, seen)
425                         data = data .. ( #data > 0 and ", " or "" ) ..
426                                 '[' .. k .. '] = ' .. v
427                 elseif k > ilen then
428                         ilen = k
429                 end
430         end
431
432         for i = 1, ilen do
433                 local v = serialize_data(t[i], seen)
434                 idata = idata .. ( #idata > 0 and ", " or "" ) .. v
435         end
436
437         return idata .. ( #data > 0 and #idata > 0 and ", " or "" ) .. data
438 end
439
440 -- with loadstring().
441 function serialize_data(val, seen)
442         seen = seen or setmetatable({}, {__mode="k"})
443
444         if val == nil then
445                 return "nil"
446         elseif type(val) == "number" then
447                 return val
448         elseif type(val) == "string" then
449                 return "%q" % val
450         elseif type(val) == "boolean" then
451                 return val and "true" or "false"
452         elseif type(val) == "function" then
453                 return "loadstring(%q)" % get_bytecode(val)
454         elseif type(val) == "table" then
455                 return "{ " .. _serialize_table(val, seen) .. " }"
456         else
457                 return '"[unhandled data type:' .. type(val) .. ']"'
458         end
459 end
460
461 function restore_data(str)
462         return loadstring("return " .. str)()
463 end
464
465
466 --
467 -- Byte code manipulation routines
468 --
469
470 -- will be stripped before it is returned.
471 function get_bytecode(val)
472         local code
473
474         if type(val) == "function" then
475                 code = string.dump(val)
476         else
477                 code = string.dump( loadstring( "return " .. serialize_data(val) ) )
478         end
479
480         return code -- and strip_bytecode(code)
481 end
482
483 -- numbers and debugging numbers will be discarded. Original version by
484 -- Peter Cawley (http://lua-users.org/lists/lua-l/2008-02/msg01158.html)
485 function strip_bytecode(code)
486         local version, format, endian, int, size, ins, num, lnum = code:byte(5, 12)
487         local subint
488         if endian == 1 then
489                 subint = function(code, i, l)
490                         local val = 0
491                         for n = l, 1, -1 do
492                                 val = val * 256 + code:byte(i + n - 1)
493                         end
494                         return val, i + l
495                 end
496         else
497                 subint = function(code, i, l)
498                         local val = 0
499                         for n = 1, l, 1 do
500                                 val = val * 256 + code:byte(i + n - 1)
501                         end
502                         return val, i + l
503                 end
504         end
505
506         local function strip_function(code)
507                 local count, offset = subint(code, 1, size)
508                 local stripped = { string.rep("\0", size) }
509                 local dirty = offset + count
510                 offset = offset + count + int * 2 + 4
511                 offset = offset + int + subint(code, offset, int) * ins
512                 count, offset = subint(code, offset, int)
513                 for n = 1, count do
514                         local t
515                         t, offset = subint(code, offset, 1)
516                         if t == 1 then
517                                 offset = offset + 1
518                         elseif t == 4 then
519                                 offset = offset + size + subint(code, offset, size)
520                         elseif t == 3 then
521                                 offset = offset + num
522                         elseif t == 254 or t == 9 then
523                                 offset = offset + lnum
524                         end
525                 end
526                 count, offset = subint(code, offset, int)
527                 stripped[#stripped+1] = code:sub(dirty, offset - 1)
528                 for n = 1, count do
529                         local proto, off = strip_function(code:sub(offset, -1))
530                         stripped[#stripped+1] = proto
531                         offset = offset + off - 1
532                 end
533                 offset = offset + subint(code, offset, int) * int + int
534                 count, offset = subint(code, offset, int)
535                 for n = 1, count do
536                         offset = offset + subint(code, offset, size) + size + int * 2
537                 end
538                 count, offset = subint(code, offset, int)
539                 for n = 1, count do
540                         offset = offset + subint(code, offset, size) + size
541                 end
542                 stripped[#stripped+1] = string.rep("\0", int * 3)
543                 return table.concat(stripped), offset
544         end
545
546         return code:sub(1,12) .. strip_function(code:sub(13,-1))
547 end
548
549
550 --
551 -- Sorting iterator functions
552 --
553
554 function _sortiter( t, f )
555         local keys = { }
556
557         local k, v
558         for k, v in pairs(t) do
559                 keys[#keys+1] = k
560         end
561
562         local _pos = 0
563
564         table.sort( keys, f )
565
566         return function()
567                 _pos = _pos + 1
568                 if _pos <= #keys then
569                         return keys[_pos], t[keys[_pos]], _pos
570                 end
571         end
572 end
573
574 -- the provided callback function.
575 function spairs(t,f)
576         return _sortiter( t, f )
577 end
578
579 -- The table pairs are sorted by key.
580 function kspairs(t)
581         return _sortiter( t )
582 end
583
584 -- The table pairs are sorted by value.
585 function vspairs(t)
586         return _sortiter( t, function (a,b) return t[a] < t[b] end )
587 end
588
589
590 --
591 -- System utility functions
592 --
593
594 function bigendian()
595         return string.byte(string.dump(function() end), 7) == 0
596 end
597
598 function exec(command)
599         local pp   = io.popen(command)
600         local data = pp:read("*a")
601         pp:close()
602
603         return data
604 end
605
606 function execi(command)
607         local pp = io.popen(command)
608
609         return pp and function()
610                 local line = pp:read()
611
612                 if not line then
613                         pp:close()
614                 end
615
616                 return line
617         end
618 end
619
620 -- Deprecated
621 function execl(command)
622         local pp   = io.popen(command)
623         local line = ""
624         local data = {}
625
626         while true do
627                 line = pp:read()
628                 if (line == nil) then break end
629                 data[#data+1] = line
630         end
631         pp:close()
632
633         return data
634 end
635
636
637 local ubus_codes = {
638         "INVALID_COMMAND",
639         "INVALID_ARGUMENT",
640         "METHOD_NOT_FOUND",
641         "NOT_FOUND",
642         "NO_DATA",
643         "PERMISSION_DENIED",
644         "TIMEOUT",
645         "NOT_SUPPORTED",
646         "UNKNOWN_ERROR",
647         "CONNECTION_FAILED"
648 }
649
650 local function ubus_return(...)
651         if select('#', ...) == 2 then
652                 local rv, err = select(1, ...), select(2, ...)
653                 if rv == nil and type(err) == "number" then
654                         return nil, err, ubus_codes[err]
655                 end
656         end
657
658         return ...
659 end
660
661 function ubus(object, method, data)
662         if not _ubus_connection then
663                 _ubus_connection = _ubus.connect()
664                 assert(_ubus_connection, "Unable to establish ubus connection")
665         end
666
667         if object and method then
668                 if type(data) ~= "table" then
669                         data = { }
670                 end
671                 return ubus_return(_ubus_connection:call(object, method, data))
672         elseif object then
673                 return _ubus_connection:signatures(object)
674         else
675                 return _ubus_connection:objects()
676         end
677 end
678
679 function serialize_json(x, cb)
680         local js = json.stringify(x)
681         if type(cb) == "function" then
682                 cb(js)
683         else
684                 return js
685         end
686 end
687
688
689 function libpath()
690         return require "nixio.fs".dirname(ldebug.__file__)
691 end
692
693 function checklib(fullpathexe, wantedlib)
694         local fs = require "nixio.fs"
695         local haveldd = fs.access('/usr/bin/ldd')
696         local haveexe = fs.access(fullpathexe)
697         if not haveldd or not haveexe then
698                 return false
699         end
700         local libs = exec(string.format("/usr/bin/ldd %s", shellquote(fullpathexe)))
701         if not libs then
702                 return false
703         end
704         for k, v in ipairs(split(libs)) do
705                 if v:find(wantedlib) then
706                         return true
707                 end
708         end
709         return false
710 end
711
712 -------------------------------------------------------------------------------
713 -- Coroutine safe xpcall and pcall versions
714 --
715 -- Encapsulates the protected calls with a coroutine based loop, so errors can
716 -- be dealed without the usual Lua 5.x pcall/xpcall issues with coroutines
717 -- yielding inside the call to pcall or xpcall.
718 --
719 -- Authors: Roberto Ierusalimschy and Andre Carregal
720 -- Contributors: Thomas Harning Jr., Ignacio BurgueƱo, Fabio Mascarenhas
721 --
722 -- Copyright 2005 - Kepler Project
723 --
724 -- $Id: coxpcall.lua,v 1.13 2008/05/19 19:20:02 mascarenhas Exp $
725 -------------------------------------------------------------------------------
726
727 -------------------------------------------------------------------------------
728 -- Implements xpcall with coroutines
729 -------------------------------------------------------------------------------
730 local coromap = setmetatable({}, { __mode = "k" })
731
732 local function handleReturnValue(err, co, status, ...)
733         if not status then
734                 return false, err(debug.traceback(co, (...)), ...)
735         end
736         if coroutine.status(co) == 'suspended' then
737                 return performResume(err, co, coroutine.yield(...))
738         else
739                 return true, ...
740         end
741 end
742
743 function performResume(err, co, ...)
744         return handleReturnValue(err, co, coroutine.resume(co, ...))
745 end
746
747 local function id(trace, ...)
748         return trace
749 end
750
751 function coxpcall(f, err, ...)
752         local current = coroutine.running()
753         if not current then
754                 if err == id then
755                         return pcall(f, ...)
756                 else
757                         if select("#", ...) > 0 then
758                                 local oldf, params = f, { ... }
759                                 f = function() return oldf(unpack(params)) end
760                         end
761                         return xpcall(f, err)
762                 end
763         else
764                 local res, co = pcall(coroutine.create, f)
765                 if not res then
766                         local newf = function(...) return f(...) end
767                         co = coroutine.create(newf)
768                 end
769                 coromap[co] = current
770                 coxpt[co] = coxpt[current] or current or 0
771                 return performResume(err, co, ...)
772         end
773 end
774
775 function copcall(f, ...)
776         return coxpcall(f, id, ...)
777 end