shutdown: set passed-via-environment control socket to blocking.
[oweals/dinit.git] / src / shutdown.cc
1 #include <cstddef>
2 #include <cstdio>
3 #include <csignal>
4 #include <unistd.h>
5 #include <cstring>
6 #include <string>
7 #include <iostream>
8
9 #include <sys/reboot.h>
10 #include <sys/types.h>
11 #include <sys/socket.h>
12 #include <sys/un.h>
13 #include <sys/wait.h>
14 #include <sys/stat.h>
15 #include <fcntl.h>
16
17 #include "cpbuffer.h"
18 #include "control-cmds.h"
19 #include "service-constants.h"
20 #include "dinit-client.h"
21
22 #include "dasynq.h"
23
24 // shutdown:  shut down the system
25 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
26
27 using loop_t = dasynq::event_loop_n;
28 using rearm = dasynq::rearm;
29 using clock_type = dasynq::clock_type;
30 class subproc_buffer;
31
32 void do_system_shutdown(shutdown_type_t shutdown_type);
33 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf);
34 static void swap_off(loop_t &loop, subproc_buffer &sub_buf);
35
36 constexpr static int subproc_bufsize = 4096;
37
38 constexpr static char output_lost_msg[] = "[Some output has not been shown due to buffer overflow]\n";
39
40 // A buffer which maintains a series of overflow markers, used for capturing and echoing
41 // subprocess output.
42 class subproc_buffer : private cpbuffer<subproc_bufsize>
43 {
44     using base = cpbuffer<subproc_bufsize>;
45
46     int overflow_marker = -1;
47     int last_overflow = -1;  // last marker in the series
48     const char *overflow_msg_ptr = nullptr;   // current position in overflow message
49     dasynq::event_loop_n &loop;
50     dasynq::event_loop_n::fd_watcher * out_watch;
51
52     public:
53     enum class fill_status
54     {
55         OK,
56         ENDFILE,
57         FULL
58     };
59
60     subproc_buffer(dasynq::event_loop_n &loop_p, int out_fd) : loop(loop_p)
61     {
62         using loop_t = dasynq::event_loop_n;
63         using rearm = dasynq::rearm;
64
65         out_watch = loop_t::fd_watcher::add_watch(loop, out_fd, dasynq::OUT_EVENTS,
66                 [&](loop_t &eloop, int fd, int flags) -> rearm {
67
68             auto fstatus = flush(STDOUT_FILENO);
69             if (fstatus == subproc_buffer::fill_status::ENDFILE) {
70                 return rearm::DISARM;
71             }
72
73             return rearm::REARM;
74         });
75     }
76
77     ~subproc_buffer()
78     {
79         out_watch->deregister(loop);
80     }
81
82     // Fill buffer by reading from file descriptor. Note caller must set overflow marker
83     // if the buffer becomes full and more data is available.
84     fill_status fill(int fd)
85     {
86         int rem = get_free();
87
88         if (rem == 0) {
89             return fill_status::FULL;
90         }
91
92         int read = base::fill(fd, rem);
93         if (read <= 0) {
94             if (read == -1 && errno == EAGAIN) {
95                 return fill_status::OK;
96             }
97             return fill_status::ENDFILE;
98         }
99
100         out_watch->set_enabled(loop, true);
101         return fill_status::OK;
102     }
103
104     // Append a message. If the message will not fit in the buffer, discard it and mark overflow.
105     void append(const char *msg)
106     {
107         out_watch->set_enabled(loop, true);
108         int len = strlen(msg);
109         if (subproc_bufsize - get_length() >= len) {
110             base::append(msg, len);
111         }
112         else {
113             mark_overflow();
114         }
115     }
116
117     // Append the given buffer, which must fit in the remaining space in this buffer.
118     void append(const char *buf, int len)
119     {
120         out_watch->set_enabled(loop, true);
121         base::append(buf, len);
122     }
123
124     int get_free()
125     {
126         return base::get_free();
127     }
128
129     // Write buffer contents out to file descriptor. The descriptor is assumed to be non-blocking.
130     // returns ENDFILE if there is no more content to flush (buffer is now empty) or OK otherwise.
131     fill_status flush(int fd)
132     {
133         int to_write = get_contiguous_length(get_ptr(0));
134
135         if (overflow_marker != -1) {
136             if (overflow_marker == 0) {
137                 // output (remainder of) overflow message
138                 int l = std::strlen(overflow_msg_ptr);
139                 int r = write(fd, overflow_msg_ptr, l);
140                 if (r == l) {
141                     // entire message has been written; next marker is in buffer
142                     int16_t overflow_marker16;
143                     extract(reinterpret_cast<char *>(&overflow_marker16), 0, sizeof(overflow_marker16));
144                     overflow_marker = overflow_marker16;
145                     consume(sizeof(overflow_marker16));
146
147                     // no more overflow markers?
148                     if (overflow_marker == -1) {
149                         last_overflow = -1;
150                     }
151                     return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
152                 }
153                 if (r > 0) {
154                     overflow_msg_ptr += r;
155                 }
156                 return fill_status::OK;
157             }
158
159             to_write = std::min(to_write, overflow_marker);
160         }
161
162         int r = write(fd, get_ptr(0), to_write);
163         if (r > 0) {
164             consume(r);
165             if (overflow_marker != -1) {
166                 overflow_marker -= r;
167                 last_overflow -= r;
168                 if (overflow_marker == 0) {
169                     overflow_msg_ptr = output_lost_msg;
170                 }
171             }
172         }
173
174         return get_length() == 0 ? fill_status::ENDFILE : fill_status::OK;
175     }
176
177     // Mark overflow occurred. Call this only when the buffer is full.
178     // The marker is put after the most recent newline in the buffer, if possible, so that whole
179     // lines are retained in the buffer. In some cases marking overflow will not add a new overflow
180     // marker but simply trim the buffer to an existing marker.
181     void mark_overflow()
182     {
183         // Try to find the last newline in the buffer
184         int begin = 0;
185         if (last_overflow != -1) {
186             begin = last_overflow + sizeof(int16_t);
187         }
188         int end = get_length() - 1 - sizeof(int16_t); // -1, then -2 for storage of marker
189
190         int i;
191         for (i = end; i >= begin; i--) {
192             if ((*this)[i] == '\n') break;
193         }
194
195         if (last_overflow != -1 && i < begin) {
196             // No new line after existing marker: trim all beyond that marker, don't
197             // create a new marker:
198             trim_to(last_overflow + sizeof(uint16_t));
199             return;
200         }
201
202         if (i < begin) {
203             // No newline in the whole buffer... we'll put the overflow marker at the end,
204             // on the assumption that it is better to output a partial line than it is to
205             // discard the entire buffer:
206             last_overflow = get_length() - sizeof(int16_t);
207             overflow_marker = last_overflow;
208             int16_t overflow16 = -1;
209             char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
210             *get_ptr(last_overflow + 0) = overflow16_ptr[0];
211             *get_ptr(last_overflow + 1) = overflow16_ptr[1];
212             return;
213         }
214
215         // We found a newline, put the overflow marker just after it:
216         int new_overflow = i + 1;
217         if (last_overflow != -1) {
218             int16_t new_overflow16 = new_overflow;
219             char * new_overflow16_ptr = reinterpret_cast<char *>(&new_overflow16);
220             *get_ptr(last_overflow + 0) = new_overflow16_ptr[0];
221             *get_ptr(last_overflow + 1) = new_overflow16_ptr[1];
222         }
223         last_overflow = new_overflow;
224         if (overflow_marker == -1) {
225             overflow_marker = last_overflow;
226         }
227
228         int16_t overflow16 = -1;
229         char * overflow16_ptr = reinterpret_cast<char *>(&overflow16);
230         *get_ptr(last_overflow + 0) = overflow16_ptr[0];
231         *get_ptr(last_overflow + 0) = overflow16_ptr[1];
232         trim_to(last_overflow + sizeof(int16_t));
233     }
234 };
235
236
237 int main(int argc, char **argv)
238 {
239     using namespace std;
240     
241     bool show_help = false;
242     bool sys_shutdown = false;
243     bool use_passed_cfd = false;
244     
245     auto shutdown_type = shutdown_type_t::POWEROFF;
246         
247     for (int i = 1; i < argc; i++) {
248         if (argv[i][0] == '-') {
249             if (strcmp(argv[i], "--help") == 0) {
250                 show_help = true;
251                 break;
252             }
253             
254             if (strcmp(argv[i], "--system") == 0) {
255                 sys_shutdown = true;
256             }
257             else if (strcmp(argv[i], "-r") == 0) {
258                 shutdown_type = shutdown_type_t::REBOOT;
259             }
260             else if (strcmp(argv[i], "-h") == 0) {
261                 shutdown_type = shutdown_type_t::HALT;
262             }
263             else if (strcmp(argv[i], "-p") == 0) {
264                 shutdown_type = shutdown_type_t::POWEROFF;
265             }
266             else if (strcmp(argv[i], "--use-passed-cfd") == 0) {
267                 use_passed_cfd = true;
268             }
269             else {
270                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
271                 return 1;
272             }
273         }
274         else {
275             // time argument? TODO
276             show_help = true;
277         }
278     }
279
280     if (show_help) {
281         cout << "dinit-shutdown :   shutdown the system" << endl;
282         cout << "  --help           : show this help" << endl;
283         cout << "  -r               : reboot" << endl;
284         cout << "  -h               : halt system" << endl;
285         cout << "  -p               : power down (default)" << endl;
286         cout << "  --use-passed-cfd : use the socket file descriptor identified by the DINIT_CS_FD" << endl;
287         cout << "                     environment variable to communicate with the init daemon." << endl;
288         cout << "  --system         : perform shutdown immediately, instead of issuing shutdown" << endl;
289         cout << "                     command to the init program. Not recommended for use" << endl;
290         cout << "                     by users." << endl;
291         return 1;
292     }
293     
294     if (sys_shutdown) {
295         do_system_shutdown(shutdown_type);
296         return 0;
297     }
298
299     signal(SIGPIPE, SIG_IGN);
300     
301     int socknum = 0;
302     
303     if (use_passed_cfd) {
304         char * dinit_cs_fd_env = getenv("DINIT_CS_FD");
305         if (dinit_cs_fd_env != nullptr) {
306             char * endptr;
307             long int cfdnum = strtol(dinit_cs_fd_env, &endptr, 10);
308             if (endptr != dinit_cs_fd_env) {
309                 socknum = (int) cfdnum;
310                 // Set non-blocking mode:
311                 int sock_flags = fcntl(socknum, F_GETFL, 0);
312                 fcntl(socknum, F_SETFL, sock_flags & ~O_NONBLOCK);
313             }
314             else {
315                 use_passed_cfd = false;
316             }
317         }
318         else {
319             use_passed_cfd = false;
320         }
321     }
322     
323     if (! use_passed_cfd) {
324         socknum = socket(AF_UNIX, SOCK_STREAM, 0);
325         if (socknum == -1) {
326             perror("socket");
327             return 1;
328         }
329         
330         const char *naddr = "/dev/dinitctl";
331         
332         struct sockaddr_un name;
333         name.sun_family = AF_UNIX;
334         strcpy(name.sun_path, naddr);
335         int sunlen = offsetof(struct sockaddr_un, sun_path) + strlen(naddr) + 1; // family, (string), nul
336         
337         int connr = connect(socknum, (struct sockaddr *) &name, sunlen);
338         if (connr == -1) {
339             perror("connect");
340             return 1;
341         }
342     }
343
344     try {
345         cpbuffer_t rbuffer;
346     
347         check_protocol_version(0, 0, rbuffer, socknum);
348
349         // Build buffer;
350         constexpr int bufsize = 2;
351         char buf[bufsize];
352
353         buf[0] = DINIT_CP_SHUTDOWN;
354         buf[1] = static_cast<char>(shutdown_type);
355
356         cout << "Issuing shutdown command..." << endl;
357
358         write_all_x(socknum, buf, bufsize);
359
360         // Wait for ACK/NACK
361     
362         wait_for_reply(rbuffer, socknum);
363         
364         if (rbuffer[0] != DINIT_RP_ACK) {
365             cerr << "shutdown: control socket protocol error" << endl;
366             return 1;
367         }
368     }
369     catch (cp_old_client_exception &e) {
370         std::cerr << "shutdown: too old (server reports newer protocol version)" << std::endl;
371         return 1;
372     }
373     catch (cp_old_server_exception &e) {
374         std::cerr << "shutdown: server too old or protocol error" << std::endl;
375         return 1;
376     }
377     catch (cp_read_exception &e) {
378         cerr << "shutdown: control socket read failure or protocol error" << endl;
379         return 1;
380     }
381     catch (cp_write_exception &e) {
382         cerr << "shutdown: control socket write error: " << std::strerror(e.errcode) << endl;
383         return 1;
384     }
385     
386     while (true) {
387         pause();
388     }
389     
390     return 0;
391 }
392
393 // Actually shut down the system.
394 void do_system_shutdown(shutdown_type_t shutdown_type)
395 {
396     using namespace std;
397     
398     // Mask all signals to prevent death of our parent etc from terminating us
399     sigset_t allsigs;
400     sigfillset(&allsigs);
401     sigprocmask(SIG_SETMASK, &allsigs, nullptr);
402     
403     int reboot_type = RB_AUTOBOOT; // reboot
404 #if defined(RB_POWER_OFF)
405     if (shutdown_type == shutdown_type_t::POWEROFF) reboot_type = RB_POWER_OFF;
406 #endif
407 #if defined(RB_HALT_SYSTEM)
408     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT_SYSTEM;
409 #elif defined(RB_HALT)
410     if (shutdown_type == shutdown_type_t::HALT) reboot_type = RB_HALT;
411 #endif
412     
413     // Write to console rather than any terminal, since we lose the terminal it seems:
414     close(STDOUT_FILENO);
415     int consfd = open("/dev/console", O_WRONLY);
416     if (consfd != STDOUT_FILENO) {
417         dup2(consfd, STDOUT_FILENO);
418     }
419     
420     loop_t loop;
421     subproc_buffer sub_buf {loop, STDOUT_FILENO};
422
423     sub_buf.append("Sending TERM/KILL to all processes...\n");
424     
425     // Send TERM/KILL to all (remaining) processes
426     kill(-1, SIGTERM);
427
428     // 1 second delay (while outputting from sub_buf):
429     bool timeout_reached = false;
430     dasynq::time_val timeout {1, 0};
431     dasynq::time_val interval {0,0};
432     loop_t::timer::add_timer(loop, clock_type::MONOTONIC, true /* relative */,
433             timeout.get_timespec(), interval.get_timespec(),
434             [&](loop_t &eloop, int expiry_count) -> rearm {
435
436         timeout_reached = true;
437         return rearm::REMOVE;
438     });
439
440     do {
441       loop.run();
442     } while (! timeout_reached);
443
444     kill(-1, SIGKILL);
445     
446     // perform shutdown
447     sub_buf.append("Turning off swap...\n");
448     swap_off(loop, sub_buf);
449     sub_buf.append("Unmounting disks...\n");
450     unmount_disks(loop, sub_buf);
451     sync();
452     
453     sub_buf.append("Issuing shutdown via kernel...\n");
454     loop.poll();  // give message a chance to get to console
455     reboot(reboot_type);
456 }
457
458 // Watcher for subprocess output.
459 class subproc_out_watch : public loop_t::fd_watcher_impl<subproc_out_watch>
460 {
461     subproc_buffer &sub_buf;
462     bool in_overflow = false;
463
464     rearm read_overflow(int fd)
465     {
466         char buf[128];
467         int r = read(fd, buf, 128);
468         if (r == 0 || (r == -1 && errno != EAGAIN)) {
469             return rearm::NOOP; // leave disarmed
470         }
471         if (r == -1) {
472             return rearm::REARM;
473         }
474
475         // How much space is available?
476         int fr = sub_buf.get_free();
477         for (int b = r - std::min(r, fr); b < r; b++) {
478             if (buf[b] == '\n') {
479                 // Copy the (partial) line into sub_buf and leave overflow mode
480                 sub_buf.append(buf + b, r - b);
481                 in_overflow = false;
482             }
483         }
484         return rearm::REARM;
485     }
486
487     public:
488     subproc_out_watch(subproc_buffer &sub_buf_p) : sub_buf(sub_buf_p) {}
489
490     rearm fd_event(loop_t &, int fd, int flags)
491     {
492         // if current status is reading overflow, read and discard until newline
493         if (in_overflow) {
494             return read_overflow(fd);
495         }
496
497         auto r = sub_buf.fill(fd);
498         if (r == subproc_buffer::fill_status::FULL) {
499             sub_buf.mark_overflow();
500             in_overflow = true;
501             return read_overflow(fd);
502         }
503         else if (r == subproc_buffer::fill_status::ENDFILE) {
504             return rearm::NOOP;
505         }
506
507         return rearm::REARM;  // re-enable watcher
508     }
509 };
510
511 // Run process, put its output through the subprocess buffer
512 //   may throw: std::system_error, std::bad_alloc
513 static void run_process(const char * prog_args[], loop_t &loop, subproc_buffer &sub_buf)
514 {
515     class sp_watcher_t : public loop_t::child_proc_watcher_impl<sp_watcher_t>
516     {
517         public:
518         bool terminated = false;
519
520         rearm status_change(loop_t &, pid_t child, int status)
521         {
522             terminated = true;
523             return rearm::REMOVE;
524         }
525     };
526
527     sp_watcher_t sp_watcher;
528
529     // Create output pipe
530     int pipefds[2];
531     if (dasynq::pipe2(pipefds, O_NONBLOCK) == -1) {
532         // TODO
533         std::cout << "*** pipe2 failed ***" << std::endl;
534     }
535
536     pid_t ch_pid = sp_watcher.fork(loop);
537     if (ch_pid == 0) {
538         // child
539         // Dup output pipe to stdout, stderr
540         dup2(pipefds[1], STDOUT_FILENO);
541         dup2(pipefds[1], STDERR_FILENO);
542         close(pipefds[0]);
543         close(pipefds[1]);
544         execv(prog_args[0], const_cast<char **>(prog_args));
545         puts("Failed to execute subprocess:\n");
546         perror(prog_args[0]);
547         _exit(1);
548     }
549
550     close(pipefds[1]);
551
552     subproc_out_watch owatch {sub_buf};
553     owatch.add_watch(loop, pipefds[0], dasynq::IN_EVENTS);
554
555     do {
556         loop.run();
557     } while (! sp_watcher.terminated);
558
559     owatch.deregister(loop);
560 }
561
562 static void unmount_disks(loop_t &loop, subproc_buffer &sub_buf)
563 {
564     const char * unmount_args[] = { "/bin/umount", "-a", "-r", nullptr };
565     run_process(unmount_args, loop, sub_buf);
566 }
567
568 static void swap_off(loop_t &loop, subproc_buffer &sub_buf)
569 {
570     const char * swapoff_args[] = { "/sbin/swapoff", "-a", nullptr };
571     run_process(swapoff_args, loop, sub_buf);
572 }