Fix error in control protocol implementation in dinitctl/shutdown.
[oweals/dinit.git] / src / dinitctl.cc
1 #include <cstdio>
2 #include <cstddef>
3 #include <cstring>
4 #include <string>
5 #include <iostream>
6 #include <system_error>
7 #include <memory>
8
9 #include <sys/types.h>
10 #include <sys/socket.h>
11 #include <sys/un.h>
12 #include <unistd.h>
13 #include <signal.h>
14 #include <pwd.h>
15
16 #include "control-cmds.h"
17 #include "service-constants.h"
18 #include "cpbuffer.h"
19
20 // dinitctl:  utility to control the Dinit daemon, including starting and stopping of services.
21
22 // This utility communicates with the dinit daemon via a unix stream socket (/dev/initctl, or $HOME/.dinitctl).
23
24 using handle_t = uint32_t;
25
26
27 class read_cp_exception
28 {
29     public:
30     int errcode;
31     read_cp_exception(int err) : errcode(err) { }
32 };
33
34 enum class command_t;
35
36 static int issue_load_service(int socknum, const char *service_name);
37 static int check_load_reply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p);
38 static int start_stop_service(int socknum, const char *service_name, command_t command, bool do_pin, bool wait_for_service, bool verbose);
39 static int unpin_service(int socknum, const char *service_name, bool verbose);
40 static int list_services(int socknum);
41
42
43 // Fill a circular buffer from a file descriptor, reading at least _rlength_ bytes.
44 // Throws ReadException if the requested number of bytes cannot be read, with:
45 //     errcode = 0   if end of stream (remote end closed)
46 //     errcode = errno   if another error occurred
47 // Note that EINTR is ignored (i.e. the read will be re-tried).
48 static void fillBufferTo(cpbuffer<1024> *buf, int fd, int rlength)
49 {
50     do {
51         int r = buf->fill_to(fd, rlength);
52         if (r == -1) {
53             if (errno != EINTR) {
54                 throw read_cp_exception(errno);
55             }
56         }
57         else if (r == 0) {
58             throw read_cp_exception(0);
59         }
60         else {
61             return;
62         }
63     }
64     while (true);
65 }
66
67 static const char * describeState(bool stopped)
68 {
69     return stopped ? "stopped" : "started";
70 }
71
72 static const char * describeVerb(bool stop)
73 {
74     return stop ? "stop" : "start";
75 }
76
77 // Wait for a reply packet, skipping over any information packets
78 // that are received in the meantime.
79 static void wait_for_reply(cpbuffer<1024> &rbuffer, int fd)
80 {
81     fillBufferTo(&rbuffer, fd, 1);
82     
83     while (rbuffer[0] >= 100) {
84         // Information packet; discard.
85         fillBufferTo(&rbuffer, fd, 2);
86         int pktlen = (unsigned char) rbuffer[1];
87         
88         rbuffer.consume(1);  // Consume one byte so we'll read one byte of the next packet
89         fillBufferTo(&rbuffer, fd, pktlen);
90         rbuffer.consume(pktlen - 1);
91     }
92 }
93
94
95 // Write *all* the requested buffer and re-try if necessary until
96 // the buffer is written or an unrecoverable error occurs.
97 static int write_all(int fd, const void *buf, size_t count)
98 {
99     const char *cbuf = static_cast<const char *>(buf);
100     int w = 0;
101     while (count > 0) {
102         int r = write(fd, cbuf, count);
103         if (r == -1) {
104             if (errno == EINTR) continue;
105             return r;
106         }
107         w += r;
108         cbuf += r;
109         count -= r;
110     }
111     return w;
112 }
113
114
115 enum class command_t {
116     NONE,
117     START_SERVICE,
118     WAKE_SERVICE,
119     STOP_SERVICE,
120     RELEASE_SERVICE,
121     UNPIN_SERVICE,
122     LIST_SERVICES
123 };
124
125 // Entry point.
126 int main(int argc, char **argv)
127 {
128     using namespace std;
129     
130     bool show_help = argc < 2;
131     char *service_name = nullptr;
132     
133     std::string control_socket_str;
134     const char * control_socket_path = nullptr;
135     
136     bool verbose = true;
137     bool sys_dinit = false;  // communicate with system daemon
138     bool wait_for_service = true;
139     bool do_pin = false;
140     
141     command_t command = command_t::NONE;
142         
143     for (int i = 1; i < argc; i++) {
144         if (argv[i][0] == '-') {
145             if (strcmp(argv[i], "--help") == 0) {
146                 show_help = true;
147                 break;
148             }
149             else if (strcmp(argv[i], "--no-wait") == 0) {
150                 wait_for_service = false;
151             }
152             else if (strcmp(argv[i], "--quiet") == 0) {
153                 verbose = false;
154             }
155             else if (strcmp(argv[i], "--system") == 0 || strcmp(argv[i], "-s") == 0) {
156                 sys_dinit = true;
157             }
158             else if (strcmp(argv[i], "--pin") == 0) {
159                 do_pin = true;
160             }
161             else {
162                 return 1;
163             }
164         }
165         else if (command == command_t::NONE) {
166             if (strcmp(argv[i], "start") == 0) {
167                 command = command_t::START_SERVICE; 
168             }
169             else if (strcmp(argv[i], "wake") == 0) {
170                 command = command_t::WAKE_SERVICE;
171             }
172             else if (strcmp(argv[i], "stop") == 0) {
173                 command = command_t::STOP_SERVICE;
174             }
175             else if (strcmp(argv[i], "release") == 0) {
176                 command = command_t::RELEASE_SERVICE;
177             }
178             else if (strcmp(argv[i], "unpin") == 0) {
179                 command = command_t::UNPIN_SERVICE;
180             }
181             else if (strcmp(argv[i], "list") == 0) {
182                 command = command_t::LIST_SERVICES;
183             }
184             else {
185                 show_help = true;
186                 break;
187             }
188         }
189         else {
190             // service name
191             if (service_name != nullptr) {
192                 show_help = true;
193                 break;
194             }
195             service_name = argv[i];
196             // TODO support multiple services
197         }
198     }
199     
200     if (service_name != nullptr && command == command_t::LIST_SERVICES) {
201         show_help = true;
202     }
203     
204     if ((service_name == nullptr && command != command_t::LIST_SERVICES) || command == command_t::NONE) {
205         show_help = true;
206     }
207
208     if (show_help) {
209         cout << "dinitctl:   control Dinit services" << endl;
210         
211         cout << "\nUsage:" << endl;
212         cout << "    dinitctl [options] start [options] <service-name> : start and activate service" << endl;
213         cout << "    dinitctl [options] stop [options] <service-name>  : stop service and cancel explicit activation" << endl;
214         cout << "    dinitctl [options] wake [options] <service-name>  : start but do not mark activated" << endl;
215         cout << "    dinitctl [options] release [options] <service-name> : release activation, stop if no dependents" << endl;
216         cout << "    dinitctl [options] unpin <service-name>           : un-pin the service (after a previous pin)" << endl;
217         cout << "    dinitctl list                                     : list loaded services" << endl;
218         
219         cout << "\nNote: An activated service continues running when its dependents stop." << endl;
220         
221         cout << "\nGeneral options:" << endl;
222         cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
223         cout << "  --quiet          : suppress output (except errors)" << endl;
224         
225         cout << "\nCommand options:" << endl;
226         cout << "  --help           : show this help" << endl;
227         cout << "  --no-wait        : don't wait for service startup/shutdown to complete" << endl;
228         cout << "  --pin            : pin the service in the requested (started/stopped) state" << endl;
229         return 1;
230     }
231     
232     signal(SIGPIPE, SIG_IGN);
233     
234     control_socket_path = "/dev/dinitctl";
235     
236     // Locate control socket
237     if (! sys_dinit) {
238         char * userhome = getenv("HOME");
239         if (userhome == nullptr) {
240             struct passwd * pwuid_p = getpwuid(getuid());
241             if (pwuid_p != nullptr) {
242                 userhome = pwuid_p->pw_dir;
243             }
244         }
245         
246         if (userhome != nullptr) {
247             control_socket_str = userhome;
248             control_socket_str += "/.dinitctl";
249             control_socket_path = control_socket_str.c_str();
250         }
251         else {
252             cerr << "Cannot locate user home directory (set HOME or check /etc/passwd file)" << endl;
253             return 1;
254         }
255     }
256     
257     int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
258     if (socknum == -1) {
259         perror("dinitctl: socket");
260         return 1;
261     }
262
263     struct sockaddr_un * name;
264     uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
265     name = (struct sockaddr_un *) malloc(sockaddr_size);
266     if (name == nullptr) {
267         cerr << "dinitctl: Out of memory" << endl;
268         return 1;
269     }
270     
271     name->sun_family = AF_UNIX;
272     strcpy(name->sun_path, control_socket_path);
273     
274     int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
275     if (connr == -1) {
276         perror("dinitctl: connect");
277         return 1;
278     }
279     
280     // TODO should start by querying protocol version
281     
282     if (command == command_t::UNPIN_SERVICE) {
283         return unpin_service(socknum, service_name, verbose);
284     }
285     else if (command == command_t::LIST_SERVICES) {
286         return list_services(socknum);
287     }
288
289     return start_stop_service(socknum, service_name, command, do_pin, wait_for_service, verbose);
290 }
291
292 // Start/stop a service
293 static int start_stop_service(int socknum, const char *service_name, command_t command, bool do_pin, bool wait_for_service, bool verbose)
294 {
295     using namespace std;
296
297     bool do_stop = (command == command_t::STOP_SERVICE || command == command_t::RELEASE_SERVICE);
298     
299     if (issue_load_service(socknum, service_name)) {
300         return 1;
301     }
302
303     // Now we expect a reply:
304     
305     try {
306         cpbuffer<1024> rbuffer;
307         wait_for_reply(rbuffer, socknum);
308         
309         service_state_t state;
310         //service_state_t target_state;
311         handle_t handle;
312         
313         if (check_load_reply(socknum, rbuffer, &handle, &state) != 0) {
314             return 0;
315         }
316                 
317         service_state_t wanted_state = do_stop ? service_state_t::STOPPED : service_state_t::STARTED;
318         int pcommand = 0;
319         switch (command) {
320         case command_t::STOP_SERVICE:
321             pcommand = DINIT_CP_STOPSERVICE;
322             break;
323         case command_t::RELEASE_SERVICE:
324             pcommand = DINIT_CP_RELEASESERVICE;
325             break;
326         case command_t::START_SERVICE:
327             pcommand = DINIT_CP_STARTSERVICE;
328             break;
329         case command_t::WAKE_SERVICE:
330             pcommand = DINIT_CP_WAKESERVICE;
331             break;
332         default: ;
333         }
334         
335         // Need to issue STOPSERVICE/STARTSERVICE
336         // We'll do this regardless of the current service state / target state, since issuing
337         // start/stop also sets or clears the "explicitly started" flag on the service.
338         {
339             int r;
340             
341             {
342                 auto buf = new char[2 + sizeof(handle)];
343                 unique_ptr<char[]> ubuf(buf);
344                 
345                 buf[0] = pcommand;
346                 buf[1] = do_pin ? 1 : 0;
347                 memcpy(buf + 2, &handle, sizeof(handle));
348                 r = write_all(socknum, buf, 2 + sizeof(handle));
349             }
350             
351             if (r == -1) {
352                 perror("dinitctl: write");
353                 return 1;
354             }
355             
356             wait_for_reply(rbuffer, socknum);
357             if (rbuffer[0] == DINIT_RP_ALREADYSS) {
358                 bool already = (state == wanted_state);
359                 if (verbose) {
360                     cout << "Service " << (already ? "(already) " : "") << describeState(do_stop) << "." << endl;
361                 }
362                 return 0; // success!
363             }
364             if (rbuffer[0] != DINIT_RP_ACK) {
365                 cerr << "dinitctl: Protocol error." << endl;
366                 return 1;
367             }
368             rbuffer.consume(1);
369         }
370         
371         if (! wait_for_service) {
372             if (verbose) {
373                 cout << "Issued " << describeVerb(do_stop) << " command successfully." << endl;
374             }
375             return 0;
376         }
377         
378         service_event_t completionEvent;
379         service_event_t cancelledEvent;
380         
381         if (do_stop) {
382             completionEvent = service_event_t::STOPPED;
383             cancelledEvent = service_event_t::STOPCANCELLED;
384         }
385         else {
386             completionEvent = service_event_t::STARTED;
387             cancelledEvent = service_event_t::STARTCANCELLED;
388         }
389         
390         // Wait until service started:
391         int r = rbuffer.fill_to(socknum, 2);
392         while (r > 0) {
393             if (rbuffer[0] >= 100) {
394                 int pktlen = (unsigned char) rbuffer[1];
395                 fillBufferTo(&rbuffer, socknum, pktlen);
396                 
397                 if (rbuffer[0] == DINIT_IP_SERVICEEVENT) {
398                     handle_t ev_handle;
399                     rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle));
400                     service_event_t event = static_cast<service_event_t>(rbuffer[2 + sizeof(ev_handle)]);
401                     if (ev_handle == handle) {
402                         if (event == completionEvent) {
403                             if (verbose) {
404                                 cout << "Service " << describeState(do_stop) << "." << endl;
405                             }
406                             return 0;
407                         }
408                         else if (event == cancelledEvent) {
409                             if (verbose) {
410                                 cout << "Service " << describeVerb(do_stop) << " cancelled." << endl;
411                             }
412                             return 1;
413                         }
414                         else if (! do_stop && event == service_event_t::FAILEDSTART) {
415                             if (verbose) {
416                                 cout << "Service failed to start." << endl;
417                             }
418                             return 1;
419                         }
420                     }
421                 }
422                 
423                 rbuffer.consume(pktlen);
424                 r = rbuffer.fill_to(socknum, 2);
425             }
426             else {
427                 // Not an information packet?
428                 cerr << "dinitctl: protocol error" << endl;
429                 return 1;
430             }
431         }
432         
433         if (r == -1) {
434             perror("dinitctl: read");
435         }
436         else {
437             cerr << "protocol error (connection closed by server)" << endl;
438         }
439         return 1;
440     }
441     catch (read_cp_exception &exc) {
442         cerr << "dinitctl: control socket read failure or protocol error" << endl;
443         return 1;
444     }
445     catch (std::bad_alloc &exc) {
446         cerr << "dinitctl: out of memory" << endl;
447         return 1;
448     }
449     
450     return 0;
451 }
452
453 // Issue a "load service" command (DINIT_CP_LOADSERVICE), without waiting for
454 // a response. Returns 1 on failure (with error logged), 0 on success.
455 static int issue_load_service(int socknum, const char *service_name)
456 {
457     // Build buffer;
458     uint16_t sname_len = strlen(service_name);
459     int bufsize = 3 + sname_len;
460     int r;
461     
462     try {
463         std::unique_ptr<char[]> ubuf(new char[bufsize]);
464         auto buf = ubuf.get();
465         
466         buf[0] = DINIT_CP_LOADSERVICE;
467         memcpy(buf + 1, &sname_len, 2);
468         memcpy(buf + 3, service_name, sname_len);
469         
470         r = write_all(socknum, buf, bufsize);
471     }
472     catch (std::bad_alloc &badalloc) {
473         std::cerr << "dinitctl: " << badalloc.what() << std::endl;
474         return 1;
475     }
476     
477     if (r == -1) {
478         perror("dinitctl: write");
479         return 1;
480     }
481     
482     return 0;
483 }
484
485 // Check that a "load service" reply was received, and that the requested service was found.
486 static int check_load_reply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p)
487 {
488     using namespace std;
489     
490     if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
491         fillBufferTo(&rbuffer, socknum, 2 + sizeof(*handle_p));
492         rbuffer.extract((char *) handle_p, 2, sizeof(*handle_p));
493         if (state_p) *state_p = static_cast<service_state_t>(rbuffer[1]);
494         //target_state = static_cast<service_state_t>(rbuffer[2 + sizeof(handle)]);
495         rbuffer.consume(3 + sizeof(*handle_p));
496         return 0;
497     }
498     else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
499         cerr << "dinitctl: Failed to find/load service." << endl;
500         return 1;
501     }
502     else {
503         cerr << "dinitctl: Protocol error." << endl;
504         return 1;
505     }
506 }
507
508 static int unpin_service(int socknum, const char *service_name, bool verbose)
509 {
510     using namespace std;
511     
512     // Build buffer;
513     if (issue_load_service(socknum, service_name) == 1) {
514         return 1;
515     }
516
517     // Now we expect a reply:
518     
519     try {
520         cpbuffer<1024> rbuffer;
521         wait_for_reply(rbuffer, socknum);
522         
523         handle_t handle;
524         
525         if (check_load_reply(socknum, rbuffer, &handle, nullptr) != 0) {
526             return 1;
527         }
528         
529         // Issue UNPIN command.
530         {
531             int r;
532             
533             {
534                 char *buf = new char[1 + sizeof(handle)];
535                 unique_ptr<char[]> ubuf(buf);
536                 buf[0] = DINIT_CP_UNPINSERVICE;
537                 memcpy(buf + 1, &handle, sizeof(handle));
538                 r = write_all(socknum, buf, 2 + sizeof(handle));
539             }
540             
541             if (r == -1) {
542                 perror("dinitctl: write");
543                 return 1;
544             }
545             
546             wait_for_reply(rbuffer, socknum);
547             if (rbuffer[0] != DINIT_RP_ACK) {
548                 cerr << "dinitctl: Protocol error." << endl;
549                 return 1;
550             }
551             rbuffer.consume(1);
552         }
553     }
554     catch (read_cp_exception &exc) {
555         cerr << "dinitctl: Control socket read failure or protocol error" << endl;
556         return 1;
557     }
558     catch (std::bad_alloc &exc) {
559         cerr << "dinitctl: Out of memory" << endl;
560         return 1;
561     }
562     
563     if (verbose) {
564         cout << "Service unpinned." << endl;
565     }
566     return 0;
567 }
568
569 static int list_services(int socknum)
570 {
571     using namespace std;
572     
573     try {
574         char cmdbuf[] = { (char)DINIT_CP_LISTSERVICES };
575         int r = write_all(socknum, cmdbuf, 1);
576         
577         if (r == -1) {
578             perror("dinitctl: write");
579             return 1;
580         }
581         
582         cpbuffer<1024> rbuffer;
583         wait_for_reply(rbuffer, socknum);
584         while (rbuffer[0] == DINIT_RP_SVCINFO) {
585             fillBufferTo(&rbuffer, socknum, 8);
586             int nameLen = rbuffer[1];
587             service_state_t current = static_cast<service_state_t>(rbuffer[2]);
588             service_state_t target = static_cast<service_state_t>(rbuffer[3]);
589             
590             fillBufferTo(&rbuffer, socknum, nameLen + 8);
591             
592             char *name_ptr = rbuffer.get_ptr(8);
593             int clength = std::min(rbuffer.get_contiguous_length(name_ptr), nameLen);
594             
595             string name = string(name_ptr, clength);
596             name.append(rbuffer.get_buf_base(), nameLen - clength);
597             
598             cout << "[";
599             
600             cout << (target  == service_state_t::STARTED ? "{" : " ");
601             cout << (current == service_state_t::STARTED ? "+" : " ");
602             cout << (target  == service_state_t::STARTED ? "}" : " ");
603             
604             if (current == service_state_t::STARTING) {
605                 cout << "<<";
606             }
607             else if (current == service_state_t::STOPPING) {
608                 cout << ">>";
609             }
610             else {
611                 cout << "  ";
612             }
613             
614             cout << (target  == service_state_t::STOPPED ? "{" : " ");
615             cout << (current == service_state_t::STOPPED ? "-" : " ");
616             cout << (target  == service_state_t::STOPPED ? "}" : " ");
617             
618             cout << "] " << name << endl;
619             
620             rbuffer.consume(8 + nameLen);
621             wait_for_reply(rbuffer, socknum);
622         }
623         
624         if (rbuffer[0] != DINIT_RP_LISTDONE) {
625             cerr << "dinitctl: Control socket protocol error" << endl;
626             return 1;
627         }
628     }
629     catch (read_cp_exception &exc) {
630         cerr << "dinitctl: Control socket read failure or protocol error" << endl;
631         return 1;
632     }
633     catch (std::bad_alloc &exc) {
634         cerr << "dinitctl: Out of memory" << endl;
635         return 1;
636     }
637     
638     return 0;
639 }