Fix "skip to reply packet" method.
[oweals/dinit.git] / src / dinitctl.cc
1 #include <cstdio>
2 #include <cstddef>
3 #include <cstring>
4 #include <string>
5 #include <iostream>
6 #include <system_error>
7
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 #include <sys/un.h>
11 #include <unistd.h>
12 #include <pwd.h>
13
14 #include "control-cmds.h"
15 #include "service-constants.h"
16 #include "cpbuffer.h"
17
18 // dinitctl:  utility to control the Dinit daemon, including starting and stopping of services.
19
20 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
21
22 using handle_t = uint32_t;
23
24
25 class ReadCPException
26 {
27     public:
28     int errcode;
29     ReadCPException(int err) : errcode(err) { }
30 };
31
32 static void fillBufferTo(CPBuffer *buf, int fd, int rlength)
33 {
34     int r = buf->fillTo(fd, rlength);
35     if (r == -1) {
36         throw ReadCPException(errno);
37     }
38     else if (r == 0) {
39         throw ReadCPException(0);
40     }
41 }
42
43 static const char * describeState(bool stopped)
44 {
45     return stopped ? "stopped" : "started";
46 }
47
48 static const char * describeVerb(bool stop)
49 {
50     return stop ? "stop" : "start";
51 }
52
53 // Wait for a reply packet, skipping over any information packets
54 // that are received in the meantime.
55 static void wait_for_reply(CPBuffer &rbuffer, int fd)
56 {
57     fillBufferTo(&rbuffer, fd, 1);
58     
59     while (rbuffer[0] >= 100) {
60         // Information packet; discard.
61         fillBufferTo(&rbuffer, fd, 1);
62         int pktlen = (unsigned char) rbuffer[1];
63         
64         rbuffer.consume(1);  // Consume one byte so we'll read one byte of the next packet
65         fillBufferTo(&rbuffer, fd, pktlen);
66         rbuffer.consume(pktlen - 1);
67     }
68 }
69
70
71 // Write *all* the requested buffer and re-try if necessary until
72 // the buffer is written or an unrecoverable error occurs.
73 static int write_all(int fd, const void *buf, size_t count)
74 {
75     const char *cbuf = static_cast<const char *>(buf);
76     int w = 0;
77     while (count > 0) {
78         int r = write(fd, cbuf, count);
79         if (r == -1) {
80             if (errno == EINTR) continue;
81             return r;
82         }
83         w += r;
84         cbuf += r;
85         count -= r;
86     }
87     return w;
88 }
89
90 int main(int argc, char **argv)
91 {
92     using namespace std;
93     
94     bool do_stop = false;
95     bool show_help = argc < 2;
96     char *service_name = nullptr;
97     
98     std::string control_socket_str;
99     const char * control_socket_path = nullptr;
100     
101     bool verbose = true;
102     bool sys_dinit = false;  // communicate with system daemon
103     bool wait_for_service = true;
104     
105     int command = 0;
106     
107     constexpr int START_SERVICE = 1;
108     constexpr int STOP_SERVICE = 2;
109         
110     for (int i = 1; i < argc; i++) {
111         if (argv[i][0] == '-') {
112             if (strcmp(argv[i], "--help") == 0) {
113                 show_help = true;
114                 break;
115             }
116             else if (strcmp(argv[i], "--no-wait") == 0) {
117                 wait_for_service = false;
118             }
119             else if (strcmp(argv[i], "--quiet") == 0) {
120                 verbose = false;
121             }
122             else if (strcmp(argv[i], "--system") == 0 || strcmp(argv[i], "-s") == 0) {
123                 sys_dinit = true;
124             }
125             else {
126                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
127                 return 1;
128             }
129         }
130         else if (command == 0) {
131             if (strcmp(argv[i], "start") == 0) {
132                 command = START_SERVICE; 
133             }
134             else if (strcmp(argv[i], "stop") == 0) {
135                 command = STOP_SERVICE;
136             }
137             else {
138                 show_help = true;
139                 break;
140             }
141         }
142         else {
143             // service name
144             service_name = argv[i];
145             // TODO support multiple services (or at least give error if multiple
146             //      services supplied)
147         }
148     }
149     
150     if (service_name == nullptr || command == 0) {
151         show_help = true;
152     }
153
154     if (show_help) {
155         cout << "dinit-start:   start a dinit service" << endl;
156         cout << "  --help           : show this help" << endl;
157         cout << "  --no-wait        : don't wait for service startup/shutdown to complete" << endl;
158         cout << "  --quiet          : suppress output (except errors)" << endl;
159         cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
160         cout << "  <service-name>   : start the named service" << endl;
161         return 1;
162     }
163     
164     do_stop = (command == STOP_SERVICE);
165     
166     control_socket_path = "/dev/dinitctl";
167     
168     if (! sys_dinit) {
169         char * userhome = getenv("HOME");
170         if (userhome == nullptr) {
171             struct passwd * pwuid_p = getpwuid(getuid());
172             if (pwuid_p != nullptr) {
173                 userhome = pwuid_p->pw_dir;
174             }
175         }
176         
177         if (userhome != nullptr) {
178             control_socket_str = userhome;
179             control_socket_str += "/.dinitctl";
180             control_socket_path = control_socket_str.c_str();
181         }
182         else {
183             cerr << "Cannot locate user home directory (set HOME or check /etc/passwd file)" << endl;
184             return 1;
185         }
186     }
187     
188     int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
189     if (socknum == -1) {
190         perror("socket");
191         return 1;
192     }
193
194     struct sockaddr_un * name;
195     uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
196     name = (struct sockaddr_un *) malloc(sockaddr_size);
197     if (name == nullptr) {
198         cerr << "dinit-start: out of memory" << endl;
199         return 1;
200     }
201     
202     name->sun_family = AF_UNIX;
203     strcpy(name->sun_path, control_socket_path);
204     
205     int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
206     if (connr == -1) {
207         perror("connect");
208         return 1;
209     }
210     
211     // TODO should start by querying protocol version
212     
213     // Build buffer;
214     uint16_t sname_len = strlen(service_name);
215     int bufsize = 3 + sname_len;
216     char * buf = new char[bufsize];
217     
218     buf[0] = DINIT_CP_LOADSERVICE;
219     memcpy(buf + 1, &sname_len, 2);
220     memcpy(buf + 3, service_name, sname_len);
221     
222     int r = write_all(socknum, buf, bufsize);
223     delete [] buf;
224     if (r == -1) {
225         perror("write");
226         return 1;
227     }
228     
229     // Now we expect a reply:
230     // NOTE: should skip over information packets.
231     
232     try {
233         CPBuffer rbuffer;
234         wait_for_reply(rbuffer, socknum);
235         
236         ServiceState state;
237         ServiceState target_state;
238         handle_t handle;
239         
240         if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
241             fillBufferTo(&rbuffer, socknum, 2 + sizeof(handle));
242             rbuffer.extract((char *) &handle, 2, sizeof(handle));
243             state = static_cast<ServiceState>(rbuffer[1]);
244             target_state = static_cast<ServiceState>(rbuffer[2 + sizeof(handle)]);
245             rbuffer.consume(3 + sizeof(handle));
246         }
247         else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
248             cerr << "Failed to find/load service." << endl;
249             return 1;
250         }
251         else {
252             cerr << "Protocol error." << endl;
253             return 1;
254         }
255         
256         ServiceState wanted_state = do_stop ? ServiceState::STOPPED : ServiceState::STARTED;
257         int command = do_stop ? DINIT_CP_STOPSERVICE : DINIT_CP_STARTSERVICE;
258         
259         // Need to issue STOPSERVICE/STARTSERVICE
260         if (target_state != wanted_state) {
261             buf = new char[2 + sizeof(handle)];
262             buf[0] = command;
263             buf[1] = 0;  // don't pin
264             memcpy(buf + 2, &handle, sizeof(handle));
265             r = write_all(socknum, buf, 2 + sizeof(handle));
266             delete buf;
267             
268             if (r == -1) {
269                 perror("write");
270                 return 1;
271             }
272             
273             wait_for_reply(rbuffer, socknum);
274             if (rbuffer[0] != DINIT_RP_ACK) {
275                 cerr << "Protocol error." << endl;
276                 return 1;
277             }
278             rbuffer.consume(1);
279         }
280         
281         if (state == wanted_state) {
282             if (verbose) {
283                 cout << "Service already " << describeState(do_stop) << "." << endl;
284             }
285             return 0; // success!
286         }
287         
288         if (! wait_for_service) {
289             return 0;
290         }
291         
292         ServiceEvent completionEvent;
293         ServiceEvent cancelledEvent;
294         
295         if (do_stop) {
296             completionEvent = ServiceEvent::STOPPED;
297             cancelledEvent = ServiceEvent::STOPCANCELLED;
298         }
299         else {
300             completionEvent = ServiceEvent::STARTED;
301             cancelledEvent = ServiceEvent::STARTCANCELLED;
302         }
303         
304         // Wait until service started:
305         r = rbuffer.fillTo(socknum, 2);
306         while (r > 0) {
307             if (rbuffer[0] >= 100) {
308                 int pktlen = (unsigned char) rbuffer[1];
309                 fillBufferTo(&rbuffer, socknum, pktlen);
310                 
311                 if (rbuffer[0] == DINIT_IP_SERVICEEVENT) {
312                     handle_t ev_handle;
313                     rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle));
314                     ServiceEvent event = static_cast<ServiceEvent>(rbuffer[2 + sizeof(ev_handle)]);
315                     if (ev_handle == handle) {
316                         if (event == completionEvent) {
317                             if (verbose) {
318                                 cout << "Service " << describeState(do_stop) << "." << endl;
319                             }
320                             return 0;
321                         }
322                         else if (event == cancelledEvent) {
323                             if (verbose) {
324                                 cout << "Service " << describeVerb(do_stop) << " cancelled." << endl;
325                             }
326                             return 1;
327                         }
328                         else if (! do_stop && event == ServiceEvent::FAILEDSTART) {
329                             if (verbose) {
330                                 cout << "Service failed to start." << endl;
331                             }
332                             return 1;
333                         }
334                     }
335                 }
336                 
337                 rbuffer.consume(pktlen);
338                 r = rbuffer.fillTo(socknum, 2);
339             }
340             else {
341                 // Not an information packet?
342                 cerr << "protocol error" << endl;
343                 return 1;
344             }
345         }
346         
347         if (r == -1) {
348             perror("read");
349         }
350         else {
351             cerr << "protocol error (connection closed by server)" << endl;
352         }
353         return 1;
354     }
355     catch (ReadCPException &exc) {
356         cerr << "control socket read failure or protocol error" << endl;
357         return 1;
358     }
359     catch (std::bad_alloc &exc) {
360         cerr << "out of memory" << endl;
361         return 1;
362     }
363     
364     return 0;
365 }