Linux-libre 3.6.3-gnu1
[librecmc/linux-libre.git] / net / netfilter / ipset / ip_set_list_set.c
1 /* Copyright (C) 2008-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.
6  */
7
8 /* Kernel module implementing an IP set type: the list:set type */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/skbuff.h>
13 #include <linux/errno.h>
14
15 #include <linux/netfilter/ipset/ip_set.h>
16 #include <linux/netfilter/ipset/ip_set_timeout.h>
17 #include <linux/netfilter/ipset/ip_set_list.h>
18
19 MODULE_LICENSE("GPL");
20 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
21 MODULE_DESCRIPTION("list:set type of IP sets");
22 MODULE_ALIAS("ip_set_list:set");
23
24 /* Member elements without and with timeout */
25 struct set_elem {
26         ip_set_id_t id;
27 };
28
29 struct set_telem {
30         ip_set_id_t id;
31         unsigned long timeout;
32 };
33
34 /* Type structure */
35 struct list_set {
36         size_t dsize;           /* element size */
37         u32 size;               /* size of set list array */
38         u32 timeout;            /* timeout value */
39         struct timer_list gc;   /* garbage collection */
40         struct set_elem members[0]; /* the set members */
41 };
42
43 static inline struct set_elem *
44 list_set_elem(const struct list_set *map, u32 id)
45 {
46         return (struct set_elem *)((void *)map->members + id * map->dsize);
47 }
48
49 static inline struct set_telem *
50 list_set_telem(const struct list_set *map, u32 id)
51 {
52         return (struct set_telem *)((void *)map->members + id * map->dsize);
53 }
54
55 static inline bool
56 list_set_timeout(const struct list_set *map, u32 id)
57 {
58         const struct set_telem *elem = list_set_telem(map, id);
59
60         return ip_set_timeout_test(elem->timeout);
61 }
62
63 static inline bool
64 list_set_expired(const struct list_set *map, u32 id)
65 {
66         const struct set_telem *elem = list_set_telem(map, id);
67
68         return ip_set_timeout_expired(elem->timeout);
69 }
70
71 /* Set list without and with timeout */
72
73 static int
74 list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
75               const struct xt_action_param *par,
76               enum ipset_adt adt, const struct ip_set_adt_opt *opt)
77 {
78         struct list_set *map = set->data;
79         struct set_elem *elem;
80         u32 i;
81         int ret;
82
83         for (i = 0; i < map->size; i++) {
84                 elem = list_set_elem(map, i);
85                 if (elem->id == IPSET_INVALID_ID)
86                         return 0;
87                 if (with_timeout(map->timeout) && list_set_expired(map, i))
88                         continue;
89                 switch (adt) {
90                 case IPSET_TEST:
91                         ret = ip_set_test(elem->id, skb, par, opt);
92                         if (ret > 0)
93                                 return ret;
94                         break;
95                 case IPSET_ADD:
96                         ret = ip_set_add(elem->id, skb, par, opt);
97                         if (ret == 0)
98                                 return ret;
99                         break;
100                 case IPSET_DEL:
101                         ret = ip_set_del(elem->id, skb, par, opt);
102                         if (ret == 0)
103                                 return ret;
104                         break;
105                 default:
106                         break;
107                 }
108         }
109         return -EINVAL;
110 }
111
112 static bool
113 id_eq(const struct list_set *map, u32 i, ip_set_id_t id)
114 {
115         const struct set_elem *elem;
116
117         if (i < map->size) {
118                 elem = list_set_elem(map, i);
119                 return elem->id == id;
120         }
121
122         return 0;
123 }
124
125 static bool
126 id_eq_timeout(const struct list_set *map, u32 i, ip_set_id_t id)
127 {
128         const struct set_elem *elem;
129
130         if (i < map->size) {
131                 elem = list_set_elem(map, i);
132                 return !!(elem->id == id &&
133                           !(with_timeout(map->timeout) &&
134                             list_set_expired(map, i)));
135         }
136
137         return 0;
138 }
139
140 static void
141 list_elem_add(struct list_set *map, u32 i, ip_set_id_t id)
142 {
143         struct set_elem *e;
144
145         for (; i < map->size; i++) {
146                 e = list_set_elem(map, i);
147                 swap(e->id, id);
148                 if (e->id == IPSET_INVALID_ID)
149                         break;
150         }
151 }
152
153 static void
154 list_elem_tadd(struct list_set *map, u32 i, ip_set_id_t id,
155                unsigned long timeout)
156 {
157         struct set_telem *e;
158
159         for (; i < map->size; i++) {
160                 e = list_set_telem(map, i);
161                 swap(e->id, id);
162                 swap(e->timeout, timeout);
163                 if (e->id == IPSET_INVALID_ID)
164                         break;
165         }
166 }
167
168 static int
169 list_set_add(struct list_set *map, u32 i, ip_set_id_t id,
170              unsigned long timeout)
171 {
172         const struct set_elem *e = list_set_elem(map, i);
173
174         if (i == map->size - 1 && e->id != IPSET_INVALID_ID)
175                 /* Last element replaced: e.g. add new,before,last */
176                 ip_set_put_byindex(e->id);
177         if (with_timeout(map->timeout))
178                 list_elem_tadd(map, i, id, ip_set_timeout_set(timeout));
179         else
180                 list_elem_add(map, i, id);
181
182         return 0;
183 }
184
185 static int
186 list_set_del(struct list_set *map, u32 i)
187 {
188         struct set_elem *a = list_set_elem(map, i), *b;
189
190         ip_set_put_byindex(a->id);
191
192         for (; i < map->size - 1; i++) {
193                 b = list_set_elem(map, i + 1);
194                 a->id = b->id;
195                 if (with_timeout(map->timeout))
196                         ((struct set_telem *)a)->timeout =
197                                 ((struct set_telem *)b)->timeout;
198                 a = b;
199                 if (a->id == IPSET_INVALID_ID)
200                         break;
201         }
202         /* Last element */
203         a->id = IPSET_INVALID_ID;
204         return 0;
205 }
206
207 static void
208 cleanup_entries(struct list_set *map)
209 {
210         struct set_telem *e;
211         u32 i;
212
213         for (i = 0; i < map->size; i++) {
214                 e = list_set_telem(map, i);
215                 if (e->id != IPSET_INVALID_ID && list_set_expired(map, i))
216                         list_set_del(map, i);
217         }
218 }
219
220 static int
221 list_set_uadt(struct ip_set *set, struct nlattr *tb[],
222               enum ipset_adt adt, u32 *lineno, u32 flags, bool retried)
223 {
224         struct list_set *map = set->data;
225         bool with_timeout = with_timeout(map->timeout);
226         bool flag_exist = flags & IPSET_FLAG_EXIST;
227         int before = 0;
228         u32 timeout = map->timeout;
229         ip_set_id_t id, refid = IPSET_INVALID_ID;
230         const struct set_elem *elem;
231         struct ip_set *s;
232         u32 i;
233         int ret = 0;
234
235         if (unlikely(!tb[IPSET_ATTR_NAME] ||
236                      !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
237                      !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
238                 return -IPSET_ERR_PROTOCOL;
239
240         if (tb[IPSET_ATTR_LINENO])
241                 *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]);
242
243         id = ip_set_get_byname(nla_data(tb[IPSET_ATTR_NAME]), &s);
244         if (id == IPSET_INVALID_ID)
245                 return -IPSET_ERR_NAME;
246         /* "Loop detection" */
247         if (s->type->features & IPSET_TYPE_NAME) {
248                 ret = -IPSET_ERR_LOOP;
249                 goto finish;
250         }
251
252         if (tb[IPSET_ATTR_CADT_FLAGS]) {
253                 u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
254                 before = f & IPSET_FLAG_BEFORE;
255         }
256
257         if (before && !tb[IPSET_ATTR_NAMEREF]) {
258                 ret = -IPSET_ERR_BEFORE;
259                 goto finish;
260         }
261
262         if (tb[IPSET_ATTR_NAMEREF]) {
263                 refid = ip_set_get_byname(nla_data(tb[IPSET_ATTR_NAMEREF]),
264                                           &s);
265                 if (refid == IPSET_INVALID_ID) {
266                         ret = -IPSET_ERR_NAMEREF;
267                         goto finish;
268                 }
269                 if (!before)
270                         before = -1;
271         }
272         if (tb[IPSET_ATTR_TIMEOUT]) {
273                 if (!with_timeout) {
274                         ret = -IPSET_ERR_TIMEOUT;
275                         goto finish;
276                 }
277                 timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
278         }
279         if (with_timeout && adt != IPSET_TEST)
280                 cleanup_entries(map);
281
282         switch (adt) {
283         case IPSET_TEST:
284                 for (i = 0; i < map->size && !ret; i++) {
285                         elem = list_set_elem(map, i);
286                         if (elem->id == IPSET_INVALID_ID ||
287                             (before != 0 && i + 1 >= map->size))
288                                 break;
289                         else if (with_timeout && list_set_expired(map, i))
290                                 continue;
291                         else if (before > 0 && elem->id == id)
292                                 ret = id_eq_timeout(map, i + 1, refid);
293                         else if (before < 0 && elem->id == refid)
294                                 ret = id_eq_timeout(map, i + 1, id);
295                         else if (before == 0 && elem->id == id)
296                                 ret = 1;
297                 }
298                 break;
299         case IPSET_ADD:
300                 for (i = 0; i < map->size; i++) {
301                         elem = list_set_elem(map, i);
302                         if (elem->id != id)
303                                 continue;
304                         if (!(with_timeout && flag_exist)) {
305                                 ret = -IPSET_ERR_EXIST;
306                                 goto finish;
307                         } else {
308                                 struct set_telem *e = list_set_telem(map, i);
309
310                                 if ((before > 1 &&
311                                      !id_eq(map, i + 1, refid)) ||
312                                     (before < 0 &&
313                                      (i == 0 || !id_eq(map, i - 1, refid)))) {
314                                         ret = -IPSET_ERR_EXIST;
315                                         goto finish;
316                                 }
317                                 e->timeout = ip_set_timeout_set(timeout);
318                                 ip_set_put_byindex(id);
319                                 ret = 0;
320                                 goto finish;
321                         }
322                 }
323                 ret = -IPSET_ERR_LIST_FULL;
324                 for (i = 0; i < map->size && ret == -IPSET_ERR_LIST_FULL; i++) {
325                         elem = list_set_elem(map, i);
326                         if (elem->id == IPSET_INVALID_ID)
327                                 ret = before != 0 ? -IPSET_ERR_REF_EXIST
328                                         : list_set_add(map, i, id, timeout);
329                         else if (elem->id != refid)
330                                 continue;
331                         else if (before > 0)
332                                 ret = list_set_add(map, i, id, timeout);
333                         else if (i + 1 < map->size)
334                                 ret = list_set_add(map, i + 1, id, timeout);
335                 }
336                 break;
337         case IPSET_DEL:
338                 ret = -IPSET_ERR_EXIST;
339                 for (i = 0; i < map->size && ret == -IPSET_ERR_EXIST; i++) {
340                         elem = list_set_elem(map, i);
341                         if (elem->id == IPSET_INVALID_ID) {
342                                 ret = before != 0 ? -IPSET_ERR_REF_EXIST
343                                                   : -IPSET_ERR_EXIST;
344                                 break;
345                         } else if (elem->id == id &&
346                                    (before == 0 ||
347                                     (before > 0 && id_eq(map, i + 1, refid))))
348                                 ret = list_set_del(map, i);
349                         else if (elem->id == refid &&
350                                  before < 0 && id_eq(map, i + 1, id))
351                                 ret = list_set_del(map, i + 1);
352                 }
353                 break;
354         default:
355                 break;
356         }
357
358 finish:
359         if (refid != IPSET_INVALID_ID)
360                 ip_set_put_byindex(refid);
361         if (adt != IPSET_ADD || ret)
362                 ip_set_put_byindex(id);
363
364         return ip_set_eexist(ret, flags) ? 0 : ret;
365 }
366
367 static void
368 list_set_flush(struct ip_set *set)
369 {
370         struct list_set *map = set->data;
371         struct set_elem *elem;
372         u32 i;
373
374         for (i = 0; i < map->size; i++) {
375                 elem = list_set_elem(map, i);
376                 if (elem->id != IPSET_INVALID_ID) {
377                         ip_set_put_byindex(elem->id);
378                         elem->id = IPSET_INVALID_ID;
379                 }
380         }
381 }
382
383 static void
384 list_set_destroy(struct ip_set *set)
385 {
386         struct list_set *map = set->data;
387
388         if (with_timeout(map->timeout))
389                 del_timer_sync(&map->gc);
390         list_set_flush(set);
391         kfree(map);
392
393         set->data = NULL;
394 }
395
396 static int
397 list_set_head(struct ip_set *set, struct sk_buff *skb)
398 {
399         const struct list_set *map = set->data;
400         struct nlattr *nested;
401
402         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
403         if (!nested)
404                 goto nla_put_failure;
405         if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) ||
406             (with_timeout(map->timeout) &&
407              nla_put_net32(skb, IPSET_ATTR_TIMEOUT, htonl(map->timeout))) ||
408             nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref - 1)) ||
409             nla_put_net32(skb, IPSET_ATTR_MEMSIZE,
410                           htonl(sizeof(*map) + map->size * map->dsize)))
411                 goto nla_put_failure;
412         ipset_nest_end(skb, nested);
413
414         return 0;
415 nla_put_failure:
416         return -EMSGSIZE;
417 }
418
419 static int
420 list_set_list(const struct ip_set *set,
421               struct sk_buff *skb, struct netlink_callback *cb)
422 {
423         const struct list_set *map = set->data;
424         struct nlattr *atd, *nested;
425         u32 i, first = cb->args[2];
426         const struct set_elem *e;
427
428         atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
429         if (!atd)
430                 return -EMSGSIZE;
431         for (; cb->args[2] < map->size; cb->args[2]++) {
432                 i = cb->args[2];
433                 e = list_set_elem(map, i);
434                 if (e->id == IPSET_INVALID_ID)
435                         goto finish;
436                 if (with_timeout(map->timeout) && list_set_expired(map, i))
437                         continue;
438                 nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
439                 if (!nested) {
440                         if (i == first) {
441                                 nla_nest_cancel(skb, atd);
442                                 return -EMSGSIZE;
443                         } else
444                                 goto nla_put_failure;
445                 }
446                 if (nla_put_string(skb, IPSET_ATTR_NAME,
447                                    ip_set_name_byindex(e->id)))
448                         goto nla_put_failure;
449                 if (with_timeout(map->timeout)) {
450                         const struct set_telem *te =
451                                 (const struct set_telem *) e;
452                         __be32 to = htonl(ip_set_timeout_get(te->timeout));
453                         if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT, to))
454                                 goto nla_put_failure;
455                 }
456                 ipset_nest_end(skb, nested);
457         }
458 finish:
459         ipset_nest_end(skb, atd);
460         /* Set listing finished */
461         cb->args[2] = 0;
462         return 0;
463
464 nla_put_failure:
465         nla_nest_cancel(skb, nested);
466         ipset_nest_end(skb, atd);
467         if (unlikely(i == first)) {
468                 cb->args[2] = 0;
469                 return -EMSGSIZE;
470         }
471         return 0;
472 }
473
474 static bool
475 list_set_same_set(const struct ip_set *a, const struct ip_set *b)
476 {
477         const struct list_set *x = a->data;
478         const struct list_set *y = b->data;
479
480         return x->size == y->size &&
481                x->timeout == y->timeout;
482 }
483
484 static const struct ip_set_type_variant list_set = {
485         .kadt   = list_set_kadt,
486         .uadt   = list_set_uadt,
487         .destroy = list_set_destroy,
488         .flush  = list_set_flush,
489         .head   = list_set_head,
490         .list   = list_set_list,
491         .same_set = list_set_same_set,
492 };
493
494 static void
495 list_set_gc(unsigned long ul_set)
496 {
497         struct ip_set *set = (struct ip_set *) ul_set;
498         struct list_set *map = set->data;
499
500         write_lock_bh(&set->lock);
501         cleanup_entries(map);
502         write_unlock_bh(&set->lock);
503
504         map->gc.expires = jiffies + IPSET_GC_PERIOD(map->timeout) * HZ;
505         add_timer(&map->gc);
506 }
507
508 static void
509 list_set_gc_init(struct ip_set *set)
510 {
511         struct list_set *map = set->data;
512
513         init_timer(&map->gc);
514         map->gc.data = (unsigned long) set;
515         map->gc.function = list_set_gc;
516         map->gc.expires = jiffies + IPSET_GC_PERIOD(map->timeout) * HZ;
517         add_timer(&map->gc);
518 }
519
520 /* Create list:set type of sets */
521
522 static bool
523 init_list_set(struct ip_set *set, u32 size, size_t dsize,
524               unsigned long timeout)
525 {
526         struct list_set *map;
527         struct set_elem *e;
528         u32 i;
529
530         map = kzalloc(sizeof(*map) + size * dsize, GFP_KERNEL);
531         if (!map)
532                 return false;
533
534         map->size = size;
535         map->dsize = dsize;
536         map->timeout = timeout;
537         set->data = map;
538
539         for (i = 0; i < size; i++) {
540                 e = list_set_elem(map, i);
541                 e->id = IPSET_INVALID_ID;
542         }
543
544         return true;
545 }
546
547 static int
548 list_set_create(struct ip_set *set, struct nlattr *tb[], u32 flags)
549 {
550         u32 size = IP_SET_LIST_DEFAULT_SIZE;
551
552         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_SIZE) ||
553                      !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT)))
554                 return -IPSET_ERR_PROTOCOL;
555
556         if (tb[IPSET_ATTR_SIZE])
557                 size = ip_set_get_h32(tb[IPSET_ATTR_SIZE]);
558         if (size < IP_SET_LIST_MIN_SIZE)
559                 size = IP_SET_LIST_MIN_SIZE;
560
561         if (tb[IPSET_ATTR_TIMEOUT]) {
562                 if (!init_list_set(set, size, sizeof(struct set_telem),
563                                    ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT])))
564                         return -ENOMEM;
565
566                 list_set_gc_init(set);
567         } else {
568                 if (!init_list_set(set, size, sizeof(struct set_elem),
569                                    IPSET_NO_TIMEOUT))
570                         return -ENOMEM;
571         }
572         set->variant = &list_set;
573         return 0;
574 }
575
576 static struct ip_set_type list_set_type __read_mostly = {
577         .name           = "list:set",
578         .protocol       = IPSET_PROTOCOL,
579         .features       = IPSET_TYPE_NAME | IPSET_DUMP_LAST,
580         .dimension      = IPSET_DIM_ONE,
581         .family         = NFPROTO_UNSPEC,
582         .revision_min   = 0,
583         .revision_max   = 0,
584         .create         = list_set_create,
585         .create_policy  = {
586                 [IPSET_ATTR_SIZE]       = { .type = NLA_U32 },
587                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
588         },
589         .adt_policy     = {
590                 [IPSET_ATTR_NAME]       = { .type = NLA_STRING,
591                                             .len = IPSET_MAXNAMELEN },
592                 [IPSET_ATTR_NAMEREF]    = { .type = NLA_STRING,
593                                             .len = IPSET_MAXNAMELEN },
594                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
595                 [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
596                 [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 },
597         },
598         .me             = THIS_MODULE,
599 };
600
601 static int __init
602 list_set_init(void)
603 {
604         return ip_set_type_register(&list_set_type);
605 }
606
607 static void __exit
608 list_set_fini(void)
609 {
610         ip_set_type_unregister(&list_set_type);
611 }
612
613 module_init(list_set_init);
614 module_exit(list_set_fini);