Linux-libre 5.3.12-gnu
[librecmc/linux-libre.git] / tools / testing / selftests / net / tcp_fastopen_backup_key.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 /*
4  * Test key rotation for TFO.
5  * New keys are 'rotated' in two steps:
6  * 1) Add new key as the 'backup' key 'behind' the primary key
7  * 2) Make new key the primary by swapping the backup and primary keys
8  *
9  * The rotation is done in stages using multiple sockets bound
10  * to the same port via SO_REUSEPORT. This simulates key rotation
11  * behind say a load balancer. We verify that across the rotation
12  * there are no cases in which a cookie is not accepted by verifying
13  * that TcpExtTCPFastOpenPassiveFail remains 0.
14  */
15 #define _GNU_SOURCE
16 #include <arpa/inet.h>
17 #include <errno.h>
18 #include <error.h>
19 #include <stdbool.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <sys/epoll.h>
24 #include <unistd.h>
25 #include <netinet/tcp.h>
26 #include <fcntl.h>
27 #include <time.h>
28
29 #ifndef TCP_FASTOPEN_KEY
30 #define TCP_FASTOPEN_KEY 33
31 #endif
32
33 #define N_LISTEN 10
34 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key"
35 #define KEY_LENGTH 16
36
37 #ifndef ARRAY_SIZE
38 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
39 #endif
40
41 static bool do_ipv6;
42 static bool do_sockopt;
43 static bool do_rotate;
44 static int key_len = KEY_LENGTH;
45 static int rcv_fds[N_LISTEN];
46 static int proc_fd;
47 static const char *IP4_ADDR = "127.0.0.1";
48 static const char *IP6_ADDR = "::1";
49 static const int PORT = 8891;
50
51 static void get_keys(int fd, uint32_t *keys)
52 {
53         char buf[128];
54         socklen_t len = KEY_LENGTH * 2;
55
56         if (do_sockopt) {
57                 if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len))
58                         error(1, errno, "Unable to get key");
59                 return;
60         }
61         lseek(proc_fd, 0, SEEK_SET);
62         if (read(proc_fd, buf, sizeof(buf)) <= 0)
63                 error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY);
64         if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2,
65             keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8)
66                 error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY);
67 }
68
69 static void set_keys(int fd, uint32_t *keys)
70 {
71         char buf[128];
72
73         if (do_sockopt) {
74                 if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys,
75                     key_len))
76                         error(1, errno, "Unable to set key");
77                 return;
78         }
79         if (do_rotate)
80                 snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x",
81                          keys[0], keys[1], keys[2], keys[3], keys[4], keys[5],
82                          keys[6], keys[7]);
83         else
84                 snprintf(buf, 128, "%08x-%08x-%08x-%08x",
85                          keys[0], keys[1], keys[2], keys[3]);
86         lseek(proc_fd, 0, SEEK_SET);
87         if (write(proc_fd, buf, sizeof(buf)) <= 0)
88                 error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY);
89 }
90
91 static void build_rcv_fd(int family, int proto, int *rcv_fds)
92 {
93         struct sockaddr_in  addr4 = {0};
94         struct sockaddr_in6 addr6 = {0};
95         struct sockaddr *addr;
96         int opt = 1, i, sz;
97         int qlen = 100;
98         uint32_t keys[8];
99
100         switch (family) {
101         case AF_INET:
102                 addr4.sin_family = family;
103                 addr4.sin_addr.s_addr = htonl(INADDR_ANY);
104                 addr4.sin_port = htons(PORT);
105                 sz = sizeof(addr4);
106                 addr = (struct sockaddr *)&addr4;
107                 break;
108         case AF_INET6:
109                 addr6.sin6_family = AF_INET6;
110                 addr6.sin6_addr = in6addr_any;
111                 addr6.sin6_port = htons(PORT);
112                 sz = sizeof(addr6);
113                 addr = (struct sockaddr *)&addr6;
114                 break;
115         default:
116                 error(1, 0, "Unsupported family %d", family);
117                 /* clang does not recognize error() above as terminating
118                  * the program, so it complains that saddr, sz are
119                  * not initialized when this code path is taken. Silence it.
120                  */
121                 return;
122         }
123         for (i = 0; i < ARRAY_SIZE(keys); i++)
124                 keys[i] = rand();
125         for (i = 0; i < N_LISTEN; i++) {
126                 rcv_fds[i] = socket(family, proto, 0);
127                 if (rcv_fds[i] < 0)
128                         error(1, errno, "failed to create receive socket");
129                 if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt,
130                                sizeof(opt)))
131                         error(1, errno, "failed to set SO_REUSEPORT");
132                 if (bind(rcv_fds[i], addr, sz))
133                         error(1, errno, "failed to bind receive socket");
134                 if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen,
135                                sizeof(qlen)))
136                         error(1, errno, "failed to set TCP_FASTOPEN");
137                 set_keys(rcv_fds[i], keys);
138                 if (proto == SOCK_STREAM && listen(rcv_fds[i], 10))
139                         error(1, errno, "failed to listen on receive port");
140         }
141 }
142
143 static int connect_and_send(int family, int proto)
144 {
145         struct sockaddr_in  saddr4 = {0};
146         struct sockaddr_in  daddr4 = {0};
147         struct sockaddr_in6 saddr6 = {0};
148         struct sockaddr_in6 daddr6 = {0};
149         struct sockaddr *saddr, *daddr;
150         int fd, sz, ret;
151         char data[1];
152
153         switch (family) {
154         case AF_INET:
155                 saddr4.sin_family = AF_INET;
156                 saddr4.sin_addr.s_addr = htonl(INADDR_ANY);
157                 saddr4.sin_port = 0;
158
159                 daddr4.sin_family = AF_INET;
160                 if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr))
161                         error(1, errno, "inet_pton failed: %s", IP4_ADDR);
162                 daddr4.sin_port = htons(PORT);
163
164                 sz = sizeof(saddr4);
165                 saddr = (struct sockaddr *)&saddr4;
166                 daddr = (struct sockaddr *)&daddr4;
167                 break;
168         case AF_INET6:
169                 saddr6.sin6_family = AF_INET6;
170                 saddr6.sin6_addr = in6addr_any;
171
172                 daddr6.sin6_family = AF_INET6;
173                 if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr))
174                         error(1, errno, "inet_pton failed: %s", IP6_ADDR);
175                 daddr6.sin6_port = htons(PORT);
176
177                 sz = sizeof(saddr6);
178                 saddr = (struct sockaddr *)&saddr6;
179                 daddr = (struct sockaddr *)&daddr6;
180                 break;
181         default:
182                 error(1, 0, "Unsupported family %d", family);
183                 /* clang does not recognize error() above as terminating
184                  * the program, so it complains that saddr, daddr, sz are
185                  * not initialized when this code path is taken. Silence it.
186                  */
187                 return -1;
188         }
189         fd = socket(family, proto, 0);
190         if (fd < 0)
191                 error(1, errno, "failed to create send socket");
192         if (bind(fd, saddr, sz))
193                 error(1, errno, "failed to bind send socket");
194         data[0] = 'a';
195         ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz);
196         if (ret != 1)
197                 error(1, errno, "failed to sendto");
198
199         return fd;
200 }
201
202 static bool is_listen_fd(int fd)
203 {
204         int i;
205
206         for (i = 0; i < N_LISTEN; i++) {
207                 if (rcv_fds[i] == fd)
208                         return true;
209         }
210         return false;
211 }
212
213 static void rotate_key(int fd)
214 {
215         static int iter;
216         static uint32_t new_key[4];
217         uint32_t keys[8];
218         uint32_t tmp_key[4];
219         int i;
220
221         if (iter < N_LISTEN) {
222                 /* first set new key as backups */
223                 if (iter == 0) {
224                         for (i = 0; i < ARRAY_SIZE(new_key); i++)
225                                 new_key[i] = rand();
226                 }
227                 get_keys(fd, keys);
228                 memcpy(keys + 4, new_key, KEY_LENGTH);
229                 set_keys(fd, keys);
230         } else {
231                 /* swap the keys */
232                 get_keys(fd, keys);
233                 memcpy(tmp_key, keys + 4, KEY_LENGTH);
234                 memcpy(keys + 4, keys, KEY_LENGTH);
235                 memcpy(keys, tmp_key, KEY_LENGTH);
236                 set_keys(fd, keys);
237         }
238         if (++iter >= (N_LISTEN * 2))
239                 iter = 0;
240 }
241
242 static void run_one_test(int family)
243 {
244         struct epoll_event ev;
245         int i, send_fd;
246         int n_loops = 10000;
247         int rotate_key_fd = 0;
248         int key_rotate_interval = 50;
249         int fd, epfd;
250         char buf[1];
251
252         build_rcv_fd(family, SOCK_STREAM, rcv_fds);
253         epfd = epoll_create(1);
254         if (epfd < 0)
255                 error(1, errno, "failed to create epoll");
256         ev.events = EPOLLIN;
257         for (i = 0; i < N_LISTEN; i++) {
258                 ev.data.fd = rcv_fds[i];
259                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev))
260                         error(1, errno, "failed to register sock epoll");
261         }
262         while (n_loops--) {
263                 send_fd = connect_and_send(family, SOCK_STREAM);
264                 if (do_rotate && ((n_loops % key_rotate_interval) == 0)) {
265                         rotate_key(rcv_fds[rotate_key_fd]);
266                         if (++rotate_key_fd >= N_LISTEN)
267                                 rotate_key_fd = 0;
268                 }
269                 while (1) {
270                         i = epoll_wait(epfd, &ev, 1, -1);
271                         if (i < 0)
272                                 error(1, errno, "epoll_wait failed");
273                         if (is_listen_fd(ev.data.fd)) {
274                                 fd = accept(ev.data.fd, NULL, NULL);
275                                 if (fd < 0)
276                                         error(1, errno, "failed to accept");
277                                 ev.data.fd = fd;
278                                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev))
279                                         error(1, errno, "failed epoll add");
280                                 continue;
281                         }
282                         i = recv(ev.data.fd, buf, sizeof(buf), 0);
283                         if (i != 1)
284                                 error(1, errno, "failed recv data");
285                         if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL))
286                                 error(1, errno, "failed epoll del");
287                         close(ev.data.fd);
288                         break;
289                 }
290                 close(send_fd);
291         }
292         for (i = 0; i < N_LISTEN; i++)
293                 close(rcv_fds[i]);
294 }
295
296 static void parse_opts(int argc, char **argv)
297 {
298         int c;
299
300         while ((c = getopt(argc, argv, "46sr")) != -1) {
301                 switch (c) {
302                 case '4':
303                         do_ipv6 = false;
304                         break;
305                 case '6':
306                         do_ipv6 = true;
307                         break;
308                 case 's':
309                         do_sockopt = true;
310                         break;
311                 case 'r':
312                         do_rotate = true;
313                         key_len = KEY_LENGTH * 2;
314                         break;
315                 default:
316                         error(1, 0, "%s: parse error", argv[0]);
317                 }
318         }
319 }
320
321 int main(int argc, char **argv)
322 {
323         parse_opts(argc, argv);
324         proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR);
325         if (proc_fd < 0)
326                 error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY);
327         srand(time(NULL));
328         if (do_ipv6)
329                 run_one_test(AF_INET6);
330         else
331                 run_one_test(AF_INET);
332         close(proc_fd);
333         fprintf(stderr, "PASS\n");
334         return 0;
335 }