Linux-libre 5.7.3-gnu
[librecmc/linux-libre.git] / net / netfilter / ipvs / ip_vs_proto_udp.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * ip_vs_proto_udp.c:   UDP load balancing support for IPVS
4  *
5  * Authors:     Wensong Zhang <wensong@linuxvirtualserver.org>
6  *              Julian Anastasov <ja@ssi.bg>
7  *
8  * Changes:     Hans Schillstrom <hans.schillstrom@ericsson.com>
9  *              Network name space (netns) aware.
10  */
11
12 #define KMSG_COMPONENT "IPVS"
13 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
14
15 #include <linux/in.h>
16 #include <linux/ip.h>
17 #include <linux/kernel.h>
18 #include <linux/netfilter.h>
19 #include <linux/netfilter_ipv4.h>
20 #include <linux/udp.h>
21 #include <linux/indirect_call_wrapper.h>
22
23 #include <net/ip_vs.h>
24 #include <net/ip.h>
25 #include <net/ip6_checksum.h>
26
27 static int
28 udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp);
29
30 static int
31 udp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb,
32                   struct ip_vs_proto_data *pd,
33                   int *verdict, struct ip_vs_conn **cpp,
34                   struct ip_vs_iphdr *iph)
35 {
36         struct ip_vs_service *svc;
37         struct udphdr _udph, *uh;
38         __be16 _ports[2], *ports = NULL;
39
40         if (likely(!ip_vs_iph_icmp(iph))) {
41                 /* IPv6 fragments, only first fragment will hit this */
42                 uh = skb_header_pointer(skb, iph->len, sizeof(_udph), &_udph);
43                 if (uh)
44                         ports = &uh->source;
45         } else {
46                 ports = skb_header_pointer(
47                         skb, iph->len, sizeof(_ports), &_ports);
48         }
49
50         if (!ports) {
51                 *verdict = NF_DROP;
52                 return 0;
53         }
54
55         if (likely(!ip_vs_iph_inverse(iph)))
56                 svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol,
57                                          &iph->daddr, ports[1]);
58         else
59                 svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol,
60                                          &iph->saddr, ports[0]);
61
62         if (svc) {
63                 int ignored;
64
65                 if (ip_vs_todrop(ipvs)) {
66                         /*
67                          * It seems that we are very loaded.
68                          * We have to drop this packet :(
69                          */
70                         *verdict = NF_DROP;
71                         return 0;
72                 }
73
74                 /*
75                  * Let the virtual server select a real server for the
76                  * incoming connection, and create a connection entry.
77                  */
78                 *cpp = ip_vs_schedule(svc, skb, pd, &ignored, iph);
79                 if (!*cpp && ignored <= 0) {
80                         if (!ignored)
81                                 *verdict = ip_vs_leave(svc, skb, pd, iph);
82                         else
83                                 *verdict = NF_DROP;
84                         return 0;
85                 }
86         }
87         /* NF_ACCEPT */
88         return 1;
89 }
90
91
92 static inline void
93 udp_fast_csum_update(int af, struct udphdr *uhdr,
94                      const union nf_inet_addr *oldip,
95                      const union nf_inet_addr *newip,
96                      __be16 oldport, __be16 newport)
97 {
98 #ifdef CONFIG_IP_VS_IPV6
99         if (af == AF_INET6)
100                 uhdr->check =
101                         csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6,
102                                          ip_vs_check_diff2(oldport, newport,
103                                                 ~csum_unfold(uhdr->check))));
104         else
105 #endif
106                 uhdr->check =
107                         csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip,
108                                          ip_vs_check_diff2(oldport, newport,
109                                                 ~csum_unfold(uhdr->check))));
110         if (!uhdr->check)
111                 uhdr->check = CSUM_MANGLED_0;
112 }
113
114 static inline void
115 udp_partial_csum_update(int af, struct udphdr *uhdr,
116                      const union nf_inet_addr *oldip,
117                      const union nf_inet_addr *newip,
118                      __be16 oldlen, __be16 newlen)
119 {
120 #ifdef CONFIG_IP_VS_IPV6
121         if (af == AF_INET6)
122                 uhdr->check =
123                         ~csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6,
124                                          ip_vs_check_diff2(oldlen, newlen,
125                                                 csum_unfold(uhdr->check))));
126         else
127 #endif
128         uhdr->check =
129                 ~csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip,
130                                 ip_vs_check_diff2(oldlen, newlen,
131                                                 csum_unfold(uhdr->check))));
132 }
133
134
135 INDIRECT_CALLABLE_SCOPE int
136 udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp,
137                  struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)
138 {
139         struct udphdr *udph;
140         unsigned int udphoff = iph->len;
141         bool payload_csum = false;
142         int oldlen;
143
144 #ifdef CONFIG_IP_VS_IPV6
145         if (cp->af == AF_INET6 && iph->fragoffs)
146                 return 1;
147 #endif
148         oldlen = skb->len - udphoff;
149
150         /* csum_check requires unshared skb */
151         if (skb_ensure_writable(skb, udphoff + sizeof(*udph)))
152                 return 0;
153
154         if (unlikely(cp->app != NULL)) {
155                 int ret;
156
157                 /* Some checks before mangling */
158                 if (!udp_csum_check(cp->af, skb, pp))
159                         return 0;
160
161                 /*
162                  *      Call application helper if needed
163                  */
164                 if (!(ret = ip_vs_app_pkt_out(cp, skb, iph)))
165                         return 0;
166                 /* ret=2: csum update is needed after payload mangling */
167                 if (ret == 1)
168                         oldlen = skb->len - udphoff;
169                 else
170                         payload_csum = true;
171         }
172
173         udph = (void *)skb_network_header(skb) + udphoff;
174         udph->source = cp->vport;
175
176         /*
177          *      Adjust UDP checksums
178          */
179         if (skb->ip_summed == CHECKSUM_PARTIAL) {
180                 udp_partial_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr,
181                                         htons(oldlen),
182                                         htons(skb->len - udphoff));
183         } else if (!payload_csum && (udph->check != 0)) {
184                 /* Only port and addr are changed, do fast csum update */
185                 udp_fast_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr,
186                                      cp->dport, cp->vport);
187                 if (skb->ip_summed == CHECKSUM_COMPLETE)
188                         skb->ip_summed = cp->app ?
189                                          CHECKSUM_UNNECESSARY : CHECKSUM_NONE;
190         } else {
191                 /* full checksum calculation */
192                 udph->check = 0;
193                 skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0);
194 #ifdef CONFIG_IP_VS_IPV6
195                 if (cp->af == AF_INET6)
196                         udph->check = csum_ipv6_magic(&cp->vaddr.in6,
197                                                       &cp->caddr.in6,
198                                                       skb->len - udphoff,
199                                                       cp->protocol, skb->csum);
200                 else
201 #endif
202                         udph->check = csum_tcpudp_magic(cp->vaddr.ip,
203                                                         cp->caddr.ip,
204                                                         skb->len - udphoff,
205                                                         cp->protocol,
206                                                         skb->csum);
207                 if (udph->check == 0)
208                         udph->check = CSUM_MANGLED_0;
209                 skb->ip_summed = CHECKSUM_UNNECESSARY;
210                 IP_VS_DBG(11, "O-pkt: %s O-csum=%d (+%zd)\n",
211                           pp->name, udph->check,
212                           (char*)&(udph->check) - (char*)udph);
213         }
214         return 1;
215 }
216
217
218 static int
219 udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp,
220                  struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)
221 {
222         struct udphdr *udph;
223         unsigned int udphoff = iph->len;
224         bool payload_csum = false;
225         int oldlen;
226
227 #ifdef CONFIG_IP_VS_IPV6
228         if (cp->af == AF_INET6 && iph->fragoffs)
229                 return 1;
230 #endif
231         oldlen = skb->len - udphoff;
232
233         /* csum_check requires unshared skb */
234         if (skb_ensure_writable(skb, udphoff + sizeof(*udph)))
235                 return 0;
236
237         if (unlikely(cp->app != NULL)) {
238                 int ret;
239
240                 /* Some checks before mangling */
241                 if (!udp_csum_check(cp->af, skb, pp))
242                         return 0;
243
244                 /*
245                  *      Attempt ip_vs_app call.
246                  *      It will fix ip_vs_conn
247                  */
248                 if (!(ret = ip_vs_app_pkt_in(cp, skb, iph)))
249                         return 0;
250                 /* ret=2: csum update is needed after payload mangling */
251                 if (ret == 1)
252                         oldlen = skb->len - udphoff;
253                 else
254                         payload_csum = true;
255         }
256
257         udph = (void *)skb_network_header(skb) + udphoff;
258         udph->dest = cp->dport;
259
260         /*
261          *      Adjust UDP checksums
262          */
263         if (skb->ip_summed == CHECKSUM_PARTIAL) {
264                 udp_partial_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr,
265                                         htons(oldlen),
266                                         htons(skb->len - udphoff));
267         } else if (!payload_csum && (udph->check != 0)) {
268                 /* Only port and addr are changed, do fast csum update */
269                 udp_fast_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr,
270                                      cp->vport, cp->dport);
271                 if (skb->ip_summed == CHECKSUM_COMPLETE)
272                         skb->ip_summed = cp->app ?
273                                          CHECKSUM_UNNECESSARY : CHECKSUM_NONE;
274         } else {
275                 /* full checksum calculation */
276                 udph->check = 0;
277                 skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0);
278 #ifdef CONFIG_IP_VS_IPV6
279                 if (cp->af == AF_INET6)
280                         udph->check = csum_ipv6_magic(&cp->caddr.in6,
281                                                       &cp->daddr.in6,
282                                                       skb->len - udphoff,
283                                                       cp->protocol, skb->csum);
284                 else
285 #endif
286                         udph->check = csum_tcpudp_magic(cp->caddr.ip,
287                                                         cp->daddr.ip,
288                                                         skb->len - udphoff,
289                                                         cp->protocol,
290                                                         skb->csum);
291                 if (udph->check == 0)
292                         udph->check = CSUM_MANGLED_0;
293                 skb->ip_summed = CHECKSUM_UNNECESSARY;
294         }
295         return 1;
296 }
297
298
299 static int
300 udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp)
301 {
302         struct udphdr _udph, *uh;
303         unsigned int udphoff;
304
305 #ifdef CONFIG_IP_VS_IPV6
306         if (af == AF_INET6)
307                 udphoff = sizeof(struct ipv6hdr);
308         else
309 #endif
310                 udphoff = ip_hdrlen(skb);
311
312         uh = skb_header_pointer(skb, udphoff, sizeof(_udph), &_udph);
313         if (uh == NULL)
314                 return 0;
315
316         if (uh->check != 0) {
317                 switch (skb->ip_summed) {
318                 case CHECKSUM_NONE:
319                         skb->csum = skb_checksum(skb, udphoff,
320                                                  skb->len - udphoff, 0);
321                         /* fall through */
322                 case CHECKSUM_COMPLETE:
323 #ifdef CONFIG_IP_VS_IPV6
324                         if (af == AF_INET6) {
325                                 if (csum_ipv6_magic(&ipv6_hdr(skb)->saddr,
326                                                     &ipv6_hdr(skb)->daddr,
327                                                     skb->len - udphoff,
328                                                     ipv6_hdr(skb)->nexthdr,
329                                                     skb->csum)) {
330                                         IP_VS_DBG_RL_PKT(0, af, pp, skb, 0,
331                                                          "Failed checksum for");
332                                         return 0;
333                                 }
334                         } else
335 #endif
336                                 if (csum_tcpudp_magic(ip_hdr(skb)->saddr,
337                                                       ip_hdr(skb)->daddr,
338                                                       skb->len - udphoff,
339                                                       ip_hdr(skb)->protocol,
340                                                       skb->csum)) {
341                                         IP_VS_DBG_RL_PKT(0, af, pp, skb, 0,
342                                                          "Failed checksum for");
343                                         return 0;
344                                 }
345                         break;
346                 default:
347                         /* No need to checksum. */
348                         break;
349                 }
350         }
351         return 1;
352 }
353
354 static inline __u16 udp_app_hashkey(__be16 port)
355 {
356         return (((__force u16)port >> UDP_APP_TAB_BITS) ^ (__force u16)port)
357                 & UDP_APP_TAB_MASK;
358 }
359
360
361 static int udp_register_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc)
362 {
363         struct ip_vs_app *i;
364         __u16 hash;
365         __be16 port = inc->port;
366         int ret = 0;
367         struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP);
368
369         hash = udp_app_hashkey(port);
370
371         list_for_each_entry(i, &ipvs->udp_apps[hash], p_list) {
372                 if (i->port == port) {
373                         ret = -EEXIST;
374                         goto out;
375                 }
376         }
377         list_add_rcu(&inc->p_list, &ipvs->udp_apps[hash]);
378         atomic_inc(&pd->appcnt);
379
380   out:
381         return ret;
382 }
383
384
385 static void
386 udp_unregister_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc)
387 {
388         struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP);
389
390         atomic_dec(&pd->appcnt);
391         list_del_rcu(&inc->p_list);
392 }
393
394
395 static int udp_app_conn_bind(struct ip_vs_conn *cp)
396 {
397         struct netns_ipvs *ipvs = cp->ipvs;
398         int hash;
399         struct ip_vs_app *inc;
400         int result = 0;
401
402         /* Default binding: bind app only for NAT */
403         if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ)
404                 return 0;
405
406         /* Lookup application incarnations and bind the right one */
407         hash = udp_app_hashkey(cp->vport);
408
409         list_for_each_entry_rcu(inc, &ipvs->udp_apps[hash], p_list) {
410                 if (inc->port == cp->vport) {
411                         if (unlikely(!ip_vs_app_inc_get(inc)))
412                                 break;
413
414                         IP_VS_DBG_BUF(9, "%s(): Binding conn %s:%u->"
415                                       "%s:%u to app %s on port %u\n",
416                                       __func__,
417                                       IP_VS_DBG_ADDR(cp->af, &cp->caddr),
418                                       ntohs(cp->cport),
419                                       IP_VS_DBG_ADDR(cp->af, &cp->vaddr),
420                                       ntohs(cp->vport),
421                                       inc->name, ntohs(inc->port));
422
423                         cp->app = inc;
424                         if (inc->init_conn)
425                                 result = inc->init_conn(inc, cp);
426                         break;
427                 }
428         }
429
430         return result;
431 }
432
433
434 static const int udp_timeouts[IP_VS_UDP_S_LAST+1] = {
435         [IP_VS_UDP_S_NORMAL]            =       5*60*HZ,
436         [IP_VS_UDP_S_LAST]              =       2*HZ,
437 };
438
439 static const char *const udp_state_name_table[IP_VS_UDP_S_LAST+1] = {
440         [IP_VS_UDP_S_NORMAL]            =       "UDP",
441         [IP_VS_UDP_S_LAST]              =       "BUG!",
442 };
443
444 static const char * udp_state_name(int state)
445 {
446         if (state >= IP_VS_UDP_S_LAST)
447                 return "ERR!";
448         return udp_state_name_table[state] ? udp_state_name_table[state] : "?";
449 }
450
451 static void
452 udp_state_transition(struct ip_vs_conn *cp, int direction,
453                      const struct sk_buff *skb,
454                      struct ip_vs_proto_data *pd)
455 {
456         if (unlikely(!pd)) {
457                 pr_err("UDP no ns data\n");
458                 return;
459         }
460
461         cp->timeout = pd->timeout_table[IP_VS_UDP_S_NORMAL];
462         if (direction == IP_VS_DIR_OUTPUT)
463                 ip_vs_control_assure_ct(cp);
464 }
465
466 static int __udp_init(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd)
467 {
468         ip_vs_init_hash_table(ipvs->udp_apps, UDP_APP_TAB_SIZE);
469         pd->timeout_table = ip_vs_create_timeout_table((int *)udp_timeouts,
470                                                         sizeof(udp_timeouts));
471         if (!pd->timeout_table)
472                 return -ENOMEM;
473         return 0;
474 }
475
476 static void __udp_exit(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd)
477 {
478         kfree(pd->timeout_table);
479 }
480
481
482 struct ip_vs_protocol ip_vs_protocol_udp = {
483         .name =                 "UDP",
484         .protocol =             IPPROTO_UDP,
485         .num_states =           IP_VS_UDP_S_LAST,
486         .dont_defrag =          0,
487         .init =                 NULL,
488         .exit =                 NULL,
489         .init_netns =           __udp_init,
490         .exit_netns =           __udp_exit,
491         .conn_schedule =        udp_conn_schedule,
492         .conn_in_get =          ip_vs_conn_in_get_proto,
493         .conn_out_get =         ip_vs_conn_out_get_proto,
494         .snat_handler =         udp_snat_handler,
495         .dnat_handler =         udp_dnat_handler,
496         .state_transition =     udp_state_transition,
497         .state_name =           udp_state_name,
498         .register_app =         udp_register_app,
499         .unregister_app =       udp_unregister_app,
500         .app_conn_bind =        udp_app_conn_bind,
501         .debug_packet =         ip_vs_tcpudp_debug_packet,
502         .timeout_change =       NULL,
503 };