Rename ServiceState to service_state_t.
[oweals/dinit.git] / src / dinitctl.cc
index 3e1782a61a6b1c585934e3b01c750e1d17967e5a..e8c5daa38385fb35b8a361b03f2a2f3e9dbfd6e7 100644 (file)
@@ -4,11 +4,13 @@
 #include <string>
 #include <iostream>
 #include <system_error>
+#include <memory>
 
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/un.h>
 #include <unistd.h>
+#include <signal.h>
 #include <pwd.h>
 
 #include "control-cmds.h"
@@ -17,7 +19,7 @@
 
 // dinitctl:  utility to control the Dinit daemon, including starting and stopping of services.
 
-// This utility communicates with the dinit daemon via a unix socket (/dev/initctl).
+// This utility communicates with the dinit daemon via a unix stream socket (/dev/initctl, or $HOME/.dinitctl).
 
 using handle_t = uint32_t;
 
@@ -29,15 +31,37 @@ class ReadCPException
     ReadCPException(int err) : errcode(err) { }
 };
 
-static void fillBufferTo(CPBuffer *buf, int fd, int rlength)
+enum class Command;
+
+static int issueLoadService(int socknum, const char *service_name);
+static int checkLoadReply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p);
+static int startStopService(int socknum, const char *service_name, Command command, bool do_pin, bool wait_for_service, bool verbose);
+static int unpinService(int socknum, const char *service_name, bool verbose);
+static int listServices(int socknum);
+
+
+// Fill a circular buffer from a file descriptor, reading at least _rlength_ bytes.
+// Throws ReadException if the requested number of bytes cannot be read, with:
+//     errcode = 0   if end of stream (remote end closed)
+//     errcode = errno   if another error occurred
+// Note that EINTR is ignored (i.e. the read will be re-tried).
+static void fillBufferTo(cpbuffer<1024> *buf, int fd, int rlength)
 {
-    int r = buf->fillTo(fd, rlength);
-    if (r == -1) {
-        throw ReadCPException(errno);
-    }
-    else if (r == 0) {
-        throw ReadCPException(0);
+    do {
+        int r = buf->fill_to(fd, rlength);
+        if (r == -1) {
+            if (errno != EINTR) {
+                throw ReadCPException(errno);
+            }
+        }
+        else if (r == 0) {
+            throw ReadCPException(0);
+        }
+        else {
+            return;
+        }
     }
+    while (true);
 }
 
 static const char * describeState(bool stopped)
@@ -50,11 +74,59 @@ static const char * describeVerb(bool stop)
     return stop ? "stop" : "start";
 }
 
+// Wait for a reply packet, skipping over any information packets
+// that are received in the meantime.
+static void wait_for_reply(cpbuffer<1024> &rbuffer, int fd)
+{
+    fillBufferTo(&rbuffer, fd, 1);
+    
+    while (rbuffer[0] >= 100) {
+        // Information packet; discard.
+        fillBufferTo(&rbuffer, fd, 1);
+        int pktlen = (unsigned char) rbuffer[1];
+        
+        rbuffer.consume(1);  // Consume one byte so we'll read one byte of the next packet
+        fillBufferTo(&rbuffer, fd, pktlen);
+        rbuffer.consume(pktlen - 1);
+    }
+}
+
+
+// Write *all* the requested buffer and re-try if necessary until
+// the buffer is written or an unrecoverable error occurs.
+static int write_all(int fd, const void *buf, size_t count)
+{
+    const char *cbuf = static_cast<const char *>(buf);
+    int w = 0;
+    while (count > 0) {
+        int r = write(fd, cbuf, count);
+        if (r == -1) {
+            if (errno == EINTR) continue;
+            return r;
+        }
+        w += r;
+        cbuf += r;
+        count -= r;
+    }
+    return w;
+}
+
+
+enum class Command {
+    NONE,
+    START_SERVICE,
+    WAKE_SERVICE,
+    STOP_SERVICE,
+    RELEASE_SERVICE,
+    UNPIN_SERVICE,
+    LIST_SERVICES
+};
+
+// Entry point.
 int main(int argc, char **argv)
 {
     using namespace std;
     
-    bool do_stop = false;
     bool show_help = argc < 2;
     char *service_name = nullptr;
     
@@ -64,11 +136,9 @@ int main(int argc, char **argv)
     bool verbose = true;
     bool sys_dinit = false;  // communicate with system daemon
     bool wait_for_service = true;
+    bool do_pin = false;
     
-    int command = 0;
-    
-    constexpr int START_SERVICE = 1;
-    constexpr int STOP_SERVICE = 2;
+    Command command = Command::NONE;
         
     for (int i = 1; i < argc; i++) {
         if (argv[i][0] == '-') {
@@ -85,17 +155,31 @@ int main(int argc, char **argv)
             else if (strcmp(argv[i], "--system") == 0 || strcmp(argv[i], "-s") == 0) {
                 sys_dinit = true;
             }
+            else if (strcmp(argv[i], "--pin") == 0) {
+                do_pin = true;
+            }
             else {
-                cerr << "Unrecognized command-line parameter: " << argv[i] << endl;
                 return 1;
             }
         }
-        else if (command == 0) {
-            if (strcmp(argv[i], "start")) {
-                command = START_SERVICE; 
+        else if (command == Command::NONE) {
+            if (strcmp(argv[i], "start") == 0) {
+                command = Command::START_SERVICE; 
+            }
+            else if (strcmp(argv[i], "wake") == 0) {
+                command = Command::WAKE_SERVICE;
+            }
+            else if (strcmp(argv[i], "stop") == 0) {
+                command = Command::STOP_SERVICE;
+            }
+            else if (strcmp(argv[i], "release") == 0) {
+                command = Command::RELEASE_SERVICE;
+            }
+            else if (strcmp(argv[i], "unpin") == 0) {
+                command = Command::UNPIN_SERVICE;
             }
-            else if (strcmp(argv[i], "stop")) {
-                command = STOP_SERVICE;
+            else if (strcmp(argv[i], "list") == 0) {
+                command = Command::LIST_SERVICES;
             }
             else {
                 show_help = true;
@@ -104,30 +188,52 @@ int main(int argc, char **argv)
         }
         else {
             // service name
+            if (service_name != nullptr) {
+                show_help = true;
+                break;
+            }
             service_name = argv[i];
-            // TODO support multiple services (or at least give error if multiple
-            //      services supplied)
+            // TODO support multiple services
         }
     }
     
-    if (service_name == nullptr || command = 0) {
+    if (service_name != nullptr && command == Command::LIST_SERVICES) {
+        show_help = true;
+    }
+    
+    if ((service_name == nullptr && command != Command::LIST_SERVICES) || command == Command::NONE) {
         show_help = true;
     }
 
     if (show_help) {
-        cout << "dinit-start:   start a dinit service" << endl;
+        cout << "dinitctl:   control Dinit services" << endl;
+        
+        cout << "\nUsage:" << endl;
+        cout << "    dinitctl [options] start [options] <service-name> : start and activate service" << endl;
+        cout << "    dinitctl [options] stop [options] <service-name>  : stop service and cancel explicit activation" << endl;
+        cout << "    dinitctl [options] wake [options] <service-name>  : start but do not mark activated" << endl;
+        cout << "    dinitctl [options] release [options] <service-name> : release activation, stop if no dependents" << endl;
+        cout << "    dinitctl [options] unpin <service-name>           : un-pin the service (after a previous pin)" << endl;
+        cout << "    dinitctl list                                     : list loaded services" << endl;
+        
+        cout << "\nNote: An activated service continues running when its dependents stop." << endl;
+        
+        cout << "\nGeneral options:" << endl;
+        cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
+        cout << "  --quiet          : suppress output (except errors)" << endl;
+        
+        cout << "\nCommand options:" << endl;
         cout << "  --help           : show this help" << endl;
         cout << "  --no-wait        : don't wait for service startup/shutdown to complete" << endl;
-        cout << "  --quiet          : suppress output (except errors)" << endl;
-        cout << "  -s, --system     : control system daemon instead of user daemon" << endl;
-        cout << "  <service-name>   : start the named service" << endl;
+        cout << "  --pin            : pin the service in the requested (started/stopped) state" << endl;
         return 1;
     }
     
-    do_stop = (command == STOP_SERVICE);
+    signal(SIGPIPE, SIG_IGN);
     
     control_socket_path = "/dev/dinitctl";
     
+    // Locate control socket
     if (! sys_dinit) {
         char * userhome = getenv("HOME");
         if (userhome == nullptr) {
@@ -150,7 +256,7 @@ int main(int argc, char **argv)
     
     int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
     if (socknum == -1) {
-        perror("socket");
+        perror("dinitctl: socket");
         return 1;
     }
 
@@ -158,7 +264,7 @@ int main(int argc, char **argv)
     uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
     name = (struct sockaddr_un *) malloc(sockaddr_size);
     if (name == nullptr) {
-        cerr << "dinit-start: out of memory" << endl;
+        cerr << "dinitctl: Out of memory" << endl;
         return 1;
     }
     
@@ -167,94 +273,122 @@ int main(int argc, char **argv)
     
     int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
     if (connr == -1) {
-        perror("connect");
+        perror("dinitctl: connect");
         return 1;
     }
     
     // TODO should start by querying protocol version
     
-    // Build buffer;
-    uint16_t sname_len = strlen(service_name);
-    int bufsize = 3 + sname_len;
-    char * buf = new char[bufsize];
-    
-    buf[0] = DINIT_CP_LOADSERVICE;
-    memcpy(buf + 1, &sname_len, 2);
-    memcpy(buf + 3, service_name, sname_len);
+    if (command == Command::UNPIN_SERVICE) {
+        return unpinService(socknum, service_name, verbose);
+    }
+    else if (command == Command::LIST_SERVICES) {
+        return listServices(socknum);
+    }
+
+    return startStopService(socknum, service_name, command, do_pin, wait_for_service, verbose);
+}
+
+// Start/stop a service
+static int startStopService(int socknum, const char *service_name, Command command, bool do_pin, bool wait_for_service, bool verbose)
+{
+    using namespace std;
+
+    bool do_stop = (command == Command::STOP_SERVICE || command == Command::RELEASE_SERVICE);
     
-    int r = write(socknum, buf, bufsize);
-    // TODO make sure we write it all
-    delete [] buf;
-    if (r == -1) {
-        perror("write");
+    if (issueLoadService(socknum, service_name)) {
         return 1;
     }
-    
+
     // Now we expect a reply:
-    // NOTE: should skip over information packets.
     
     try {
-        CPBuffer rbuffer;
-        fillBufferTo(&rbuffer, socknum, 1);
+        cpbuffer<1024> rbuffer;
+        wait_for_reply(rbuffer, socknum);
         
-        ServiceState state;
-        ServiceState target_state;
+        service_state_t state;
+        //service_state_t target_state;
         handle_t handle;
         
-        if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
-            fillBufferTo(&rbuffer, socknum, 2 + sizeof(handle));
-            rbuffer.extract((char *) &handle, 2, sizeof(handle));
-            state = static_cast<ServiceState>(rbuffer[1]);
-            target_state = static_cast<ServiceState>(rbuffer[2 + sizeof(handle)]);
-            rbuffer.consume(3 + sizeof(handle));
-        }
-        else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
-            cerr << "Failed to find/load service." << endl;
-            return 1;
+        if (checkLoadReply(socknum, rbuffer, &handle, &state) != 0) {
+            return 0;
         }
-        else {
-            cerr << "Protocol error." << endl;
-            return 1;
+                
+        service_state_t wanted_state = do_stop ? service_state_t::STOPPED : service_state_t::STARTED;
+        int pcommand = 0;
+        switch (command) {
+        case Command::STOP_SERVICE:
+            pcommand = DINIT_CP_STOPSERVICE;
+            break;
+        case Command::RELEASE_SERVICE:
+            pcommand = DINIT_CP_RELEASESERVICE;
+            break;
+        case Command::START_SERVICE:
+            pcommand = DINIT_CP_STARTSERVICE;
+            break;
+        case Command::WAKE_SERVICE:
+            pcommand = DINIT_CP_WAKESERVICE;
+            break;
+        default: ;
         }
         
-        ServiceState wanted_state = do_stop ? ServiceState::STOPPED : ServiceState::STARTED;
-        int command = do_stop ? DINIT_CP_STOPSERVICE : DINIT_CP_STARTSERVICE;
-        
         // Need to issue STOPSERVICE/STARTSERVICE
-        if (target_state != wanted_state) {
-            buf = new char[2 + sizeof(handle)];
-            buf[0] = command;
-            buf[1] = 0;  // don't pin
-            memcpy(buf + 2, &handle, sizeof(handle));
-            r = write(socknum, buf, 2 + sizeof(handle));
-            delete buf;
-        }
-        
-        if (state == wanted_state) {
-            if (verbose) {
-                cout << "Service already " << describeState(do_stop) << "." << endl;
+        // We'll do this regardless of the current service state / target state, since issuing
+        // start/stop also sets or clears the "explicitly started" flag on the service.
+        {
+            int r;
+            
+            {
+                auto buf = new char[2 + sizeof(handle)];
+                unique_ptr<char[]> ubuf(buf);
+                
+                buf[0] = pcommand;
+                buf[1] = do_pin ? 1 : 0;
+                memcpy(buf + 2, &handle, sizeof(handle));
+                r = write_all(socknum, buf, 2 + sizeof(handle));
+            }
+            
+            if (r == -1) {
+                perror("dinitctl: write");
+                return 1;
+            }
+            
+            wait_for_reply(rbuffer, socknum);
+            if (rbuffer[0] == DINIT_RP_ALREADYSS) {
+                bool already = (state == wanted_state);
+                if (verbose) {
+                    cout << "Service " << (already ? "(already) " : "") << describeState(do_stop) << "." << endl;
+                }
+                return 0; // success!
             }
-            return 0; // success!
+            if (rbuffer[0] != DINIT_RP_ACK) {
+                cerr << "dinitctl: Protocol error." << endl;
+                return 1;
+            }
+            rbuffer.consume(1);
         }
         
         if (! wait_for_service) {
+            if (verbose) {
+                cout << "Issued " << describeVerb(do_stop) << " command successfully." << endl;
+            }
             return 0;
         }
         
-        ServiceEvent completionEvent;
-        ServiceEvent cancelledEvent;
+        service_event completionEvent;
+        service_event cancelledEvent;
         
         if (do_stop) {
-            completionEvent = ServiceEvent::STOPPED;
-            cancelledEvent = ServiceEvent::STOPCANCELLED;
+            completionEvent = service_event::STOPPED;
+            cancelledEvent = service_event::STOPCANCELLED;
         }
         else {
-            completionEvent = ServiceEvent::STARTED;
-            cancelledEvent = ServiceEvent::STARTCANCELLED;
+            completionEvent = service_event::STARTED;
+            cancelledEvent = service_event::STARTCANCELLED;
         }
         
         // Wait until service started:
-        r = rbuffer.fillTo(socknum, 2);
+        int r = rbuffer.fill_to(socknum, 2);
         while (r > 0) {
             if (rbuffer[0] >= 100) {
                 int pktlen = (unsigned char) rbuffer[1];
@@ -263,7 +397,7 @@ int main(int argc, char **argv)
                 if (rbuffer[0] == DINIT_IP_SERVICEEVENT) {
                     handle_t ev_handle;
                     rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle));
-                    ServiceEvent event = static_cast<ServiceEvent>(rbuffer[2 + sizeof(ev_handle)]);
+                    service_event event = static_cast<service_event>(rbuffer[2 + sizeof(ev_handle)]);
                     if (ev_handle == handle) {
                         if (event == completionEvent) {
                             if (verbose) {
@@ -277,18 +411,27 @@ int main(int argc, char **argv)
                             }
                             return 1;
                         }
+                        else if (! do_stop && event == service_event::FAILEDSTART) {
+                            if (verbose) {
+                                cout << "Service failed to start." << endl;
+                            }
+                            return 1;
+                        }
                     }
                 }
+                
+                rbuffer.consume(pktlen);
+                r = rbuffer.fill_to(socknum, 2);
             }
             else {
                 // Not an information packet?
-                cerr << "protocol error" << endl;
+                cerr << "dinitctl: protocol error" << endl;
                 return 1;
             }
         }
         
         if (r == -1) {
-            perror("read");
+            perror("dinitctl: read");
         }
         else {
             cerr << "protocol error (connection closed by server)" << endl;
@@ -296,11 +439,198 @@ int main(int argc, char **argv)
         return 1;
     }
     catch (ReadCPException &exc) {
-        cerr << "control socket read failure or protocol error" << endl;
+        cerr << "dinitctl: control socket read failure or protocol error" << endl;
+        return 1;
+    }
+    catch (std::bad_alloc &exc) {
+        cerr << "dinitctl: out of memory" << endl;
+        return 1;
+    }
+    
+    return 0;
+}
+
+// Issue a "load service" command (DINIT_CP_LOADSERVICE), without waiting for
+// a response. Returns 1 on failure (with error logged), 0 on success.
+static int issueLoadService(int socknum, const char *service_name)
+{
+    using namespace std;
+    
+    // Build buffer;
+    uint16_t sname_len = strlen(service_name);
+    int bufsize = 3 + sname_len;
+    int r;
+    
+    {
+        // TODO: new: catch exception
+        unique_ptr<char[]> ubuf(new char[bufsize]);
+        auto buf = ubuf.get();
+        
+        buf[0] = DINIT_CP_LOADSERVICE;
+        memcpy(buf + 1, &sname_len, 2);
+        memcpy(buf + 3, service_name, sname_len);
+        
+        r = write_all(socknum, buf, bufsize);
+    }
+    
+    if (r == -1) {
+        perror("dinitctl: write");
+        return 1;
+    }
+    
+    return 0;
+}
+
+// Check that a "load service" reply was received, and that the requested service was found.
+static int checkLoadReply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p)
+{
+    using namespace std;
+    
+    if (rbuffer[0] == DINIT_RP_SERVICERECORD) {
+        fillBufferTo(&rbuffer, socknum, 2 + sizeof(*handle_p));
+        rbuffer.extract((char *) handle_p, 2, sizeof(*handle_p));
+        if (state_p) *state_p = static_cast<service_state_t>(rbuffer[1]);
+        //target_state = static_cast<service_state_t>(rbuffer[2 + sizeof(handle)]);
+        rbuffer.consume(3 + sizeof(*handle_p));
+        return 0;
+    }
+    else if (rbuffer[0] == DINIT_RP_NOSERVICE) {
+        cerr << "dinitctl: Failed to find/load service." << endl;
+        return 1;
+    }
+    else {
+        cerr << "dinitctl: Protocol error." << endl;
+        return 1;
+    }
+}
+
+static int unpinService(int socknum, const char *service_name, bool verbose)
+{
+    using namespace std;
+    
+    // Build buffer;
+    if (issueLoadService(socknum, service_name) == 1) {
+        return 1;
+    }
+
+    // Now we expect a reply:
+    
+    try {
+        cpbuffer<1024> rbuffer;
+        wait_for_reply(rbuffer, socknum);
+        
+        handle_t handle;
+        
+        if (checkLoadReply(socknum, rbuffer, &handle, nullptr) != 0) {
+            return 1;
+        }
+        
+        // Issue UNPIN command.
+        {
+            int r;
+            
+            {
+                char *buf = new char[1 + sizeof(handle)];
+                unique_ptr<char[]> ubuf(buf);
+                buf[0] = DINIT_CP_UNPINSERVICE;
+                memcpy(buf + 1, &handle, sizeof(handle));
+                r = write_all(socknum, buf, 2 + sizeof(handle));
+            }
+            
+            if (r == -1) {
+                perror("dinitctl: write");
+                return 1;
+            }
+            
+            wait_for_reply(rbuffer, socknum);
+            if (rbuffer[0] != DINIT_RP_ACK) {
+                cerr << "dinitctl: Protocol error." << endl;
+                return 1;
+            }
+            rbuffer.consume(1);
+        }
+    }
+    catch (ReadCPException &exc) {
+        cerr << "dinitctl: Control socket read failure or protocol error" << endl;
+        return 1;
+    }
+    catch (std::bad_alloc &exc) {
+        cerr << "dinitctl: Out of memory" << endl;
+        return 1;
+    }
+    
+    if (verbose) {
+        cout << "Service unpinned." << endl;
+    }
+    return 0;
+}
+
+static int listServices(int socknum)
+{
+    using namespace std;
+    
+    try {
+        char cmdbuf[] = { (char)DINIT_CP_LISTSERVICES };
+        int r = write_all(socknum, cmdbuf, 1);
+        
+        if (r == -1) {
+            perror("dinitctl: write");
+            return 1;
+        }
+        
+        cpbuffer<1024> rbuffer;
+        wait_for_reply(rbuffer, socknum);
+        while (rbuffer[0] == DINIT_RP_SVCINFO) {
+            fillBufferTo(&rbuffer, socknum, 8);
+            int nameLen = rbuffer[1];
+            service_state_t current = static_cast<service_state_t>(rbuffer[2]);
+            service_state_t target = static_cast<service_state_t>(rbuffer[3]);
+            
+            fillBufferTo(&rbuffer, socknum, nameLen + 8);
+            
+            char *name_ptr = rbuffer.get_ptr(8);
+            int clength = std::min(rbuffer.get_contiguous_length(name_ptr), nameLen);
+            
+            string name = string(name_ptr, clength);
+            name.append(rbuffer.get_buf_base(), nameLen - clength);
+            
+            cout << "[";
+            
+            cout << (target  == service_state_t::STARTED ? "{" : " ");
+            cout << (current == service_state_t::STARTED ? "+" : " ");
+            cout << (target  == service_state_t::STARTED ? "}" : " ");
+            
+            if (current == service_state_t::STARTING) {
+                cout << "<<";
+            }
+            else if (current == service_state_t::STOPPING) {
+                cout << ">>";
+            }
+            else {
+                cout << "  ";
+            }
+            
+            cout << (target  == service_state_t::STOPPED ? "{" : " ");
+            cout << (current == service_state_t::STOPPED ? "-" : " ");
+            cout << (target  == service_state_t::STOPPED ? "}" : " ");
+            
+            cout << "] " << name << endl;
+            
+            rbuffer.consume(8 + nameLen);
+            wait_for_reply(rbuffer, socknum);
+        }
+        
+        if (rbuffer[0] != DINIT_RP_LISTDONE) {
+            cerr << "dinitctl: Control socket protocol error" << endl;
+            return 1;
+        }
+    }
+    catch (ReadCPException &exc) {
+        cerr << "dinitctl: Control socket read failure or protocol error" << endl;
         return 1;
     }
     catch (std::bad_alloc &exc) {
-        cerr << "out of memory" << endl;
+        cerr << "dinitctl: Out of memory" << endl;
         return 1;
     }