Linux-libre 5.7.3-gnu
[librecmc/linux-libre.git] / tools / testing / selftests / bpf / prog_tests / sk_assign.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 // Copyright (c) 2019 Cloudflare
4 // Copyright (c) 2020 Isovalent, Inc.
5 /*
6  * Test that the socket assign program is able to redirect traffic towards a
7  * socket, regardless of whether the port or address destination of the traffic
8  * matches the port.
9  */
10
11 #define _GNU_SOURCE
12 #include <fcntl.h>
13 #include <signal.h>
14 #include <stdlib.h>
15 #include <unistd.h>
16
17 #include "test_progs.h"
18
19 #define BIND_PORT 1234
20 #define CONNECT_PORT 4321
21 #define TEST_DADDR (0xC0A80203)
22 #define NS_SELF "/proc/self/ns/net"
23
24 static const struct timeval timeo_sec = { .tv_sec = 3 };
25 static const size_t timeo_optlen = sizeof(timeo_sec);
26 static int stop, duration;
27
28 static bool
29 configure_stack(void)
30 {
31         char tc_cmd[BUFSIZ];
32
33         /* Move to a new networking namespace */
34         if (CHECK_FAIL(unshare(CLONE_NEWNET)))
35                 return false;
36
37         /* Configure necessary links, routes */
38         if (CHECK_FAIL(system("ip link set dev lo up")))
39                 return false;
40         if (CHECK_FAIL(system("ip route add local default dev lo")))
41                 return false;
42         if (CHECK_FAIL(system("ip -6 route add local default dev lo")))
43                 return false;
44
45         /* Load qdisc, BPF program */
46         if (CHECK_FAIL(system("tc qdisc add dev lo clsact")))
47                 return false;
48         sprintf(tc_cmd, "%s %s %s %s", "tc filter add dev lo ingress bpf",
49                        "direct-action object-file ./test_sk_assign.o",
50                        "section classifier/sk_assign_test",
51                        (env.verbosity < VERBOSE_VERY) ? " 2>/dev/null" : "");
52         if (CHECK(system(tc_cmd), "BPF load failed;",
53                   "run with -vv for more info\n"))
54                 return false;
55
56         return true;
57 }
58
59 static int
60 start_server(const struct sockaddr *addr, socklen_t len, int type)
61 {
62         int fd;
63
64         fd = socket(addr->sa_family, type, 0);
65         if (CHECK_FAIL(fd == -1))
66                 goto out;
67         if (CHECK_FAIL(setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec,
68                                   timeo_optlen)))
69                 goto close_out;
70         if (CHECK_FAIL(bind(fd, addr, len) == -1))
71                 goto close_out;
72         if (type == SOCK_STREAM && CHECK_FAIL(listen(fd, 128) == -1))
73                 goto close_out;
74
75         goto out;
76 close_out:
77         close(fd);
78         fd = -1;
79 out:
80         return fd;
81 }
82
83 static int
84 connect_to_server(const struct sockaddr *addr, socklen_t len, int type)
85 {
86         int fd = -1;
87
88         fd = socket(addr->sa_family, type, 0);
89         if (CHECK_FAIL(fd == -1))
90                 goto out;
91         if (CHECK_FAIL(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec,
92                                   timeo_optlen)))
93                 goto close_out;
94         if (CHECK_FAIL(connect(fd, addr, len)))
95                 goto close_out;
96
97         goto out;
98 close_out:
99         close(fd);
100         fd = -1;
101 out:
102         return fd;
103 }
104
105 static in_port_t
106 get_port(int fd)
107 {
108         struct sockaddr_storage ss;
109         socklen_t slen = sizeof(ss);
110         in_port_t port = 0;
111
112         if (CHECK_FAIL(getsockname(fd, (struct sockaddr *)&ss, &slen)))
113                 return port;
114
115         switch (ss.ss_family) {
116         case AF_INET:
117                 port = ((struct sockaddr_in *)&ss)->sin_port;
118                 break;
119         case AF_INET6:
120                 port = ((struct sockaddr_in6 *)&ss)->sin6_port;
121                 break;
122         default:
123                 CHECK(1, "Invalid address family", "%d\n", ss.ss_family);
124         }
125         return port;
126 }
127
128 static ssize_t
129 rcv_msg(int srv_client, int type)
130 {
131         struct sockaddr_storage ss;
132         char buf[BUFSIZ];
133         socklen_t slen;
134
135         if (type == SOCK_STREAM)
136                 return read(srv_client, &buf, sizeof(buf));
137         else
138                 return recvfrom(srv_client, &buf, sizeof(buf), 0,
139                                 (struct sockaddr *)&ss, &slen);
140 }
141
142 static int
143 run_test(int server_fd, const struct sockaddr *addr, socklen_t len, int type)
144 {
145         int client = -1, srv_client = -1;
146         char buf[] = "testing";
147         in_port_t port;
148         int ret = 1;
149
150         client = connect_to_server(addr, len, type);
151         if (client == -1) {
152                 perror("Cannot connect to server");
153                 goto out;
154         }
155
156         if (type == SOCK_STREAM) {
157                 srv_client = accept(server_fd, NULL, NULL);
158                 if (CHECK_FAIL(srv_client == -1)) {
159                         perror("Can't accept connection");
160                         goto out;
161                 }
162         } else {
163                 srv_client = server_fd;
164         }
165         if (CHECK_FAIL(write(client, buf, sizeof(buf)) != sizeof(buf))) {
166                 perror("Can't write on client");
167                 goto out;
168         }
169         if (CHECK_FAIL(rcv_msg(srv_client, type) != sizeof(buf))) {
170                 perror("Can't read on server");
171                 goto out;
172         }
173
174         port = get_port(srv_client);
175         if (CHECK_FAIL(!port))
176                 goto out;
177         /* SOCK_STREAM is connected via accept(), so the server's local address
178          * will be the CONNECT_PORT rather than the BIND port that corresponds
179          * to the listen socket. SOCK_DGRAM on the other hand is connectionless
180          * so we can't really do the same check there; the server doesn't ever
181          * create a socket with CONNECT_PORT.
182          */
183         if (type == SOCK_STREAM &&
184             CHECK(port != htons(CONNECT_PORT), "Expected", "port %u but got %u",
185                   CONNECT_PORT, ntohs(port)))
186                 goto out;
187         else if (type == SOCK_DGRAM &&
188                  CHECK(port != htons(BIND_PORT), "Expected",
189                        "port %u but got %u", BIND_PORT, ntohs(port)))
190                 goto out;
191
192         ret = 0;
193 out:
194         close(client);
195         if (srv_client != server_fd)
196                 close(srv_client);
197         if (ret)
198                 WRITE_ONCE(stop, 1);
199         return ret;
200 }
201
202 static void
203 prepare_addr(struct sockaddr *addr, int family, __u16 port, bool rewrite_addr)
204 {
205         struct sockaddr_in *addr4;
206         struct sockaddr_in6 *addr6;
207
208         switch (family) {
209         case AF_INET:
210                 addr4 = (struct sockaddr_in *)addr;
211                 memset(addr4, 0, sizeof(*addr4));
212                 addr4->sin_family = family;
213                 addr4->sin_port = htons(port);
214                 if (rewrite_addr)
215                         addr4->sin_addr.s_addr = htonl(TEST_DADDR);
216                 else
217                         addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
218                 break;
219         case AF_INET6:
220                 addr6 = (struct sockaddr_in6 *)addr;
221                 memset(addr6, 0, sizeof(*addr6));
222                 addr6->sin6_family = family;
223                 addr6->sin6_port = htons(port);
224                 addr6->sin6_addr = in6addr_loopback;
225                 if (rewrite_addr)
226                         addr6->sin6_addr.s6_addr32[3] = htonl(TEST_DADDR);
227                 break;
228         default:
229                 fprintf(stderr, "Invalid family %d", family);
230         }
231 }
232
233 struct test_sk_cfg {
234         const char *name;
235         int family;
236         struct sockaddr *addr;
237         socklen_t len;
238         int type;
239         bool rewrite_addr;
240 };
241
242 #define TEST(NAME, FAMILY, TYPE, REWRITE)                               \
243 {                                                                       \
244         .name = NAME,                                                   \
245         .family = FAMILY,                                               \
246         .addr = (FAMILY == AF_INET) ? (struct sockaddr *)&addr4         \
247                                     : (struct sockaddr *)&addr6,        \
248         .len = (FAMILY == AF_INET) ? sizeof(addr4) : sizeof(addr6),     \
249         .type = TYPE,                                                   \
250         .rewrite_addr = REWRITE,                                        \
251 }
252
253 void test_sk_assign(void)
254 {
255         struct sockaddr_in addr4;
256         struct sockaddr_in6 addr6;
257         struct test_sk_cfg tests[] = {
258                 TEST("ipv4 tcp port redir", AF_INET, SOCK_STREAM, false),
259                 TEST("ipv4 tcp addr redir", AF_INET, SOCK_STREAM, true),
260                 TEST("ipv6 tcp port redir", AF_INET6, SOCK_STREAM, false),
261                 TEST("ipv6 tcp addr redir", AF_INET6, SOCK_STREAM, true),
262                 TEST("ipv4 udp port redir", AF_INET, SOCK_DGRAM, false),
263                 TEST("ipv4 udp addr redir", AF_INET, SOCK_DGRAM, true),
264                 TEST("ipv6 udp port redir", AF_INET6, SOCK_DGRAM, false),
265                 TEST("ipv6 udp addr redir", AF_INET6, SOCK_DGRAM, true),
266         };
267         int server = -1;
268         int self_net;
269
270         self_net = open(NS_SELF, O_RDONLY);
271         if (CHECK_FAIL(self_net < 0)) {
272                 perror("Unable to open "NS_SELF);
273                 return;
274         }
275
276         if (!configure_stack()) {
277                 perror("configure_stack");
278                 goto cleanup;
279         }
280
281         for (int i = 0; i < ARRAY_SIZE(tests) && !READ_ONCE(stop); i++) {
282                 struct test_sk_cfg *test = &tests[i];
283                 const struct sockaddr *addr;
284
285                 if (!test__start_subtest(test->name))
286                         continue;
287                 prepare_addr(test->addr, test->family, BIND_PORT, false);
288                 addr = (const struct sockaddr *)test->addr;
289                 server = start_server(addr, test->len, test->type);
290                 if (server == -1)
291                         goto cleanup;
292
293                 /* connect to unbound ports */
294                 prepare_addr(test->addr, test->family, CONNECT_PORT,
295                              test->rewrite_addr);
296                 if (run_test(server, addr, test->len, test->type))
297                         goto close;
298
299                 close(server);
300                 server = -1;
301         }
302
303 close:
304         close(server);
305 cleanup:
306         if (CHECK_FAIL(setns(self_net, CLONE_NEWNET)))
307                 perror("Failed to setns("NS_SELF")");
308         close(self_net);
309 }