Make sure to not write a partial command packet.
[oweals/dinit.git] / src / dinitctl.cc
1 #include <cstdio>
2 #include <cstddef>
3 #include <cstring>
4 #include <string>
5 #include <iostream>
6 #include <system_error>
7
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 #include <sys/un.h>
11 #include <unistd.h>
12 #include <pwd.h>
13
14 #include "control-cmds.h"
15 #include "service-constants.h"
16 #include "cpbuffer.h"
17
18 // dinitctl:  utility to control the Dinit daemon, including starting and stopping of services.
19
20 // This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
21
22 using handle_t = uint32_t;
23
24
25 class ReadCPException
26 {
27     public:
28     int errcode;
29     ReadCPException(int err) : errcode(err) { }
30 };
31
32 static void fillBufferTo(CPBuffer *buf, int fd, int rlength)
33 {
34     int r = buf->fillTo(fd, rlength);
35     if (r == -1) {
36         throw ReadCPException(errno);
37     }
38     else if (r == 0) {
39         throw ReadCPException(0);
40     }
41 }
42
43 static const char * describeState(bool stopped)
44 {
45     return stopped ? "stopped" : "started";
46 }
47
48 static const char * describeVerb(bool stop)
49 {
50     return stop ? "stop" : "start";
51 }
52
53 // Write *all* the requested buffer and re-try if necessary until
54 // the buffer is written or an unrecoverable error occurs.
55 static int write_all(int fd, const void *buf, size_t count)
56 {
57     const char *cbuf = static_cast<const char *>(buf);
58     int w = 0;
59     while (count > 0) {
60         int r = write(fd, cbuf, count);
61         if (r == -1) {
62             if (errno == EINTR) continue;
63             return r;
64         }
65         w += r;
66         cbuf += r;
67         count -= r;
68     }
69     return w;
70 }
71
72 int main(int argc, char **argv)
73 {
74     using namespace std;
75     
76     bool do_stop = false;
77     bool show_help = argc < 2;
78     char *service_name = nullptr;
79     
80     std::string control_socket_str;
81     const char * control_socket_path = nullptr;
82     
83     bool verbose = true;
84     bool sys_dinit = false;  // communicate with system daemon
85     bool wait_for_service = true;
86     
87     int command = 0;
88     
89     constexpr int START_SERVICE = 1;
90     constexpr int STOP_SERVICE = 2;
91         
92     for (int i = 1; i < argc; i++) {
93         if (argv[i][0] == '-') {
94             if (strcmp(argv[i], "--help") == 0) {
95                 show_help = true;
96                 break;
97             }
98             else if (strcmp(argv[i], "--no-wait") == 0) {
99                 wait_for_service = false;
100             }
101             else if (strcmp(argv[i], "--quiet") == 0) {
102                 verbose = false;
103             }
104             else if (strcmp(argv[i], "--system") == 0 || strcmp(argv[i], "-s") == 0) {
105                 sys_dinit = true;
106             }
107             else {
108                 cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
109                 return 1;
110             }
111         }
112         else if (command == 0) {
113             if (strcmp(argv[i], "start") == 0) {
114                 command = START_SERVICE; 
115             }
116             else if (strcmp(argv[i], "stop") == 0) {
117                 command = STOP_SERVICE;
118             }
119             else {
120                 show_help = true;
121                 break;
122             }
123         }
124         else {
125             // service name
126             service_name = argv[i];
127             // TODO support multiple services (or at least give error if multiple
128             //      services supplied)
129         }
130     }
131     
132     if (service_name == nullptr || command == 0) {
133         show_help = true;
134     }
135
136     if (show_help) {
137         cout << "dinit-start:   start a dinit service" << endl;
138         cout << "  --help           : show this help" << endl;
139         cout << "  --no-wait        : don't wait for service startup/shutdown to complete" << endl;
140         cout << "  --quiet          : suppress output (except errors)" << endl;
141         cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
142         cout << "  <service-name>   : start the named service" << endl;
143         return 1;
144     }
145     
146     do_stop = (command == STOP_SERVICE);
147     
148     control_socket_path = "/dev/dinitctl";
149     
150     if (! sys_dinit) {
151         char * userhome = getenv("HOME");
152         if (userhome == nullptr) {
153             struct passwd * pwuid_p = getpwuid(getuid());
154             if (pwuid_p != nullptr) {
155                 userhome = pwuid_p->pw_dir;
156             }
157         }
158         
159         if (userhome != nullptr) {
160             control_socket_str = userhome;
161             control_socket_str += "/.dinitctl";
162             control_socket_path = control_socket_str.c_str();
163         }
164         else {
165             cerr << "Cannot locate user home directory (set HOME or check /etc/passwd file)" << endl;
166             return 1;
167         }
168     }
169     
170     int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
171     if (socknum == -1) {
172         perror("socket");
173         return 1;
174     }
175
176     struct sockaddr_un * name;
177     uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
178     name = (struct sockaddr_un *) malloc(sockaddr_size);
179     if (name == nullptr) {
180         cerr << "dinit-start: out of memory" << endl;
181         return 1;
182     }
183     
184     name->sun_family = AF_UNIX;
185     strcpy(name->sun_path, control_socket_path);
186     
187     int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
188     if (connr == -1) {
189         perror("connect");
190         return 1;
191     }
192     
193     // TODO should start by querying protocol version
194     
195     // Build buffer;
196     uint16_t sname_len = strlen(service_name);
197     int bufsize = 3 + sname_len;
198     char * buf = new char[bufsize];
199     
200     buf[0] = DINIT_CP_LOADSERVICE;
201     memcpy(buf + 1, &sname_len, 2);
202     memcpy(buf + 3, service_name, sname_len);
203     
204     int r = write_all(socknum, buf, bufsize);
205     delete [] buf;
206     if (r == -1) {
207         perror("write");
208         return 1;
209     }
210     
211     // Now we expect a reply:
212     // NOTE: should skip over information packets.
213     
214     try {
215         CPBuffer rbuffer;
216         fillBufferTo(&rbuffer, socknum, 1);
217         
218         ServiceState state;
219         ServiceState target_state;
220         handle_t handle;
221         
222         if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
223             fillBufferTo(&rbuffer, socknum, 2 + sizeof(handle));
224             rbuffer.extract((char *) &handle, 2, sizeof(handle));
225             state = static_cast<ServiceState>(rbuffer[1]);
226             target_state = static_cast<ServiceState>(rbuffer[2 + sizeof(handle)]);
227             rbuffer.consume(3 + sizeof(handle));
228         }
229         else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
230             cerr << "Failed to find/load service." << endl;
231             return 1;
232         }
233         else {
234             cerr << "Protocol error." << endl;
235             return 1;
236         }
237         
238         ServiceState wanted_state = do_stop ? ServiceState::STOPPED : ServiceState::STARTED;
239         int command = do_stop ? DINIT_CP_STOPSERVICE : DINIT_CP_STARTSERVICE;
240         
241         // Need to issue STOPSERVICE/STARTSERVICE
242         if (target_state != wanted_state) {
243             buf = new char[2 + sizeof(handle)];
244             buf[0] = command;
245             buf[1] = 0;  // don't pin
246             memcpy(buf + 2, &handle, sizeof(handle));
247             r = write(socknum, buf, 2 + sizeof(handle));
248             delete buf;
249         }
250         
251         if (state == wanted_state) {
252             if (verbose) {
253                 cout << "Service already " << describeState(do_stop) << "." << endl;
254             }
255             return 0; // success!
256         }
257         
258         if (! wait_for_service) {
259             return 0;
260         }
261         
262         ServiceEvent completionEvent;
263         ServiceEvent cancelledEvent;
264         
265         if (do_stop) {
266             completionEvent = ServiceEvent::STOPPED;
267             cancelledEvent = ServiceEvent::STOPCANCELLED;
268         }
269         else {
270             completionEvent = ServiceEvent::STARTED;
271             cancelledEvent = ServiceEvent::STARTCANCELLED;
272         }
273         
274         // Wait until service started:
275         r = rbuffer.fillTo(socknum, 2);
276         while (r > 0) {
277             if (rbuffer[0] >= 100) {
278                 int pktlen = (unsigned char) rbuffer[1];
279                 fillBufferTo(&rbuffer, socknum, pktlen);
280                 
281                 if (rbuffer[0] == DINIT_IP_SERVICEEVENT) {
282                     handle_t ev_handle;
283                     rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle));
284                     ServiceEvent event = static_cast<ServiceEvent>(rbuffer[2 + sizeof(ev_handle)]);
285                     if (ev_handle == handle) {
286                         if (event == completionEvent) {
287                             if (verbose) {
288                                 cout << "Service " << describeState(do_stop) << "." << endl;
289                             }
290                             return 0;
291                         }
292                         else if (event == cancelledEvent) {
293                             if (verbose) {
294                                 cout << "Service " << describeVerb(do_stop) << " cancelled." << endl;
295                             }
296                             return 1;
297                         }
298                         else if (! do_stop && event == ServiceEvent::FAILEDSTART) {
299                             if (verbose) {
300                                 cout << "Service failed to start." << endl;
301                             }
302                             return 1;
303                         }
304                     }
305                 }
306             }
307             else {
308                 // Not an information packet?
309                 cerr << "protocol error" << endl;
310                 return 1;
311             }
312         }
313         
314         if (r == -1) {
315             perror("read");
316         }
317         else {
318             cerr << "protocol error (connection closed by server)" << endl;
319         }
320         return 1;
321     }
322     catch (ReadCPException &exc) {
323         cerr << "control socket read failure or protocol error" << endl;
324         return 1;
325     }
326     catch (std::bad_alloc &exc) {
327         cerr << "out of memory" << endl;
328         return 1;
329     }
330     
331     return 0;
332 }