Linux-libre 3.18.37-gnu
[librecmc/linux-libre.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 static DEFINE_SPINLOCK(fou_lock);
20 static LIST_HEAD(fou_list);
21
22 struct fou {
23         struct socket *sock;
24         u8 protocol;
25         u16 port;
26         struct udp_offload udp_offloads;
27         struct list_head list;
28         struct rcu_head rcu;
29 };
30
31 struct fou_cfg {
32         u16 type;
33         u8 protocol;
34         struct udp_port_cfg udp_config;
35 };
36
37 static inline struct fou *fou_from_sock(struct sock *sk)
38 {
39         return sk->sk_user_data;
40 }
41
42 static int fou_udp_encap_recv_deliver(struct sk_buff *skb,
43                                       u8 protocol, size_t len)
44 {
45         struct iphdr *iph = ip_hdr(skb);
46
47         /* Remove 'len' bytes from the packet (UDP header and
48          * FOU header if present), modify the protocol to the one
49          * we found, and then call rcv_encap.
50          */
51         iph->tot_len = htons(ntohs(iph->tot_len) - len);
52         __skb_pull(skb, len);
53         skb_postpull_rcsum(skb, udp_hdr(skb), len);
54         skb_reset_transport_header(skb);
55
56         return -protocol;
57 }
58
59 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
60 {
61         struct fou *fou = fou_from_sock(sk);
62
63         if (!fou)
64                 return 1;
65
66         return fou_udp_encap_recv_deliver(skb, fou->protocol,
67                                           sizeof(struct udphdr));
68 }
69
70 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
71 {
72         struct fou *fou = fou_from_sock(sk);
73         size_t len;
74         struct guehdr *guehdr;
75         struct udphdr *uh;
76
77         if (!fou)
78                 return 1;
79
80         len = sizeof(struct udphdr) + sizeof(struct guehdr);
81         if (!pskb_may_pull(skb, len))
82                 goto drop;
83
84         uh = udp_hdr(skb);
85         guehdr = (struct guehdr *)&uh[1];
86
87         len += guehdr->hlen << 2;
88         if (!pskb_may_pull(skb, len))
89                 goto drop;
90
91         uh = udp_hdr(skb);
92         guehdr = (struct guehdr *)&uh[1];
93
94         if (guehdr->version != 0)
95                 goto drop;
96
97         if (guehdr->flags) {
98                 /* No support yet */
99                 goto drop;
100         }
101
102         return fou_udp_encap_recv_deliver(skb, guehdr->next_hdr, len);
103 drop:
104         kfree_skb(skb);
105         return 0;
106 }
107
108 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
109                                         struct sk_buff *skb)
110 {
111         const struct net_offload *ops;
112         struct sk_buff **pp = NULL;
113         u8 proto = NAPI_GRO_CB(skb)->proto;
114         const struct net_offload **offloads;
115
116         rcu_read_lock();
117         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
118         ops = rcu_dereference(offloads[proto]);
119         if (!ops || !ops->callbacks.gro_receive)
120                 goto out_unlock;
121
122         pp = ops->callbacks.gro_receive(head, skb);
123
124 out_unlock:
125         rcu_read_unlock();
126
127         return pp;
128 }
129
130 static int fou_gro_complete(struct sk_buff *skb, int nhoff)
131 {
132         const struct net_offload *ops;
133         u8 proto = NAPI_GRO_CB(skb)->proto;
134         int err = -ENOSYS;
135         const struct net_offload **offloads;
136
137         udp_tunnel_gro_complete(skb, nhoff);
138
139         rcu_read_lock();
140         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
141         ops = rcu_dereference(offloads[proto]);
142         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
143                 goto out_unlock;
144
145         err = ops->callbacks.gro_complete(skb, nhoff);
146
147 out_unlock:
148         rcu_read_unlock();
149
150         return err;
151 }
152
153 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
154                                         struct sk_buff *skb)
155 {
156         const struct net_offload **offloads;
157         const struct net_offload *ops;
158         struct sk_buff **pp = NULL;
159         struct sk_buff *p;
160         u8 proto;
161         struct guehdr *guehdr;
162         unsigned int hlen, guehlen;
163         unsigned int off;
164         int flush = 1;
165
166         off = skb_gro_offset(skb);
167         hlen = off + sizeof(*guehdr);
168         guehdr = skb_gro_header_fast(skb, off);
169         if (skb_gro_header_hard(skb, hlen)) {
170                 guehdr = skb_gro_header_slow(skb, hlen, off);
171                 if (unlikely(!guehdr))
172                         goto out;
173         }
174
175         proto = guehdr->next_hdr;
176
177         rcu_read_lock();
178         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
179         ops = rcu_dereference(offloads[proto]);
180         if (WARN_ON(!ops || !ops->callbacks.gro_receive))
181                 goto out_unlock;
182
183         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
184
185         hlen = off + guehlen;
186         if (skb_gro_header_hard(skb, hlen)) {
187                 guehdr = skb_gro_header_slow(skb, hlen, off);
188                 if (unlikely(!guehdr))
189                         goto out_unlock;
190         }
191
192         flush = 0;
193
194         for (p = *head; p; p = p->next) {
195                 const struct guehdr *guehdr2;
196
197                 if (!NAPI_GRO_CB(p)->same_flow)
198                         continue;
199
200                 guehdr2 = (struct guehdr *)(p->data + off);
201
202                 /* Compare base GUE header to be equal (covers
203                  * hlen, version, next_hdr, and flags.
204                  */
205                 if (guehdr->word != guehdr2->word) {
206                         NAPI_GRO_CB(p)->same_flow = 0;
207                         continue;
208                 }
209
210                 /* Compare optional fields are the same. */
211                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
212                                            guehdr->hlen << 2)) {
213                         NAPI_GRO_CB(p)->same_flow = 0;
214                         continue;
215                 }
216         }
217
218         skb_gro_pull(skb, guehlen);
219
220         /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
221         skb_gro_postpull_rcsum(skb, guehdr, guehlen);
222
223         pp = ops->callbacks.gro_receive(head, skb);
224
225 out_unlock:
226         rcu_read_unlock();
227 out:
228         NAPI_GRO_CB(skb)->flush |= flush;
229
230         return pp;
231 }
232
233 static int gue_gro_complete(struct sk_buff *skb, int nhoff)
234 {
235         const struct net_offload **offloads;
236         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
237         const struct net_offload *ops;
238         unsigned int guehlen;
239         u8 proto;
240         int err = -ENOENT;
241
242         proto = guehdr->next_hdr;
243
244         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
245
246         rcu_read_lock();
247         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
248         ops = rcu_dereference(offloads[proto]);
249         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
250                 goto out_unlock;
251
252         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
253
254 out_unlock:
255         rcu_read_unlock();
256         return err;
257 }
258
259 static int fou_add_to_port_list(struct fou *fou)
260 {
261         struct fou *fout;
262
263         spin_lock(&fou_lock);
264         list_for_each_entry(fout, &fou_list, list) {
265                 if (fou->port == fout->port) {
266                         spin_unlock(&fou_lock);
267                         return -EALREADY;
268                 }
269         }
270
271         list_add(&fou->list, &fou_list);
272         spin_unlock(&fou_lock);
273
274         return 0;
275 }
276
277 static void fou_release(struct fou *fou)
278 {
279         struct socket *sock = fou->sock;
280         struct sock *sk = sock->sk;
281
282         udp_del_offload(&fou->udp_offloads);
283
284         list_del(&fou->list);
285
286         /* Remove hooks into tunnel socket */
287         sk->sk_user_data = NULL;
288
289         sock_release(sock);
290
291         kfree_rcu(fou, rcu);
292 }
293
294 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
295 {
296         udp_sk(sk)->encap_rcv = fou_udp_recv;
297         fou->protocol = cfg->protocol;
298         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
299         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
300         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
301         fou->udp_offloads.ipproto = cfg->protocol;
302
303         return 0;
304 }
305
306 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
307 {
308         udp_sk(sk)->encap_rcv = gue_udp_recv;
309         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
310         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
311         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
312
313         return 0;
314 }
315
316 static int fou_create(struct net *net, struct fou_cfg *cfg,
317                       struct socket **sockp)
318 {
319         struct fou *fou = NULL;
320         int err;
321         struct socket *sock = NULL;
322         struct sock *sk;
323
324         /* Open UDP socket */
325         err = udp_sock_create(net, &cfg->udp_config, &sock);
326         if (err < 0)
327                 goto error;
328
329         /* Allocate FOU port structure */
330         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
331         if (!fou) {
332                 err = -ENOMEM;
333                 goto error;
334         }
335
336         sk = sock->sk;
337
338         fou->port = cfg->udp_config.local_udp_port;
339
340         /* Initial for fou type */
341         switch (cfg->type) {
342         case FOU_ENCAP_DIRECT:
343                 err = fou_encap_init(sk, fou, cfg);
344                 if (err)
345                         goto error;
346                 break;
347         case FOU_ENCAP_GUE:
348                 err = gue_encap_init(sk, fou, cfg);
349                 if (err)
350                         goto error;
351                 break;
352         default:
353                 err = -EINVAL;
354                 goto error;
355         }
356
357         udp_sk(sk)->encap_type = 1;
358         udp_encap_enable();
359
360         sk->sk_user_data = fou;
361         fou->sock = sock;
362
363         udp_set_convert_csum(sk, true);
364
365         sk->sk_allocation = GFP_ATOMIC;
366
367         if (cfg->udp_config.family == AF_INET) {
368                 err = udp_add_offload(&fou->udp_offloads);
369                 if (err)
370                         goto error;
371         }
372
373         err = fou_add_to_port_list(fou);
374         if (err)
375                 goto error;
376
377         if (sockp)
378                 *sockp = sock;
379
380         return 0;
381
382 error:
383         kfree(fou);
384         if (sock)
385                 sock_release(sock);
386
387         return err;
388 }
389
390 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
391 {
392         struct fou *fou;
393         u16 port = cfg->udp_config.local_udp_port;
394         int err = -EINVAL;
395
396         spin_lock(&fou_lock);
397         list_for_each_entry(fou, &fou_list, list) {
398                 if (fou->port == port) {
399                         udp_del_offload(&fou->udp_offloads);
400                         fou_release(fou);
401                         err = 0;
402                         break;
403                 }
404         }
405         spin_unlock(&fou_lock);
406
407         return err;
408 }
409
410 static struct genl_family fou_nl_family = {
411         .id             = GENL_ID_GENERATE,
412         .hdrsize        = 0,
413         .name           = FOU_GENL_NAME,
414         .version        = FOU_GENL_VERSION,
415         .maxattr        = FOU_ATTR_MAX,
416         .netnsok        = true,
417 };
418
419 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
420         [FOU_ATTR_PORT] = { .type = NLA_U16, },
421         [FOU_ATTR_AF] = { .type = NLA_U8, },
422         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
423         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
424 };
425
426 static int parse_nl_config(struct genl_info *info,
427                            struct fou_cfg *cfg)
428 {
429         memset(cfg, 0, sizeof(*cfg));
430
431         cfg->udp_config.family = AF_INET;
432
433         if (info->attrs[FOU_ATTR_AF]) {
434                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
435
436                 if (family != AF_INET && family != AF_INET6)
437                         return -EINVAL;
438
439                 cfg->udp_config.family = family;
440         }
441
442         if (info->attrs[FOU_ATTR_PORT]) {
443                 u16 port = nla_get_u16(info->attrs[FOU_ATTR_PORT]);
444
445                 cfg->udp_config.local_udp_port = port;
446         }
447
448         if (info->attrs[FOU_ATTR_IPPROTO])
449                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
450
451         if (info->attrs[FOU_ATTR_TYPE])
452                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
453
454         return 0;
455 }
456
457 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
458 {
459         struct fou_cfg cfg;
460         int err;
461
462         err = parse_nl_config(info, &cfg);
463         if (err)
464                 return err;
465
466         return fou_create(&init_net, &cfg, NULL);
467 }
468
469 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
470 {
471         struct fou_cfg cfg;
472
473         parse_nl_config(info, &cfg);
474
475         return fou_destroy(&init_net, &cfg);
476 }
477
478 static const struct genl_ops fou_nl_ops[] = {
479         {
480                 .cmd = FOU_CMD_ADD,
481                 .doit = fou_nl_cmd_add_port,
482                 .policy = fou_nl_policy,
483                 .flags = GENL_ADMIN_PERM,
484         },
485         {
486                 .cmd = FOU_CMD_DEL,
487                 .doit = fou_nl_cmd_rm_port,
488                 .policy = fou_nl_policy,
489                 .flags = GENL_ADMIN_PERM,
490         },
491 };
492
493 static int __init fou_init(void)
494 {
495         int ret;
496
497         ret = genl_register_family_with_ops(&fou_nl_family,
498                                             fou_nl_ops);
499
500         return ret;
501 }
502
503 static void __exit fou_fini(void)
504 {
505         struct fou *fou, *next;
506
507         genl_unregister_family(&fou_nl_family);
508
509         /* Close all the FOU sockets */
510
511         spin_lock(&fou_lock);
512         list_for_each_entry_safe(fou, next, &fou_list, list)
513                 fou_release(fou);
514         spin_unlock(&fou_lock);
515 }
516
517 module_init(fou_init);
518 module_exit(fou_fini);
519 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
520 MODULE_LICENSE("GPL");