Linux-libre 5.3.12-gnu
[librecmc/linux-libre.git] / tools / testing / selftests / bpf / test_tcp_check_syncookie_user.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 // Copyright (c) 2019 Cloudflare
4
5 #include <string.h>
6 #include <stdlib.h>
7 #include <unistd.h>
8
9 #include <arpa/inet.h>
10 #include <netinet/in.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13
14 #include <bpf/bpf.h>
15 #include <bpf/libbpf.h>
16
17 #include "bpf_rlimit.h"
18 #include "cgroup_helpers.h"
19
20 static int start_server(const struct sockaddr *addr, socklen_t len)
21 {
22         int fd;
23
24         fd = socket(addr->sa_family, SOCK_STREAM, 0);
25         if (fd == -1) {
26                 log_err("Failed to create server socket");
27                 goto out;
28         }
29
30         if (bind(fd, addr, len) == -1) {
31                 log_err("Failed to bind server socket");
32                 goto close_out;
33         }
34
35         if (listen(fd, 128) == -1) {
36                 log_err("Failed to listen on server socket");
37                 goto close_out;
38         }
39
40         goto out;
41
42 close_out:
43         close(fd);
44         fd = -1;
45 out:
46         return fd;
47 }
48
49 static int connect_to_server(int server_fd)
50 {
51         struct sockaddr_storage addr;
52         socklen_t len = sizeof(addr);
53         int fd = -1;
54
55         if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
56                 log_err("Failed to get server addr");
57                 goto out;
58         }
59
60         fd = socket(addr.ss_family, SOCK_STREAM, 0);
61         if (fd == -1) {
62                 log_err("Failed to create client socket");
63                 goto out;
64         }
65
66         if (connect(fd, (const struct sockaddr *)&addr, len) == -1) {
67                 log_err("Fail to connect to server");
68                 goto close_out;
69         }
70
71         goto out;
72
73 close_out:
74         close(fd);
75         fd = -1;
76 out:
77         return fd;
78 }
79
80 static int get_map_fd_by_prog_id(int prog_id)
81 {
82         struct bpf_prog_info info = {};
83         __u32 info_len = sizeof(info);
84         __u32 map_ids[1];
85         int prog_fd = -1;
86         int map_fd = -1;
87
88         prog_fd = bpf_prog_get_fd_by_id(prog_id);
89         if (prog_fd < 0) {
90                 log_err("Failed to get fd by prog id %d", prog_id);
91                 goto err;
92         }
93
94         info.nr_map_ids = 1;
95         info.map_ids = (__u64)(unsigned long)map_ids;
96
97         if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) {
98                 log_err("Failed to get info by prog fd %d", prog_fd);
99                 goto err;
100         }
101
102         if (!info.nr_map_ids) {
103                 log_err("No maps found for prog fd %d", prog_fd);
104                 goto err;
105         }
106
107         map_fd = bpf_map_get_fd_by_id(map_ids[0]);
108         if (map_fd < 0)
109                 log_err("Failed to get fd by map id %d", map_ids[0]);
110 err:
111         if (prog_fd >= 0)
112                 close(prog_fd);
113         return map_fd;
114 }
115
116 static int run_test(int server_fd, int results_fd)
117 {
118         int client = -1, srv_client = -1;
119         int ret = 0;
120         __u32 key = 0;
121         __u64 value = 0;
122
123         if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) {
124                 log_err("Can't clear results");
125                 goto err;
126         }
127
128         client = connect_to_server(server_fd);
129         if (client == -1)
130                 goto err;
131
132         srv_client = accept(server_fd, NULL, 0);
133         if (srv_client == -1) {
134                 log_err("Can't accept connection");
135                 goto err;
136         }
137
138         if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) {
139                 log_err("Can't lookup result");
140                 goto err;
141         }
142
143         if (value != 1) {
144                 log_err("Didn't match syncookie: %llu", value);
145                 goto err;
146         }
147
148         goto out;
149
150 err:
151         ret = 1;
152 out:
153         close(client);
154         close(srv_client);
155         return ret;
156 }
157
158 int main(int argc, char **argv)
159 {
160         struct sockaddr_in addr4;
161         struct sockaddr_in6 addr6;
162         int server = -1;
163         int server_v6 = -1;
164         int results = -1;
165         int err = 0;
166
167         if (argc < 2) {
168                 fprintf(stderr, "Usage: %s prog_id\n", argv[0]);
169                 exit(1);
170         }
171
172         results = get_map_fd_by_prog_id(atoi(argv[1]));
173         if (results < 0) {
174                 log_err("Can't get map");
175                 goto err;
176         }
177
178         memset(&addr4, 0, sizeof(addr4));
179         addr4.sin_family = AF_INET;
180         addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
181         addr4.sin_port = 0;
182
183         memset(&addr6, 0, sizeof(addr6));
184         addr6.sin6_family = AF_INET6;
185         addr6.sin6_addr = in6addr_loopback;
186         addr6.sin6_port = 0;
187
188         server = start_server((const struct sockaddr *)&addr4, sizeof(addr4));
189         if (server == -1)
190                 goto err;
191
192         server_v6 = start_server((const struct sockaddr *)&addr6,
193                                  sizeof(addr6));
194         if (server_v6 == -1)
195                 goto err;
196
197         if (run_test(server, results))
198                 goto err;
199
200         if (run_test(server_v6, results))
201                 goto err;
202
203         printf("ok\n");
204         goto out;
205 err:
206         err = 1;
207 out:
208         close(server);
209         close(server_v6);
210         close(results);
211         return err;
212 }