a04598fbfc94fbf763d241d5df3fb41b554f7ea0
[oweals/dinit.git] / src / shutdown.cc
1 #include <cstddef>
2 #include <cstdio>
3 #include <csignal>
4 #include <cstring>
5 #include <string>
6 #include <iostream>
7 #include <exception>
8
9 #include <sys/reboot.h>
10 #include <sys/types.h>
11 #include <sys/socket.h>
12 #include <sys/un.h>
13 #include <sys/wait.h>
14 #include <sys/stat.h>
15 #include <unistd.h>
16 #include <fcntl.h>
17
18 #include "cpbuffer.h"
19 #include "control-cmds.h"
20 #include "service-constants.h"
21 #include "dinit-client.h"
22 #include "dinit-util.h"
23
24 #include "dasynq.h"
25
26 // shutdown:  shut down the system
27 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
28
29 static constexpr uint16_t min_cp_version = 1;
30 static constexpr uint16_t max_cp_version = 1;
31
32 using loop_t = dasynq::event_loop_n;
33 using rearm = dasynq::rearm;
34 using clock_type = dasynq::clock_type;
35 class subproc_buffer;
36
37 void do_system_shutdown(shutdown_type_t shutdown_type);
38 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf);
39 static void swap_off(loop_t &loop, subproc_buffer &sub_buf);
40
41 constexpr static int subproc_bufsize = 4096;
42
43 constexpr static char output_lost_msg[] = "[Some output has not been shown due to buffer overflow]\n";
44
45 // A buffer which maintains a series of overflow markers, used for capturing and echoing
46 // subprocess output.
47 class subproc_buffer : private cpbuffer<subproc_bufsize>
48 {
49     using base = cpbuffer<subproc_bufsize>;
50
51     int overflow_marker = -1;
52     int last_overflow = -1;  // last marker in the series
53     const char *overflow_msg_ptr = nullptr;   // current position in overflow message
54     dasynq::event_loop_n &loop;
55     dasynq::event_loop_n::fd_watcher * out_watch;
56
57     public:
58     enum class fill_status
59     {
60         OK,
61         ENDFILE,
62         FULL
63     };
64
65     subproc_buffer(dasynq::event_loop_n &loop_p, int out_fd) : loop(loop_p)
66     {
67         using loop_t = dasynq::event_loop_n;
68         using rearm = dasynq::rearm;
69
70         out_watch = loop_t::fd_watcher::add_watch(loop, out_fd, dasynq::OUT_EVENTS,
71                 [&](loop_t &eloop, int fd, int flags) -> rearm {
72
73             auto fstatus = flush(STDOUT_FILENO);
74             if (fstatus == subproc_buffer::fill_status::ENDFILE) {
75                 return rearm::DISARM;
76             }
77
78             return rearm::REARM;
79         });
80     }
81
82     ~subproc_buffer()
83     {
84         out_watch->deregister(loop);
85     }
86
87     // Fill buffer by reading from file descriptor. Note caller must set overflow marker
88     // if the buffer becomes full and more data is available.
89     fill_status fill(int fd)
90     {
91         int rem = get_free();
92
93         if (rem == 0) {
94             return fill_status::FULL;
95         }
96
97         int read = base::fill(fd, rem);
98         if (read <= 0) {
99             if (read == -1 && errno == EAGAIN) {
100                 return fill_status::OK;
101             }
102             return fill_status::ENDFILE;
103         }
104
105         out_watch->set_enabled(loop, true);
106         return fill_status::OK;
107     }
108
109     // Append a message. If the message will not fit in the buffer, discard it and mark overflow.
110     void append(const char *msg)
111     {
112         out_watch->set_enabled(loop, true);
113         int len = strlen(msg);
114         if (subproc_bufsize - get_length() >= len) {
115             base::append(msg, len);
116         }
117         else {
118             mark_overflow();
119         }
120     }
121
122     // Append the given buffer, which must fit in the remaining space in this buffer.
123     void append(const char *buf, int len)
124     {
125         out_watch->set_enabled(loop, true);
126         base::append(buf, len);
127     }
128
129     int get_free()
130     {
131         return base::get_free();
132     }
133
134     // Write buffer contents out to file descriptor. The descriptor is assumed to be non-blocking.
135     // returns ENDFILE if there is no more content to flush (buffer is now empty) or OK otherwise.
136     fill_status flush(int fd)
137     {
138         int to_write = get_contiguous_length(get_ptr(0));
139
140         if (overflow_marker != -1) {
141             if (overflow_marker == 0) {
142                 // output (remainder of) overflow message
143                 int l = std::strlen(overflow_msg_ptr);
144                 int r = write(fd, overflow_msg_ptr, l);
145                 if (r == l) {
146                     // entire message has been written; next marker is in buffer
147                     int16_t overflow_marker16;
148                     extract(reinterpret_cast<char *>(&overflow_marker16), 0, sizeof(overflow_marker16));
149                     overflow_marker = overflow_marker16;
150                     consume(sizeof(overflow_marker16));
151
152                     // no more overflow markers?
153                     if (overflow_marker == -1) {
154                         last_overflow = -1;
155                     }
156                     return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
157                 }
158                 if (r > 0) {
159                     overflow_msg_ptr += r;
160                 }
161                 return fill_status::OK;
162             }
163
164             to_write = std::min(to_write, overflow_marker);
165         }
166
167         int r = write(fd, get_ptr(0), to_write);
168         if (r > 0) {
169             consume(r);
170             if (overflow_marker != -1) {
171                 overflow_marker -= r;
172                 last_overflow -= r;
173                 if (overflow_marker == 0) {
174                     overflow_msg_ptr = output_lost_msg;
175                 }
176             }
177         }
178
179         return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
180     }
181
182     // Mark overflow occurred. Call this only when the buffer is full.
183     // The marker is put after the most recent newline in the buffer, if possible, so that whole
184     // lines are retained in the buffer. In some cases marking overflow will not add a new overflow
185     // marker but simply trim the buffer to an existing marker.
186     void mark_overflow()
187     {
188         // Try to find the last newline in the buffer
189         int begin = 0;
190         if (last_overflow != -1) {
191             begin = last_overflow + sizeof(int16_t);
192         }
193         int end = get_length() - 1 - sizeof(int16_t); // -1, then -2 for storage of marker
194
195         int i;
196         for (i = end; i >= begin; i--) {
197             if ((*this)[i] == '\n') break;
198         }
199
200         if (last_overflow != -1 && i < begin) {
201             // No new line after existing marker: trim all beyond that marker, don't
202             // create a new marker:
203             trim_to(last_overflow + sizeof(uint16_t));
204             return;
205         }
206
207         if (i < begin) {
208             // No newline in the whole buffer... we'll put the overflow marker at the end,
209             // on the assumption that it is better to output a partial line than it is to
210             // discard the entire buffer:
211             last_overflow = get_length() - sizeof(int16_t);
212             overflow_marker = last_overflow;
213             int16_t overflow16 = -1;
214             char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
215             *get_ptr(last_overflow + 0) = overflow16_ptr[0];
216             *get_ptr(last_overflow + 1) = overflow16_ptr[1];
217             return;
218         }
219
220         // We found a newline, put the overflow marker just after it:
221         int new_overflow = i + 1;
222         if (last_overflow != -1) {
223             int16_t new_overflow16 = new_overflow;
224             char * new_overflow16_ptr = reinterpret_cast<char *>(&new_overflow16);
225             *get_ptr(last_overflow + 0) = new_overflow16_ptr[0];
226             *get_ptr(last_overflow + 1) = new_overflow16_ptr[1];
227         }
228         last_overflow = new_overflow;
229         if (overflow_marker == -1) {
230             overflow_marker = last_overflow;
231         }
232
233         int16_t overflow16 = -1;
234         char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
235         *get_ptr(last_overflow + 0) = overflow16_ptr[0];
236         *get_ptr(last_overflow + 0) = overflow16_ptr[1];
237         trim_to(last_overflow + sizeof(int16_t));
238     }
239 };
240
241
242 int main(int argc, char **argv)
243 {
244     using namespace std;
245     
246     bool show_help = false;
247     bool sys_shutdown = false;
248     bool use_passed_cfd = false;
249     
250     auto shutdown_type = shutdown_type_t::POWEROFF;
251
252     const char *execname = base_name(argv[0]);
253     if (strcmp(execname, "reboot") == 0) {
254         shutdown_type = shutdown_type_t::REBOOT;
255     }
256         
257     for (int i = 1; i < argc; i++) {
258         if (argv[i][0] == '-') {
259             if (strcmp(argv[i], "--help") == 0) {
260                 show_help = true;
261                 break;
262             }
263             
264             if (strcmp(argv[i], "--system") == 0) {
265                 sys_shutdown = true;
266             }
267             else if (strcmp(argv[i], "-r") == 0) {
268                 shutdown_type = shutdown_type_t::REBOOT;
269             }
270             else if (strcmp(argv[i], "-h") == 0) {
271                 shutdown_type = shutdown_type_t::HALT;
272             }
273             else if (strcmp(argv[i], "-p") == 0) {
274                 shutdown_type = shutdown_type_t::POWEROFF;
275             }
276             else if (strcmp(argv[i], "--use-passed-cfd") == 0) {
277                 use_passed_cfd = true;
278             }
279             else {
280                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
281                 return 1;
282             }
283         }
284         else {
285             // time argument? TODO
286             show_help = true;
287         }
288     }
289
290     if (show_help) {
291         cout << execname << " :   shutdown the system\n"
292                 "  --help           : show this help\n"
293                 "  -r               : reboot\n"
294                 "  -h               : halt system\n"
295                 "  -p               : power down (default)\n"
296                 "  --use-passed-cfd : use the socket file descriptor identified by the DINIT_CS_FD\n"
297                 "                     environment variable to communicate with the init daemon.\n"
298                 "  --system         : perform shutdown immediately, instead of issuing shutdown\n"
299                 "                     command to the init program. Not recommended for use\n"
300                 "                     by users.\n";
301         return 1;
302     }
303     
304     if (sys_shutdown) {
305         do_system_shutdown(shutdown_type);
306         return 0;
307     }
308
309     signal(SIGPIPE, SIG_IGN);
310     
311     int socknum = 0;
312     
313     if (use_passed_cfd) {
314         char * dinit_cs_fd_env = getenv("DINIT_CS_FD");
315         if (dinit_cs_fd_env != nullptr) {
316             char * endptr;
317             long int cfdnum = strtol(dinit_cs_fd_env, &endptr, 10);
318             if (endptr != dinit_cs_fd_env) {
319                 socknum = (int) cfdnum;
320                 // Set non-blocking mode:
321                 int sock_flags = fcntl(socknum, F_GETFL, 0);
322                 fcntl(socknum, F_SETFL, sock_flags & ~O_NONBLOCK);
323             }
324             else {
325                 use_passed_cfd = false;
326             }
327         }
328         else {
329             use_passed_cfd = false;
330         }
331     }
332     
333     if (! use_passed_cfd) {
334         socknum = socket(AF_UNIX, SOCK_STREAM, 0);
335         if (socknum == -1) {
336             perror("socket");
337             return 1;
338         }
339         
340         const char *naddr = "/dev/dinitctl";
341         
342         struct sockaddr_un name;
343         name.sun_family = AF_UNIX;
344         strcpy(name.sun_path, naddr);
345         int sunlen = offsetof(struct sockaddr_un, sun_path) + strlen(naddr) + 1; // family, (string), nul
346         
347         int connr = connect(socknum, (struct sockaddr *) &name, sunlen);
348         if (connr == -1) {
349             perror("connect");
350             return 1;
351         }
352     }
353
354     try {
355         cpbuffer_t rbuffer;
356     
357         check_protocol_version(min_cp_version, max_cp_version, rbuffer, socknum);
358
359         // Build buffer;
360         constexpr int bufsize = 2;
361         char buf[bufsize];
362
363         buf[0] = DINIT_CP_SHUTDOWN;
364         buf[1] = static_cast<char>(shutdown_type);
365
366         cout << "Issuing shutdown command..." << endl;
367
368         write_all_x(socknum, buf, bufsize);
369
370         // Wait for ACK/NACK
371     
372         wait_for_reply(rbuffer, socknum);
373         
374         if (rbuffer[0] != DINIT_RP_ACK) {
375             cerr << "shutdown: control socket protocol error" << endl;
376             return 1;
377         }
378     }
379     catch (cp_old_client_exception &e) {
380         std::cerr << "shutdown: too old (server reports newer protocol version)" << std::endl;
381         return 1;
382     }
383     catch (cp_old_server_exception &e) {
384         std::cerr << "shutdown: server too old or protocol error" << std::endl;
385         return 1;
386     }
387     catch (cp_read_exception &e) {
388         cerr << "shutdown: control socket read failure or protocol error" << endl;
389         return 1;
390     }
391     catch (cp_write_exception &e) {
392         cerr << "shutdown: control socket write error: " << std::strerror(e.errcode) << endl;
393         return 1;
394     }
395     
396     while (true) {
397         pause();
398     }
399     
400     return 0;
401 }
402
403 // Actually shut down the system.
404 void do_system_shutdown(shutdown_type_t shutdown_type)
405 {
406     using namespace std;
407     
408     // Mask all signals to prevent death of our parent etc from terminating us
409     sigset_t allsigs;
410     sigfillset(&allsigs);
411     sigprocmask(SIG_SETMASK, &allsigs, nullptr);
412     
413     int reboot_type = RB_AUTOBOOT; // reboot
414 #if defined(RB_POWER_OFF)
415     if (shutdown_type == shutdown_type_t::POWEROFF) reboot_type = RB_POWER_OFF;
416 #endif
417 #if defined(RB_HALT_SYSTEM)
418     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT_SYSTEM;
419 #elif defined(RB_HALT)
420     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT;
421 #endif
422     
423     // Write to console rather than any terminal, since we lose the terminal it seems:
424     close(STDOUT_FILENO);
425     int consfd = open("/dev/console", O_WRONLY);
426     if (consfd != STDOUT_FILENO) {
427         dup2(consfd, STDOUT_FILENO);
428     }
429     
430     loop_t loop;
431     subproc_buffer sub_buf {loop, STDOUT_FILENO};
432
433     sub_buf.append("Sending TERM/KILL to all processes...\n");
434     
435     // Send TERM/KILL to all (remaining) processes
436     kill(-1, SIGTERM);
437
438     // 1 second delay (while outputting from sub_buf):
439     bool timeout_reached = false;
440     dasynq::time_val timeout {1, 0};
441     dasynq::time_val interval {0,0};
442     loop_t::timer::add_timer(loop, clock_type::MONOTONIC, true /* relative */,
443             timeout.get_timespec(), interval.get_timespec(),
444             [&](loop_t &eloop, int expiry_count) -> rearm {
445
446         timeout_reached = true;
447         return rearm::REMOVE;
448     });
449
450     do {
451       loop.run();
452     } while (! timeout_reached);
453
454     kill(-1, SIGKILL);
455     
456     // perform shutdown
457     sub_buf.append("Turning off swap...\n");
458     swap_off(loop, sub_buf);
459     sub_buf.append("Unmounting disks...\n");
460     unmount_disks(loop, sub_buf);
461     sync();
462     
463     sub_buf.append("Issuing shutdown via kernel...\n");
464     loop.poll();  // give message a chance to get to console
465     reboot(reboot_type);
466 }
467
468 // Watcher for subprocess output.
469 class subproc_out_watch : public loop_t::fd_watcher_impl<subproc_out_watch>
470 {
471     subproc_buffer &sub_buf;
472     bool in_overflow = false;
473
474     rearm read_overflow(int fd)
475     {
476         char buf[128];
477         int r = read(fd, buf, 128);
478         if (r == 0 || (r == -1 && errno != EAGAIN)) {
479             return rearm::NOOP; // leave disarmed
480         }
481         if (r == -1) {
482             return rearm::REARM;
483         }
484
485         // How much space is available?
486         int fr = sub_buf.get_free();
487         for (int b = r - std::min(r, fr); b < r; b++) {
488             if (buf[b] == '\n') {
489                 // Copy the (partial) line into sub_buf and leave overflow mode
490                 sub_buf.append(buf + b, r - b);
491                 in_overflow = false;
492             }
493         }
494         return rearm::REARM;
495     }
496
497     public:
498     subproc_out_watch(subproc_buffer &sub_buf_p) : sub_buf(sub_buf_p) {}
499
500     rearm fd_event(loop_t &, int fd, int flags)
501     {
502         // if current status is reading overflow, read and discard until newline
503         if (in_overflow) {
504             return read_overflow(fd);
505         }
506
507         auto r = sub_buf.fill(fd);
508         if (r == subproc_buffer::fill_status::FULL) {
509             sub_buf.mark_overflow();
510             in_overflow = true;
511             return read_overflow(fd);
512         }
513         else if (r == subproc_buffer::fill_status::ENDFILE) {
514             return rearm::NOOP;
515         }
516
517         return rearm::REARM;  // re-enable watcher
518     }
519 };
520
521 // Run process, put its output through the subprocess buffer
522 //   may throw: std::system_error, std::bad_alloc
523 static void run_process(const char * prog_args[], loop_t &loop, subproc_buffer &sub_buf)
524 {
525     class sp_watcher_t : public loop_t::child_proc_watcher_impl<sp_watcher_t>
526     {
527         public:
528         bool terminated = false;
529
530         rearm status_change(loop_t &, pid_t child, int status)
531         {
532             terminated = true;
533             return rearm::REMOVE;
534         }
535     };
536
537     sp_watcher_t sp_watcher;
538
539     // Create output pipe
540     bool have_pipe = true;
541     int pipefds[2];
542     if (dasynq::pipe2(pipefds, O_NONBLOCK) == -1) {
543         sub_buf.append("Warning: ");
544         sub_buf.append(prog_args[0]);
545         sub_buf.append(": could not create pipe for subprocess output\n");
546         have_pipe = false;
547         // Note, we proceed and let the sub-process run with our stdout/stderr.
548     }
549
550     subproc_out_watch owatch {sub_buf};
551
552     if (have_pipe) {
553         close(pipefds[1]);
554         try {
555             owatch.add_watch(loop, pipefds[0], dasynq::IN_EVENTS);
556         }
557         catch (...) {
558             // failed to create the watcher for the subprocess output; again, let it run with
559             // our stdout/stderr
560             sub_buf.append("Warning: could not create output watch for subprocess\n");
561             close(pipefds[0]);
562             have_pipe = false;
563         }
564     }
565
566     // If we've buffered any messages/output, give them a chance to go out now:
567     loop.poll();
568
569     pid_t ch_pid = sp_watcher.fork(loop);
570     if (ch_pid == 0) {
571         // child
572         // Dup output pipe to stdout, stderr
573         if (have_pipe) {
574             dup2(pipefds[1], STDOUT_FILENO);
575             dup2(pipefds[1], STDERR_FILENO);
576             close(pipefds[0]);
577             close(pipefds[1]);
578         }
579         execv(prog_args[0], const_cast<char **>(prog_args));
580         puts("Failed to execute subprocess: ");
581         perror(prog_args[0]);
582         _exit(1);
583     }
584
585     do {
586         loop.run();
587     } while (! sp_watcher.terminated);
588
589     if (have_pipe) {
590         owatch.deregister(loop);
591     }
592 }
593
594 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf)
595 {
596     try {
597         const char * unmount_args[] = { "/bin/umount", "-a", "-r", nullptr };
598         run_process(unmount_args, loop, sub_buf);
599     }
600     catch (std::exception &e) {
601         sub_buf.append("Couldn't fork for umount: ");
602         sub_buf.append(e.what());
603         sub_buf.append("\n");
604     }
605 }
606
607 static void swap_off(loop_t &loop, subproc_buffer &sub_buf)
608 {
609     try {
610         const char * swapoff_args[] = { "/sbin/swapoff", "-a", nullptr };
611         run_process(swapoff_args, loop, sub_buf);
612     }
613     catch (std::exception &e) {
614         sub_buf.append("Couldn't fork for swapoff: ");
615         sub_buf.append(e.what());
616         sub_buf.append("\n");
617     }
618 }