7399cb559fd7fb70624dba64b0791cf63c6987fe
[oweals/dinit.git] / src / dinitctl.cc
1 #include <cstdio>
2 #include <cstddef>
3 #include <cstring>
4 #include <string>
5 #include <iostream>
6 #include <system_error>
7
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 #include <sys/un.h>
11 #include <unistd.h>
12 #include <pwd.h>
13
14 #include "control-cmds.h"
15 #include "service-constants.h"
16 #include "cpbuffer.h"
17
18 // dinitctl:  utility to control the Dinit daemon, including starting and stopping of services.
19
20 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
21
22 using handle_t = uint32_t;
23
24
25 class ReadCPException
26 {
27     public:
28     int errcode;
29     ReadCPException(int err) : errcode(err) { }
30 };
31
32 static void fillBufferTo(CPBuffer *buf, int fd, int rlength)
33 {
34     int r = buf->fillTo(fd, rlength);
35     if (r == -1) {
36         throw ReadCPException(errno);
37     }
38     else if (r == 0) {
39         throw ReadCPException(0);
40     }
41 }
42
43 static const char * describeState(bool stopped)
44 {
45     return stopped ? "stopped" : "started";
46 }
47
48 static const char * describeVerb(bool stop)
49 {
50     return stop ? "stop" : "start";
51 }
52
53 // Wait for a reply packet, skipping over any information packets
54 // that are received in the meantime.
55 static void wait_for_reply(CPBuffer &rbuffer, int fd)
56 {
57     fillBufferTo(&rbuffer, fd, 1);
58     
59     while (rbuffer[0] >= 100) {
60         // Information packet; discard.
61         fillBufferTo(&rbuffer, fd, 1);
62         int pktlen = (unsigned char) rbuffer[1];
63         fillBufferTo(&rbuffer, fd, pktlen);
64         rbuffer.consume(pktlen);
65     }
66 }
67
68
69 // Write *all* the requested buffer and re-try if necessary until
70 // the buffer is written or an unrecoverable error occurs.
71 static int write_all(int fd, const void *buf, size_t count)
72 {
73     const char *cbuf = static_cast<const char *>(buf);
74     int w = 0;
75     while (count > 0) {
76         int r = write(fd, cbuf, count);
77         if (r == -1) {
78             if (errno == EINTR) continue;
79             return r;
80         }
81         w += r;
82         cbuf += r;
83         count -= r;
84     }
85     return w;
86 }
87
88 int main(int argc, char **argv)
89 {
90     using namespace std;
91     
92     bool do_stop = false;
93     bool show_help = argc < 2;
94     char *service_name = nullptr;
95     
96     std::string control_socket_str;
97     const char * control_socket_path = nullptr;
98     
99     bool verbose = true;
100     bool sys_dinit = false;  // communicate with system daemon
101     bool wait_for_service = true;
102     
103     int command = 0;
104     
105     constexpr int START_SERVICE = 1;
106     constexpr int STOP_SERVICE = 2;
107         
108     for (int i = 1; i < argc; i++) {
109         if (argv[i][0] == '-') {
110             if (strcmp(argv[i], "--help") == 0) {
111                 show_help = true;
112                 break;
113             }
114             else if (strcmp(argv[i], "--no-wait") == 0) {
115                 wait_for_service = false;
116             }
117             else if (strcmp(argv[i], "--quiet") == 0) {
118                 verbose = false;
119             }
120             else if (strcmp(argv[i], "--system") == 0 || strcmp(argv[i], "-s") == 0) {
121                 sys_dinit = true;
122             }
123             else {
124                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
125                 return 1;
126             }
127         }
128         else if (command == 0) {
129             if (strcmp(argv[i], "start") == 0) {
130                 command = START_SERVICE; 
131             }
132             else if (strcmp(argv[i], "stop") == 0) {
133                 command = STOP_SERVICE;
134             }
135             else {
136                 show_help = true;
137                 break;
138             }
139         }
140         else {
141             // service name
142             service_name = argv[i];
143             // TODO support multiple services (or at least give error if multiple
144             //      services supplied)
145         }
146     }
147     
148     if (service_name == nullptr || command == 0) {
149         show_help = true;
150     }
151
152     if (show_help) {
153         cout << "dinit-start:   start a dinit service" << endl;
154         cout << "  --help           : show this help" << endl;
155         cout << "  --no-wait        : don't wait for service startup/shutdown to complete" << endl;
156         cout << "  --quiet          : suppress output (except errors)" << endl;
157         cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
158         cout << "  <service-name>   : start the named service" << endl;
159         return 1;
160     }
161     
162     do_stop = (command == STOP_SERVICE);
163     
164     control_socket_path = "/dev/dinitctl";
165     
166     if (! sys_dinit) {
167         char * userhome = getenv("HOME");
168         if (userhome == nullptr) {
169             struct passwd * pwuid_p = getpwuid(getuid());
170             if (pwuid_p != nullptr) {
171                 userhome = pwuid_p->pw_dir;
172             }
173         }
174         
175         if (userhome != nullptr) {
176             control_socket_str = userhome;
177             control_socket_str += "/.dinitctl";
178             control_socket_path = control_socket_str.c_str();
179         }
180         else {
181             cerr << "Cannot locate user home directory (set HOME or check /etc/passwd file)" << endl;
182             return 1;
183         }
184     }
185     
186     int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
187     if (socknum == -1) {
188         perror("socket");
189         return 1;
190     }
191
192     struct sockaddr_un * name;
193     uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
194     name = (struct sockaddr_un *) malloc(sockaddr_size);
195     if (name == nullptr) {
196         cerr << "dinit-start: out of memory" << endl;
197         return 1;
198     }
199     
200     name->sun_family = AF_UNIX;
201     strcpy(name->sun_path, control_socket_path);
202     
203     int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
204     if (connr == -1) {
205         perror("connect");
206         return 1;
207     }
208     
209     // TODO should start by querying protocol version
210     
211     // Build buffer;
212     uint16_t sname_len = strlen(service_name);
213     int bufsize = 3 + sname_len;
214     char * buf = new char[bufsize];
215     
216     buf[0] = DINIT_CP_LOADSERVICE;
217     memcpy(buf + 1, &sname_len, 2);
218     memcpy(buf + 3, service_name, sname_len);
219     
220     int r = write_all(socknum, buf, bufsize);
221     delete [] buf;
222     if (r == -1) {
223         perror("write");
224         return 1;
225     }
226     
227     // Now we expect a reply:
228     // NOTE: should skip over information packets.
229     
230     try {
231         CPBuffer rbuffer;
232         wait_for_reply(rbuffer, socknum);
233         
234         ServiceState state;
235         ServiceState target_state;
236         handle_t handle;
237         
238         if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
239             fillBufferTo(&rbuffer, socknum, 2 + sizeof(handle));
240             rbuffer.extract((char *) &handle, 2, sizeof(handle));
241             state = static_cast<ServiceState>(rbuffer[1]);
242             target_state = static_cast<ServiceState>(rbuffer[2 + sizeof(handle)]);
243             rbuffer.consume(3 + sizeof(handle));
244         }
245         else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
246             cerr << "Failed to find/load service." << endl;
247             return 1;
248         }
249         else {
250             cerr << "Protocol error." << endl;
251             return 1;
252         }
253         
254         ServiceState wanted_state = do_stop ? ServiceState::STOPPED : ServiceState::STARTED;
255         int command = do_stop ? DINIT_CP_STOPSERVICE : DINIT_CP_STARTSERVICE;
256         
257         // Need to issue STOPSERVICE/STARTSERVICE
258         if (target_state != wanted_state) {
259             buf = new char[2 + sizeof(handle)];
260             buf[0] = command;
261             buf[1] = 0;  // don't pin
262             memcpy(buf + 2, &handle, sizeof(handle));
263             r = write_all(socknum, buf, 2 + sizeof(handle));
264             delete buf;
265             
266             if (r == -1) {
267                 perror("write");
268                 return 1;
269             }
270             
271             wait_for_reply(rbuffer, socknum);
272             if (rbuffer[0] != DINIT_RP_ACK) {
273                 cerr << "Protocol error." << endl;
274                 return 1;
275             }
276             rbuffer.consume(1);
277         }
278         
279         if (state == wanted_state) {
280             if (verbose) {
281                 cout << "Service already " << describeState(do_stop) << "." << endl;
282             }
283             return 0; // success!
284         }
285         
286         if (! wait_for_service) {
287             return 0;
288         }
289         
290         ServiceEvent completionEvent;
291         ServiceEvent cancelledEvent;
292         
293         if (do_stop) {
294             completionEvent = ServiceEvent::STOPPED;
295             cancelledEvent = ServiceEvent::STOPCANCELLED;
296         }
297         else {
298             completionEvent = ServiceEvent::STARTED;
299             cancelledEvent = ServiceEvent::STARTCANCELLED;
300         }
301         
302         // Wait until service started:
303         r = rbuffer.fillTo(socknum, 2);
304         while (r > 0) {
305             if (rbuffer[0] >= 100) {
306                 int pktlen = (unsigned char) rbuffer[1];
307                 fillBufferTo(&rbuffer, socknum, pktlen);
308                 
309                 if (rbuffer[0] == DINIT_IP_SERVICEEVENT) {
310                     handle_t ev_handle;
311                     rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle));
312                     ServiceEvent event = static_cast<ServiceEvent>(rbuffer[2 + sizeof(ev_handle)]);
313                     if (ev_handle == handle) {
314                         if (event == completionEvent) {
315                             if (verbose) {
316                                 cout << "Service " << describeState(do_stop) << "." << endl;
317                             }
318                             return 0;
319                         }
320                         else if (event == cancelledEvent) {
321                             if (verbose) {
322                                 cout << "Service " << describeVerb(do_stop) << " cancelled." << endl;
323                             }
324                             return 1;
325                         }
326                         else if (! do_stop && event == ServiceEvent::FAILEDSTART) {
327                             if (verbose) {
328                                 cout << "Service failed to start." << endl;
329                             }
330                             return 1;
331                         }
332                     }
333                 }
334                 
335                 rbuffer.consume(pktlen);
336                 r = rbuffer.fillTo(socknum, 2);
337             }
338             else {
339                 // Not an information packet?
340                 cerr << "protocol error" << endl;
341                 return 1;
342             }
343         }
344         
345         if (r == -1) {
346             perror("read");
347         }
348         else {
349             cerr << "protocol error (connection closed by server)" << endl;
350         }
351         return 1;
352     }
353     catch (ReadCPException &exc) {
354         cerr << "control socket read failure or protocol error" << endl;
355         return 1;
356     }
357     catch (std::bad_alloc &exc) {
358         cerr << "out of memory" << endl;
359         return 1;
360     }
361     
362     return 0;
363 }