Define constants for supported control protocol versions.
[oweals/dinit.git] / src / shutdown.cc
1 #include <cstddef>
2 #include <cstdio>
3 #include <csignal>
4 #include <unistd.h>
5 #include <cstring>
6 #include <string>
7 #include <iostream>
8
9 #include <sys/reboot.h>
10 #include <sys/types.h>
11 #include <sys/socket.h>
12 #include <sys/un.h>
13 #include <sys/wait.h>
14 #include <sys/stat.h>
15 #include <fcntl.h>
16
17 #include "cpbuffer.h"
18 #include "control-cmds.h"
19 #include "service-constants.h"
20 #include "dinit-client.h"
21
22 #include "dasynq.h"
23
24 // shutdown:  shut down the system
25 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
26
27 static constexpr uint16_t min_cp_version = 1;
28 static constexpr uint16_t max_cp_version = 1;
29
30 using loop_t = dasynq::event_loop_n;
31 using rearm = dasynq::rearm;
32 using clock_type = dasynq::clock_type;
33 class subproc_buffer;
34
35 void do_system_shutdown(shutdown_type_t shutdown_type);
36 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf);
37 static void swap_off(loop_t &loop, subproc_buffer &sub_buf);
38
39 constexpr static int subproc_bufsize = 4096;
40
41 constexpr static char output_lost_msg[] = "[Some output has not been shown due to buffer overflow]\n";
42
43 // A buffer which maintains a series of overflow markers, used for capturing and echoing
44 // subprocess output.
45 class subproc_buffer : private cpbuffer<subproc_bufsize>
46 {
47     using base = cpbuffer<subproc_bufsize>;
48
49     int overflow_marker = -1;
50     int last_overflow = -1;  // last marker in the series
51     const char *overflow_msg_ptr = nullptr;   // current position in overflow message
52     dasynq::event_loop_n &loop;
53     dasynq::event_loop_n::fd_watcher * out_watch;
54
55     public:
56     enum class fill_status
57     {
58         OK,
59         ENDFILE,
60         FULL
61     };
62
63     subproc_buffer(dasynq::event_loop_n &loop_p, int out_fd) : loop(loop_p)
64     {
65         using loop_t = dasynq::event_loop_n;
66         using rearm = dasynq::rearm;
67
68         out_watch = loop_t::fd_watcher::add_watch(loop, out_fd, dasynq::OUT_EVENTS,
69                 [&](loop_t &eloop, int fd, int flags) -> rearm {
70
71             auto fstatus = flush(STDOUT_FILENO);
72             if (fstatus == subproc_buffer::fill_status::ENDFILE) {
73                 return rearm::DISARM;
74             }
75
76             return rearm::REARM;
77         });
78     }
79
80     ~subproc_buffer()
81     {
82         out_watch->deregister(loop);
83     }
84
85     // Fill buffer by reading from file descriptor. Note caller must set overflow marker
86     // if the buffer becomes full and more data is available.
87     fill_status fill(int fd)
88     {
89         int rem = get_free();
90
91         if (rem == 0) {
92             return fill_status::FULL;
93         }
94
95         int read = base::fill(fd, rem);
96         if (read <= 0) {
97             if (read == -1 && errno == EAGAIN) {
98                 return fill_status::OK;
99             }
100             return fill_status::ENDFILE;
101         }
102
103         out_watch->set_enabled(loop, true);
104         return fill_status::OK;
105     }
106
107     // Append a message. If the message will not fit in the buffer, discard it and mark overflow.
108     void append(const char *msg)
109     {
110         out_watch->set_enabled(loop, true);
111         int len = strlen(msg);
112         if (subproc_bufsize - get_length() >= len) {
113             base::append(msg, len);
114         }
115         else {
116             mark_overflow();
117         }
118     }
119
120     // Append the given buffer, which must fit in the remaining space in this buffer.
121     void append(const char *buf, int len)
122     {
123         out_watch->set_enabled(loop, true);
124         base::append(buf, len);
125     }
126
127     int get_free()
128     {
129         return base::get_free();
130     }
131
132     // Write buffer contents out to file descriptor. The descriptor is assumed to be non-blocking.
133     // returns ENDFILE if there is no more content to flush (buffer is now empty) or OK otherwise.
134     fill_status flush(int fd)
135     {
136         int to_write = get_contiguous_length(get_ptr(0));
137
138         if (overflow_marker != -1) {
139             if (overflow_marker == 0) {
140                 // output (remainder of) overflow message
141                 int l = std::strlen(overflow_msg_ptr);
142                 int r = write(fd, overflow_msg_ptr, l);
143                 if (r == l) {
144                     // entire message has been written; next marker is in buffer
145                     int16_t overflow_marker16;
146                     extract(reinterpret_cast<char *>(&overflow_marker16), 0, sizeof(overflow_marker16));
147                     overflow_marker = overflow_marker16;
148                     consume(sizeof(overflow_marker16));
149
150                     // no more overflow markers?
151                     if (overflow_marker == -1) {
152                         last_overflow = -1;
153                     }
154                     return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
155                 }
156                 if (r > 0) {
157                     overflow_msg_ptr += r;
158                 }
159                 return fill_status::OK;
160             }
161
162             to_write = std::min(to_write, overflow_marker);
163         }
164
165         int r = write(fd, get_ptr(0), to_write);
166         if (r > 0) {
167             consume(r);
168             if (overflow_marker != -1) {
169                 overflow_marker -= r;
170                 last_overflow -= r;
171                 if (overflow_marker == 0) {
172                     overflow_msg_ptr = output_lost_msg;
173                 }
174             }
175         }
176
177         return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
178     }
179
180     // Mark overflow occurred. Call this only when the buffer is full.
181     // The marker is put after the most recent newline in the buffer, if possible, so that whole
182     // lines are retained in the buffer. In some cases marking overflow will not add a new overflow
183     // marker but simply trim the buffer to an existing marker.
184     void mark_overflow()
185     {
186         // Try to find the last newline in the buffer
187         int begin = 0;
188         if (last_overflow != -1) {
189             begin = last_overflow + sizeof(int16_t);
190         }
191         int end = get_length() - 1 - sizeof(int16_t); // -1, then -2 for storage of marker
192
193         int i;
194         for (i = end; i >= begin; i--) {
195             if ((*this)[i] == '\n') break;
196         }
197
198         if (last_overflow != -1 && i < begin) {
199             // No new line after existing marker: trim all beyond that marker, don't
200             // create a new marker:
201             trim_to(last_overflow + sizeof(uint16_t));
202             return;
203         }
204
205         if (i < begin) {
206             // No newline in the whole buffer... we'll put the overflow marker at the end,
207             // on the assumption that it is better to output a partial line than it is to
208             // discard the entire buffer:
209             last_overflow = get_length() - sizeof(int16_t);
210             overflow_marker = last_overflow;
211             int16_t overflow16 = -1;
212             char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
213             *get_ptr(last_overflow + 0) = overflow16_ptr[0];
214             *get_ptr(last_overflow + 1) = overflow16_ptr[1];
215             return;
216         }
217
218         // We found a newline, put the overflow marker just after it:
219         int new_overflow = i + 1;
220         if (last_overflow != -1) {
221             int16_t new_overflow16 = new_overflow;
222             char * new_overflow16_ptr = reinterpret_cast<char *>(&new_overflow16);
223             *get_ptr(last_overflow + 0) = new_overflow16_ptr[0];
224             *get_ptr(last_overflow + 1) = new_overflow16_ptr[1];
225         }
226         last_overflow = new_overflow;
227         if (overflow_marker == -1) {
228             overflow_marker = last_overflow;
229         }
230
231         int16_t overflow16 = -1;
232         char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
233         *get_ptr(last_overflow + 0) = overflow16_ptr[0];
234         *get_ptr(last_overflow + 0) = overflow16_ptr[1];
235         trim_to(last_overflow + sizeof(int16_t));
236     }
237 };
238
239
240 int main(int argc, char **argv)
241 {
242     using namespace std;
243     
244     bool show_help = false;
245     bool sys_shutdown = false;
246     bool use_passed_cfd = false;
247     
248     auto shutdown_type = shutdown_type_t::POWEROFF;
249         
250     for (int i = 1; i < argc; i++) {
251         if (argv[i][0] == '-') {
252             if (strcmp(argv[i], "--help") == 0) {
253                 show_help = true;
254                 break;
255             }
256             
257             if (strcmp(argv[i], "--system") == 0) {
258                 sys_shutdown = true;
259             }
260             else if (strcmp(argv[i], "-r") == 0) {
261                 shutdown_type = shutdown_type_t::REBOOT;
262             }
263             else if (strcmp(argv[i], "-h") == 0) {
264                 shutdown_type = shutdown_type_t::HALT;
265             }
266             else if (strcmp(argv[i], "-p") == 0) {
267                 shutdown_type = shutdown_type_t::POWEROFF;
268             }
269             else if (strcmp(argv[i], "--use-passed-cfd") == 0) {
270                 use_passed_cfd = true;
271             }
272             else {
273                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
274                 return 1;
275             }
276         }
277         else {
278             // time argument? TODO
279             show_help = true;
280         }
281     }
282
283     if (show_help) {
284         cout << "dinit-shutdown :   shutdown the system" << endl;
285         cout << "  --help           : show this help" << endl;
286         cout << "  -r               : reboot" << endl;
287         cout << "  -h               : halt system" << endl;
288         cout << "  -p               : power down (default)" << endl;
289         cout << "  --use-passed-cfd : use the socket file descriptor identified by the DINIT_CS_FD" << endl;
290         cout << "                     environment variable to communicate with the init daemon." << endl;
291         cout << "  --system         : perform shutdown immediately, instead of issuing shutdown" << endl;
292         cout << "                     command to the init program. Not recommended for use" << endl;
293         cout << "                     by users." << endl;
294         return 1;
295     }
296     
297     if (sys_shutdown) {
298         do_system_shutdown(shutdown_type);
299         return 0;
300     }
301
302     signal(SIGPIPE, SIG_IGN);
303     
304     int socknum = 0;
305     
306     if (use_passed_cfd) {
307         char * dinit_cs_fd_env = getenv("DINIT_CS_FD");
308         if (dinit_cs_fd_env != nullptr) {
309             char * endptr;
310             long int cfdnum = strtol(dinit_cs_fd_env, &endptr, 10);
311             if (endptr != dinit_cs_fd_env) {
312                 socknum = (int) cfdnum;
313                 // Set non-blocking mode:
314                 int sock_flags = fcntl(socknum, F_GETFL, 0);
315                 fcntl(socknum, F_SETFL, sock_flags & ~O_NONBLOCK);
316             }
317             else {
318                 use_passed_cfd = false;
319             }
320         }
321         else {
322             use_passed_cfd = false;
323         }
324     }
325     
326     if (! use_passed_cfd) {
327         socknum = socket(AF_UNIX, SOCK_STREAM, 0);
328         if (socknum == -1) {
329             perror("socket");
330             return 1;
331         }
332         
333         const char *naddr = "/dev/dinitctl";
334         
335         struct sockaddr_un name;
336         name.sun_family = AF_UNIX;
337         strcpy(name.sun_path, naddr);
338         int sunlen = offsetof(struct sockaddr_un, sun_path) + strlen(naddr) + 1; // family, (string), nul
339         
340         int connr = connect(socknum, (struct sockaddr *) &name, sunlen);
341         if (connr == -1) {
342             perror("connect");
343             return 1;
344         }
345     }
346
347     try {
348         cpbuffer_t rbuffer;
349     
350         check_protocol_version(min_cp_version, max_cp_version, rbuffer, socknum);
351
352         // Build buffer;
353         constexpr int bufsize = 2;
354         char buf[bufsize];
355
356         buf[0] = DINIT_CP_SHUTDOWN;
357         buf[1] = static_cast<char>(shutdown_type);
358
359         cout << "Issuing shutdown command..." << endl;
360
361         write_all_x(socknum, buf, bufsize);
362
363         // Wait for ACK/NACK
364     
365         wait_for_reply(rbuffer, socknum);
366         
367         if (rbuffer[0] != DINIT_RP_ACK) {
368             cerr << "shutdown: control socket protocol error" << endl;
369             return 1;
370         }
371     }
372     catch (cp_old_client_exception &e) {
373         std::cerr << "shutdown: too old (server reports newer protocol version)" << std::endl;
374         return 1;
375     }
376     catch (cp_old_server_exception &e) {
377         std::cerr << "shutdown: server too old or protocol error" << std::endl;
378         return 1;
379     }
380     catch (cp_read_exception &e) {
381         cerr << "shutdown: control socket read failure or protocol error" << endl;
382         return 1;
383     }
384     catch (cp_write_exception &e) {
385         cerr << "shutdown: control socket write error: " << std::strerror(e.errcode) << endl;
386         return 1;
387     }
388     
389     while (true) {
390         pause();
391     }
392     
393     return 0;
394 }
395
396 // Actually shut down the system.
397 void do_system_shutdown(shutdown_type_t shutdown_type)
398 {
399     using namespace std;
400     
401     // Mask all signals to prevent death of our parent etc from terminating us
402     sigset_t allsigs;
403     sigfillset(&allsigs);
404     sigprocmask(SIG_SETMASK, &allsigs, nullptr);
405     
406     int reboot_type = RB_AUTOBOOT; // reboot
407 #if defined(RB_POWER_OFF)
408     if (shutdown_type == shutdown_type_t::POWEROFF) reboot_type = RB_POWER_OFF;
409 #endif
410 #if defined(RB_HALT_SYSTEM)
411     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT_SYSTEM;
412 #elif defined(RB_HALT)
413     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT;
414 #endif
415     
416     // Write to console rather than any terminal, since we lose the terminal it seems:
417     close(STDOUT_FILENO);
418     int consfd = open("/dev/console", O_WRONLY);
419     if (consfd != STDOUT_FILENO) {
420         dup2(consfd, STDOUT_FILENO);
421     }
422     
423     loop_t loop;
424     subproc_buffer sub_buf {loop, STDOUT_FILENO};
425
426     sub_buf.append("Sending TERM/KILL to all processes...\n");
427     
428     // Send TERM/KILL to all (remaining) processes
429     kill(-1, SIGTERM);
430
431     // 1 second delay (while outputting from sub_buf):
432     bool timeout_reached = false;
433     dasynq::time_val timeout {1, 0};
434     dasynq::time_val interval {0,0};
435     loop_t::timer::add_timer(loop, clock_type::MONOTONIC, true /* relative */,
436             timeout.get_timespec(), interval.get_timespec(),
437             [&](loop_t &eloop, int expiry_count) -> rearm {
438
439         timeout_reached = true;
440         return rearm::REMOVE;
441     });
442
443     do {
444       loop.run();
445     } while (! timeout_reached);
446
447     kill(-1, SIGKILL);
448     
449     // perform shutdown
450     sub_buf.append("Turning off swap...\n");
451     swap_off(loop, sub_buf);
452     sub_buf.append("Unmounting disks...\n");
453     unmount_disks(loop, sub_buf);
454     sync();
455     
456     sub_buf.append("Issuing shutdown via kernel...\n");
457     loop.poll();  // give message a chance to get to console
458     reboot(reboot_type);
459 }
460
461 // Watcher for subprocess output.
462 class subproc_out_watch : public loop_t::fd_watcher_impl<subproc_out_watch>
463 {
464     subproc_buffer &sub_buf;
465     bool in_overflow = false;
466
467     rearm read_overflow(int fd)
468     {
469         char buf[128];
470         int r = read(fd, buf, 128);
471         if (r == 0 || (r == -1 && errno != EAGAIN)) {
472             return rearm::NOOP; // leave disarmed
473         }
474         if (r == -1) {
475             return rearm::REARM;
476         }
477
478         // How much space is available?
479         int fr = sub_buf.get_free();
480         for (int b = r - std::min(r, fr); b < r; b++) {
481             if (buf[b] == '\n') {
482                 // Copy the (partial) line into sub_buf and leave overflow mode
483                 sub_buf.append(buf + b, r - b);
484                 in_overflow = false;
485             }
486         }
487         return rearm::REARM;
488     }
489
490     public:
491     subproc_out_watch(subproc_buffer &sub_buf_p) : sub_buf(sub_buf_p) {}
492
493     rearm fd_event(loop_t &, int fd, int flags)
494     {
495         // if current status is reading overflow, read and discard until newline
496         if (in_overflow) {
497             return read_overflow(fd);
498         }
499
500         auto r = sub_buf.fill(fd);
501         if (r == subproc_buffer::fill_status::FULL) {
502             sub_buf.mark_overflow();
503             in_overflow = true;
504             return read_overflow(fd);
505         }
506         else if (r == subproc_buffer::fill_status::ENDFILE) {
507             return rearm::NOOP;
508         }
509
510         return rearm::REARM;  // re-enable watcher
511     }
512 };
513
514 // Run process, put its output through the subprocess buffer
515 //   may throw: std::system_error, std::bad_alloc
516 static void run_process(const char * prog_args[], loop_t &loop, subproc_buffer &sub_buf)
517 {
518     class sp_watcher_t : public loop_t::child_proc_watcher_impl<sp_watcher_t>
519     {
520         public:
521         bool terminated = false;
522
523         rearm status_change(loop_t &, pid_t child, int status)
524         {
525             terminated = true;
526             return rearm::REMOVE;
527         }
528     };
529
530     sp_watcher_t sp_watcher;
531
532     // Create output pipe
533     int pipefds[2];
534     if (dasynq::pipe2(pipefds, O_NONBLOCK) == -1) {
535         // TODO
536         std::cout << "*** pipe2 failed ***" << std::endl;
537     }
538
539     pid_t ch_pid = sp_watcher.fork(loop);
540     if (ch_pid == 0) {
541         // child
542         // Dup output pipe to stdout, stderr
543         dup2(pipefds[1], STDOUT_FILENO);
544         dup2(pipefds[1], STDERR_FILENO);
545         close(pipefds[0]);
546         close(pipefds[1]);
547         execv(prog_args[0], const_cast<char **>(prog_args));
548         puts("Failed to execute subprocess:\n");
549         perror(prog_args[0]);
550         _exit(1);
551     }
552
553     close(pipefds[1]);
554
555     subproc_out_watch owatch {sub_buf};
556     owatch.add_watch(loop, pipefds[0], dasynq::IN_EVENTS);
557
558     do {
559         loop.run();
560     } while (! sp_watcher.terminated);
561
562     owatch.deregister(loop);
563 }
564
565 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf)
566 {
567     const char * unmount_args[] = { "/bin/umount", "-a", "-r", nullptr };
568     run_process(unmount_args, loop, sub_buf);
569 }
570
571 static void swap_off(loop_t &loop, subproc_buffer &sub_buf)
572 {
573     const char * swapoff_args[] = { "/sbin/swapoff", "-a", nullptr };
574     run_process(swapoff_args, loop, sub_buf);
575 }