Merge pull request #3517 from Ansuel/ubus_short
[oweals/luci.git] / modules / luci-base / luasrc / util.lua
1 -- Copyright 2008 Steven Barth <steven@midlink.org>
2 -- Licensed to the public under the Apache License 2.0.
3
4 local io = require "io"
5 local math = require "math"
6 local table = require "table"
7 local debug = require "debug"
8 local ldebug = require "luci.debug"
9 local string = require "string"
10 local coroutine = require "coroutine"
11 local tparser = require "luci.template.parser"
12 local json = require "luci.jsonc"
13 local lhttp = require "lucihttp"
14
15 local _ubus = require "ubus"
16 local _ubus_connection = nil
17
18 local getmetatable, setmetatable = getmetatable, setmetatable
19 local rawget, rawset, unpack, select = rawget, rawset, unpack, select
20 local tostring, type, assert, error = tostring, type, assert, error
21 local ipairs, pairs, next, loadstring = ipairs, pairs, next, loadstring
22 local require, pcall, xpcall = require, pcall, xpcall
23 local collectgarbage, get_memory_limit = collectgarbage, get_memory_limit
24
25 module "luci.util"
26
27 --
28 -- Pythonic string formatting extension
29 --
30 getmetatable("").__mod = function(a, b)
31         local ok, res
32
33         if not b then
34                 return a
35         elseif type(b) == "table" then
36                 local k, _
37                 for k, _ in pairs(b) do if type(b[k]) == "userdata" then b[k] = tostring(b[k]) end end
38
39                 ok, res = pcall(a.format, a, unpack(b))
40                 if not ok then
41                         error(res, 2)
42                 end
43                 return res
44         else
45                 if type(b) == "userdata" then b = tostring(b) end
46
47                 ok, res = pcall(a.format, a, b)
48                 if not ok then
49                         error(res, 2)
50                 end
51                 return res
52         end
53 end
54
55
56 --
57 -- Class helper routines
58 --
59
60 -- Instantiates a class
61 local function _instantiate(class, ...)
62         local inst = setmetatable({}, {__index = class})
63
64         if inst.__init__ then
65                 inst:__init__(...)
66         end
67
68         return inst
69 end
70
71 -- The class object can be instantiated by calling itself.
72 -- Any class functions or shared parameters can be attached to this object.
73 -- Attaching a table to the class object makes this table shared between
74 -- all instances of this class. For object parameters use the __init__ function.
75 -- Classes can inherit member functions and values from a base class.
76 -- Class can be instantiated by calling them. All parameters will be passed
77 -- to the __init__ function of this class - if such a function exists.
78 -- The __init__ function must be used to set any object parameters that are not shared
79 -- with other objects of this class. Any return values will be ignored.
80 function class(base)
81         return setmetatable({}, {
82                 __call  = _instantiate,
83                 __index = base
84         })
85 end
86
87 function instanceof(object, class)
88         local meta = getmetatable(object)
89         while meta and meta.__index do
90                 if meta.__index == class then
91                         return true
92                 end
93                 meta = getmetatable(meta.__index)
94         end
95         return false
96 end
97
98
99 --
100 -- Scope manipulation routines
101 --
102
103 coxpt = setmetatable({}, { __mode = "kv" })
104
105 local tl_meta = {
106         __mode = "k",
107
108         __index = function(self, key)
109                 local t = rawget(self, coxpt[coroutine.running()]
110                  or coroutine.running() or 0)
111                 return t and t[key]
112         end,
113
114         __newindex = function(self, key, value)
115                 local c = coxpt[coroutine.running()] or coroutine.running() or 0
116                 local r = rawget(self, c)
117                 if not r then
118                         rawset(self, c, { [key] = value })
119                 else
120                         r[key] = value
121                 end
122         end
123 }
124
125 -- the current active coroutine. A thread local store is private a table object
126 -- whose values can't be accessed from outside of the running coroutine.
127 function threadlocal(tbl)
128         return setmetatable(tbl or {}, tl_meta)
129 end
130
131
132 --
133 -- Debugging routines
134 --
135
136 function perror(obj)
137         return io.stderr:write(tostring(obj) .. "\n")
138 end
139
140 function dumptable(t, maxdepth, i, seen)
141         i = i or 0
142         seen = seen or setmetatable({}, {__mode="k"})
143
144         for k,v in pairs(t) do
145                 perror(string.rep("\t", i) .. tostring(k) .. "\t" .. tostring(v))
146                 if type(v) == "table" and (not maxdepth or i < maxdepth) then
147                         if not seen[v] then
148                                 seen[v] = true
149                                 dumptable(v, maxdepth, i+1, seen)
150                         else
151                                 perror(string.rep("\t", i) .. "*** RECURSION ***")
152                         end
153                 end
154         end
155 end
156
157
158 --
159 -- String and data manipulation routines
160 --
161
162 function pcdata(value)
163         return value and tparser.pcdata(tostring(value))
164 end
165
166 function urlencode(value)
167         if value ~= nil then
168                 local str = tostring(value)
169                 return lhttp.urlencode(str, lhttp.ENCODE_IF_NEEDED + lhttp.ENCODE_FULL)
170                         or str
171         end
172         return nil
173 end
174
175 function urldecode(value, decode_plus)
176         if value ~= nil then
177                 local flag = decode_plus and lhttp.DECODE_PLUS or 0
178                 local str = tostring(value)
179                 return lhttp.urldecode(str, lhttp.DECODE_IF_NEEDED + flag)
180                         or str
181         end
182         return nil
183 end
184
185 function striptags(value)
186         return value and tparser.striptags(tostring(value))
187 end
188
189 function shellquote(value)
190         return string.format("'%s'", string.gsub(value or "", "'", "'\\''"))
191 end
192
193 -- for bash, ash and similar shells single-quoted strings are taken
194 -- literally except for single quotes (which terminate the string)
195 -- (and the exception noted below for dash (-) at the start of a
196 -- command line parameter).
197 function shellsqescape(value)
198    local res
199    res, _ = string.gsub(value, "'", "'\\''")
200    return res
201 end
202
203 -- bash, ash and other similar shells interpret a dash (-) at the start
204 -- of a command-line parameters as an option indicator regardless of
205 -- whether it is inside a single-quoted string.  It must be backlash
206 -- escaped to resolve this.  This requires in some funky special-case
207 -- handling.  It may actually be a property of the getopt function
208 -- rather than the shell proper.
209 function shellstartsqescape(value)
210    res, _ = string.gsub(value, "^%-", "\\-")
211    return shellsqescape(res)
212 end
213
214 -- containing the resulting substrings. The optional max parameter specifies
215 -- the number of bytes to process, regardless of the actual length of the given
216 -- string. The optional last parameter, regex, specifies whether the separator
217 -- sequence is interpreted as regular expression.
218 --                                      pattern as regular expression (optional, default is false)
219 function split(str, pat, max, regex)
220         pat = pat or "\n"
221         max = max or #str
222
223         local t = {}
224         local c = 1
225
226         if #str == 0 then
227                 return {""}
228         end
229
230         if #pat == 0 then
231                 return nil
232         end
233
234         if max == 0 then
235                 return str
236         end
237
238         repeat
239                 local s, e = str:find(pat, c, not regex)
240                 max = max - 1
241                 if s and max < 0 then
242                         t[#t+1] = str:sub(c)
243                 else
244                         t[#t+1] = str:sub(c, s and s - 1)
245                 end
246                 c = e and e + 1 or #str + 1
247         until not s or max < 0
248
249         return t
250 end
251
252 function trim(str)
253         return (str:gsub("^%s*(.-)%s*$", "%1"))
254 end
255
256 function cmatch(str, pat)
257         local count = 0
258         for _ in str:gmatch(pat) do count = count + 1 end
259         return count
260 end
261
262 -- one token per invocation, the tokens are separated by whitespace. If the
263 -- input value is a table, it is transformed into a string first. A nil value
264 -- will result in a valid iterator which aborts with the first invocation.
265 function imatch(v)
266         if type(v) == "table" then
267                 local k = nil
268                 return function()
269                         k = next(v, k)
270                         return v[k]
271                 end
272
273         elseif type(v) == "number" or type(v) == "boolean" then
274                 local x = true
275                 return function()
276                         if x then
277                                 x = false
278                                 return tostring(v)
279                         end
280                 end
281
282         elseif type(v) == "userdata" or type(v) == "string" then
283                 return tostring(v):gmatch("%S+")
284         end
285
286         return function() end
287 end
288
289 -- value or 0 if the unit is unknown. Upper- or lower case is irrelevant.
290 -- Recognized units are:
291 --      o "y"   - one year   (60*60*24*366)
292 --  o "m"       - one month  (60*60*24*31)
293 --  o "w"       - one week   (60*60*24*7)
294 --  o "d"       - one day    (60*60*24)
295 --  o "h"       - one hour       (60*60)
296 --  o "min"     - one minute (60)
297 --  o "kb"  - one kilobyte (1024)
298 --  o "mb"      - one megabyte (1024*1024)
299 --  o "gb"      - one gigabyte (1024*1024*1024)
300 --  o "kib" - one si kilobyte (1000)
301 --  o "mib"     - one si megabyte (1000*1000)
302 --  o "gib"     - one si gigabyte (1000*1000*1000)
303 function parse_units(ustr)
304
305         local val = 0
306
307         -- unit map
308         local map = {
309                 -- date stuff
310                 y   = 60 * 60 * 24 * 366,
311                 m   = 60 * 60 * 24 * 31,
312                 w   = 60 * 60 * 24 * 7,
313                 d   = 60 * 60 * 24,
314                 h   = 60 * 60,
315                 min = 60,
316
317                 -- storage sizes
318                 kb  = 1024,
319                 mb  = 1024 * 1024,
320                 gb  = 1024 * 1024 * 1024,
321
322                 -- storage sizes (si)
323                 kib = 1000,
324                 mib = 1000 * 1000,
325                 gib = 1000 * 1000 * 1000
326         }
327
328         -- parse input string
329         for spec in ustr:lower():gmatch("[0-9%.]+[a-zA-Z]*") do
330
331                 local num = spec:gsub("[^0-9%.]+$","")
332                 local spn = spec:gsub("^[0-9%.]+", "")
333
334                 if map[spn] or map[spn:sub(1,1)] then
335                         val = val + num * ( map[spn] or map[spn:sub(1,1)] )
336                 else
337                         val = val + num
338                 end
339         end
340
341
342         return val
343 end
344
345 -- also register functions above in the central string class for convenience
346 string.pcdata      = pcdata
347 string.striptags   = striptags
348 string.split       = split
349 string.trim        = trim
350 string.cmatch      = cmatch
351 string.parse_units = parse_units
352
353
354 function append(src, ...)
355         for i, a in ipairs({...}) do
356                 if type(a) == "table" then
357                         for j, v in ipairs(a) do
358                                 src[#src+1] = v
359                         end
360                 else
361                         src[#src+1] = a
362                 end
363         end
364         return src
365 end
366
367 function combine(...)
368         return append({}, ...)
369 end
370
371 function contains(table, value)
372         for k, v in pairs(table) do
373                 if value == v then
374                         return k
375                 end
376         end
377         return false
378 end
379
380 -- Both table are - in fact - merged together.
381 function update(t, updates)
382         for k, v in pairs(updates) do
383                 t[k] = v
384         end
385 end
386
387 function keys(t)
388         local keys = { }
389         if t then
390                 for k, _ in kspairs(t) do
391                         keys[#keys+1] = k
392                 end
393         end
394         return keys
395 end
396
397 function clone(object, deep)
398         local copy = {}
399
400         for k, v in pairs(object) do
401                 if deep and type(v) == "table" then
402                         v = clone(v, deep)
403                 end
404                 copy[k] = v
405         end
406
407         return setmetatable(copy, getmetatable(object))
408 end
409
410
411 -- Serialize the contents of a table value.
412 function _serialize_table(t, seen)
413         assert(not seen[t], "Recursion detected.")
414         seen[t] = true
415
416         local data  = ""
417         local idata = ""
418         local ilen  = 0
419
420         for k, v in pairs(t) do
421                 if type(k) ~= "number" or k < 1 or math.floor(k) ~= k or ( k - #t ) > 3 then
422                         k = serialize_data(k, seen)
423                         v = serialize_data(v, seen)
424                         data = data .. ( #data > 0 and ", " or "" ) ..
425                                 '[' .. k .. '] = ' .. v
426                 elseif k > ilen then
427                         ilen = k
428                 end
429         end
430
431         for i = 1, ilen do
432                 local v = serialize_data(t[i], seen)
433                 idata = idata .. ( #idata > 0 and ", " or "" ) .. v
434         end
435
436         return idata .. ( #data > 0 and #idata > 0 and ", " or "" ) .. data
437 end
438
439 -- with loadstring().
440 function serialize_data(val, seen)
441         seen = seen or setmetatable({}, {__mode="k"})
442
443         if val == nil then
444                 return "nil"
445         elseif type(val) == "number" then
446                 return val
447         elseif type(val) == "string" then
448                 return "%q" % val
449         elseif type(val) == "boolean" then
450                 return val and "true" or "false"
451         elseif type(val) == "function" then
452                 return "loadstring(%q)" % get_bytecode(val)
453         elseif type(val) == "table" then
454                 return "{ " .. _serialize_table(val, seen) .. " }"
455         else
456                 return '"[unhandled data type:' .. type(val) .. ']"'
457         end
458 end
459
460 function restore_data(str)
461         return loadstring("return " .. str)()
462 end
463
464
465 --
466 -- Byte code manipulation routines
467 --
468
469 -- will be stripped before it is returned.
470 function get_bytecode(val)
471         local code
472
473         if type(val) == "function" then
474                 code = string.dump(val)
475         else
476                 code = string.dump( loadstring( "return " .. serialize_data(val) ) )
477         end
478
479         return code -- and strip_bytecode(code)
480 end
481
482 -- numbers and debugging numbers will be discarded. Original version by
483 -- Peter Cawley (http://lua-users.org/lists/lua-l/2008-02/msg01158.html)
484 function strip_bytecode(code)
485         local version, format, endian, int, size, ins, num, lnum = code:byte(5, 12)
486         local subint
487         if endian == 1 then
488                 subint = function(code, i, l)
489                         local val = 0
490                         for n = l, 1, -1 do
491                                 val = val * 256 + code:byte(i + n - 1)
492                         end
493                         return val, i + l
494                 end
495         else
496                 subint = function(code, i, l)
497                         local val = 0
498                         for n = 1, l, 1 do
499                                 val = val * 256 + code:byte(i + n - 1)
500                         end
501                         return val, i + l
502                 end
503         end
504
505         local function strip_function(code)
506                 local count, offset = subint(code, 1, size)
507                 local stripped = { string.rep("\0", size) }
508                 local dirty = offset + count
509                 offset = offset + count + int * 2 + 4
510                 offset = offset + int + subint(code, offset, int) * ins
511                 count, offset = subint(code, offset, int)
512                 for n = 1, count do
513                         local t
514                         t, offset = subint(code, offset, 1)
515                         if t == 1 then
516                                 offset = offset + 1
517                         elseif t == 4 then
518                                 offset = offset + size + subint(code, offset, size)
519                         elseif t == 3 then
520                                 offset = offset + num
521                         elseif t == 254 or t == 9 then
522                                 offset = offset + lnum
523                         end
524                 end
525                 count, offset = subint(code, offset, int)
526                 stripped[#stripped+1] = code:sub(dirty, offset - 1)
527                 for n = 1, count do
528                         local proto, off = strip_function(code:sub(offset, -1))
529                         stripped[#stripped+1] = proto
530                         offset = offset + off - 1
531                 end
532                 offset = offset + subint(code, offset, int) * int + int
533                 count, offset = subint(code, offset, int)
534                 for n = 1, count do
535                         offset = offset + subint(code, offset, size) + size + int * 2
536                 end
537                 count, offset = subint(code, offset, int)
538                 for n = 1, count do
539                         offset = offset + subint(code, offset, size) + size
540                 end
541                 stripped[#stripped+1] = string.rep("\0", int * 3)
542                 return table.concat(stripped), offset
543         end
544
545         return code:sub(1,12) .. strip_function(code:sub(13,-1))
546 end
547
548
549 --
550 -- Sorting iterator functions
551 --
552
553 function _sortiter( t, f )
554         local keys = { }
555
556         local k, v
557         for k, v in pairs(t) do
558                 keys[#keys+1] = k
559         end
560
561         local _pos = 0
562
563         table.sort( keys, f )
564
565         return function()
566                 _pos = _pos + 1
567                 if _pos <= #keys then
568                         return keys[_pos], t[keys[_pos]], _pos
569                 end
570         end
571 end
572
573 -- the provided callback function.
574 function spairs(t,f)
575         return _sortiter( t, f )
576 end
577
578 -- The table pairs are sorted by key.
579 function kspairs(t)
580         return _sortiter( t )
581 end
582
583 -- The table pairs are sorted by value.
584 function vspairs(t)
585         return _sortiter( t, function (a,b) return t[a] < t[b] end )
586 end
587
588
589 --
590 -- System utility functions
591 --
592
593 function bigendian()
594         return string.byte(string.dump(function() end), 7) == 0
595 end
596
597 function exec(command)
598         local pp   = io.popen(command)
599         local data = pp:read("*a")
600         pp:close()
601
602         return data
603 end
604
605 function execi(command)
606         local pp = io.popen(command)
607
608         return pp and function()
609                 local line = pp:read()
610
611                 if not line then
612                         pp:close()
613                 end
614
615                 return line
616         end
617 end
618
619 -- Deprecated
620 function execl(command)
621         local pp   = io.popen(command)
622         local line = ""
623         local data = {}
624
625         while true do
626                 line = pp:read()
627                 if (line == nil) then break end
628                 data[#data+1] = line
629         end
630         pp:close()
631
632         return data
633 end
634
635
636 local ubus_codes = {
637         "INVALID_COMMAND",
638         "INVALID_ARGUMENT",
639         "METHOD_NOT_FOUND",
640         "NOT_FOUND",
641         "NO_DATA",
642         "PERMISSION_DENIED",
643         "TIMEOUT",
644         "NOT_SUPPORTED",
645         "UNKNOWN_ERROR",
646         "CONNECTION_FAILED"
647 }
648
649 local function ubus_return(...)
650         if select('#', ...) == 2 then
651                 local rv, err = select(1, ...), select(2, ...)
652                 if rv == nil and type(err) == "number" then
653                         return nil, err, ubus_codes[err]
654                 end
655         end
656
657         return ...
658 end
659
660 function ubus(object, method, data)
661         if not _ubus_connection then
662                 _ubus_connection = _ubus.connect()
663                 assert(_ubus_connection, "Unable to establish ubus connection")
664         end
665
666         if object and method then
667                 if type(data) ~= "table" then
668                         data = { }
669                 end
670                 return ubus_return(_ubus_connection:call(object, method, data))
671         elseif object then
672                 return _ubus_connection:signatures(object)
673         else
674                 return _ubus_connection:objects()
675         end
676 end
677
678 function serialize_json(x, cb)
679         local js = json.stringify(x)
680         if type(cb) == "function" then
681                 cb(js)
682         else
683                 return js
684         end
685 end
686
687
688 function libpath()
689         return require "nixio.fs".dirname(ldebug.__file__)
690 end
691
692 function checklib(fullpathexe, wantedlib)
693         local fs = require "nixio.fs"
694         local haveldd = fs.access('/usr/bin/ldd')
695         local haveexe = fs.access(fullpathexe)
696         if not haveldd or not haveexe then
697                 return false
698         end
699         local libs = exec(string.format("/usr/bin/ldd %s", shellquote(fullpathexe)))
700         if not libs then
701                 return false
702         end
703         for k, v in ipairs(split(libs)) do
704                 if v:find(wantedlib) then
705                         return true
706                 end
707         end
708         return false
709 end
710
711 -------------------------------------------------------------------------------
712 -- Coroutine safe xpcall and pcall versions
713 --
714 -- Encapsulates the protected calls with a coroutine based loop, so errors can
715 -- be dealed without the usual Lua 5.x pcall/xpcall issues with coroutines
716 -- yielding inside the call to pcall or xpcall.
717 --
718 -- Authors: Roberto Ierusalimschy and Andre Carregal
719 -- Contributors: Thomas Harning Jr., Ignacio BurgueƱo, Fabio Mascarenhas
720 --
721 -- Copyright 2005 - Kepler Project
722 --
723 -- $Id: coxpcall.lua,v 1.13 2008/05/19 19:20:02 mascarenhas Exp $
724 -------------------------------------------------------------------------------
725
726 -------------------------------------------------------------------------------
727 -- Implements xpcall with coroutines
728 -------------------------------------------------------------------------------
729 local coromap = setmetatable({}, { __mode = "k" })
730
731 local function handleReturnValue(err, co, status, ...)
732         if not status then
733                 return false, err(debug.traceback(co, (...)), ...)
734         end
735         if coroutine.status(co) == 'suspended' then
736                 return performResume(err, co, coroutine.yield(...))
737         else
738                 return true, ...
739         end
740 end
741
742 function performResume(err, co, ...)
743         return handleReturnValue(err, co, coroutine.resume(co, ...))
744 end
745
746 local function id(trace, ...)
747         return trace
748 end
749
750 function coxpcall(f, err, ...)
751         local current = coroutine.running()
752         if not current then
753                 if err == id then
754                         return pcall(f, ...)
755                 else
756                         if select("#", ...) > 0 then
757                                 local oldf, params = f, { ... }
758                                 f = function() return oldf(unpack(params)) end
759                         end
760                         return xpcall(f, err)
761                 end
762         else
763                 local res, co = pcall(coroutine.create, f)
764                 if not res then
765                         local newf = function(...) return f(...) end
766                         co = coroutine.create(newf)
767                 end
768                 coromap[co] = current
769                 coxpt[co] = coxpt[current] or current or 0
770                 return performResume(err, co, ...)
771         end
772 end
773
774 function copcall(f, ...)
775         return coxpcall(f, id, ...)
776 end