Refactor some packet-building for readability and safety
authorDavin McCall <davmac@davmac.org>
Mon, 29 Jul 2019 10:41:23 +0000 (20:41 +1000)
committerDavin McCall <davmac@davmac.org>
Mon, 29 Jul 2019 10:41:23 +0000 (20:41 +1000)
src/dinitctl.cc
src/includes/dinit-client.h

index 9404335b98f732c9244fff7bf10e2a5634662b30..f260bf0a9d0aa5d1a4f4708357b010e7845b8f31 100644 (file)
@@ -446,12 +446,11 @@ static bool load_service(int socknum, cpbuffer_t &rbuffer, const char *name, han
 // Get the service name for a given handle, by querying the daemon.
 static std::string get_service_name(int socknum, cpbuffer_t &rbuffer, handle_t handle)
 {
-    char buf[2 + sizeof(handle)];
-    buf[0] = DINIT_CP_QUERYSERVICENAME;
-    buf[1] = 0;
-    memcpy(buf + 2, &handle, sizeof(handle));
-
-    write_all_x(socknum, buf, sizeof(buf));
+    auto m = membuf()
+            .append((char) DINIT_CP_QUERYSERVICENAME)
+            .append((char) 0)
+            .append(handle);
+    write_all_x(socknum, m);
 
     wait_for_reply(rbuffer, socknum);
 
@@ -530,15 +529,17 @@ static int start_stop_service(int socknum, cpbuffer_t &rbuffer, const char *serv
     // We'll do this regardless of the current service state / target state, since issuing
     // start/stop also sets or clears the "explicitly started" flag on the service.
     {
-        char buf[2 + sizeof(handle)];
-        buf[0] = pcommand;
-        buf[1] = (do_pin ? 1 : 0) | ((pcommand == DINIT_CP_STOPSERVICE && !do_force) ? 2 : 0);
+        char flags = (do_pin ? 1 : 0) | ((pcommand == DINIT_CP_STOPSERVICE && !do_force) ? 2 : 0);
         if (command == command_t::RESTART_SERVICE) {
-            buf[1] |= 4;
+            flags |= 4;
         }
-        memcpy(buf + 2, &handle, sizeof(handle));
-        write_all_x(socknum, buf, 2 + sizeof(handle));
-        
+
+        auto m = membuf()
+                .append((char) pcommand)
+                .append(flags)
+                .append(handle);
+        write_all_x(socknum, m);
+
         wait_for_reply(rbuffer, socknum);
         auto reply_pkt_h = rbuffer[0];
         rbuffer.consume(1); // consume header
@@ -709,10 +710,10 @@ static int unpin_service(int socknum, cpbuffer_t &rbuffer, const char *service_n
     
     // Issue UNPIN command.
     {
-        char buf[1 + sizeof(handle)];
-        buf[0] = DINIT_CP_UNPINSERVICE;
-        memcpy(buf + 1, &handle, sizeof(handle));
-        write_all_x(socknum, buf, sizeof(buf));
+        auto m = membuf()
+                .append<char>(DINIT_CP_UNPINSERVICE)
+                .append(handle);
+        write_all_x(socknum, m);
         
         wait_for_reply(rbuffer, socknum);
         if (rbuffer[0] != DINIT_RP_ACK) {
@@ -751,10 +752,10 @@ static int unload_service(int socknum, cpbuffer_t &rbuffer, const char *service_
 
     // Issue UNLOAD command.
     {
-        char buf[1 + sizeof(handle)];
-        buf[0] = DINIT_CP_UNLOADSERVICE;
-        memcpy(buf + 1, &handle, sizeof(handle));
-        write_all_x(socknum, buf, 2 + sizeof(handle));
+        auto m = membuf()
+                .append<char>(DINIT_CP_UNLOADSERVICE)
+                .append(handle);
+        write_all_x(socknum, m);
 
         wait_for_reply(rbuffer, socknum);
         if (rbuffer[0] == DINIT_RP_NAK) {
@@ -891,7 +892,6 @@ static int add_remove_dependency(int socknum, cpbuffer_t &rbuffer, bool add,
 {
     using namespace std;
 
-
     handle_t from_handle;
     handle_t to_handle;
 
@@ -900,11 +900,12 @@ static int add_remove_dependency(int socknum, cpbuffer_t &rbuffer, bool add,
         return 1;
     }
 
-    constexpr int pktsize = 2 + sizeof(handle_t) * 2;
-    char cmdbuf[pktsize] = { add ? (char)DINIT_CP_ADD_DEP : (char)DINIT_CP_REM_DEP, (char)dep_type};
-    memcpy(cmdbuf + 2, &from_handle, sizeof(from_handle));
-    memcpy(cmdbuf + 2 + sizeof(from_handle), &to_handle, sizeof(to_handle));
-    write_all_x(socknum, cmdbuf, pktsize);
+    auto m = membuf()
+            .append<char>(add ? (char)DINIT_CP_ADD_DEP : (char)DINIT_CP_REM_DEP)
+            .append(dep_type)
+            .append(from_handle)
+            .append(to_handle);
+    write_all_x(socknum, m);
 
     wait_for_reply(rbuffer, socknum);
 
@@ -926,14 +927,10 @@ static int shutdown_dinit(int socknum, cpbuffer_t &rbuffer)
     // TODO support no-wait option.
     using namespace std;
 
-    // Build buffer;
-    constexpr int bufsize = 2;
-    char buf[bufsize];
-
-    buf[0] = DINIT_CP_SHUTDOWN;
-    buf[1] = static_cast<char>(shutdown_type_t::HALT);
-
-    write_all_x(socknum, buf, bufsize);
+    auto m = membuf()
+            .append<char>(DINIT_CP_SHUTDOWN)
+            .append(static_cast<char>(shutdown_type_t::HALT));
+    write_all_x(socknum, m);
 
     wait_for_reply(rbuffer, socknum);
 
index 6fdf9b9cf714ba54e4dea078adea5814cc225991..4df4e61e89d7c0a0d34a9364bae34a4b0f069cb1 100644 (file)
@@ -34,6 +34,84 @@ class cp_old_server_exception
 };
 
 
+
+// static_membuf: a buffer of a fixed size (N) with one additional value (of type T). Don't use this
+// directly, construct via membuf.
+template <int N, typename T> class static_membuf
+{
+    public:
+    static constexpr int size() { return N + sizeof(T); }
+
+    private:
+    char buf[size()];
+
+    public:
+    static_membuf(char (&prevbuf)[N], const T &val)
+    {
+        memcpy(buf, prevbuf, N);
+        memcpy(buf + N, &val, sizeof(val));
+    }
+
+    const char *data() const { return buf; }
+
+    template <typename U> static_membuf<N+sizeof(T), U> append(const U &u)
+    {
+        return static_membuf<N+sizeof(T), U>{buf, u};
+    }
+
+    void output(char *out)
+    {
+        memcpy(out, buf, size());
+    }
+};
+
+// static_membuf specialisation for N = 0. Don't use this directly, construct via membuf.
+template <typename T> class static_membuf<0, T>
+{
+    public:
+    static constexpr int size() { return sizeof(T); }
+
+    private:
+    char buf[size()];
+
+    public:
+    static_membuf(const T &val)
+    {
+        memcpy(buf, &val, sizeof(val));
+    }
+
+    const char *data() { return buf; }
+
+    template <typename U> static_membuf<sizeof(T), U> append(const U &u)
+    {
+        return static_membuf<sizeof(T), U>{buf, u};
+    }
+
+    void output(char *out)
+    {
+        memcpy(out, buf, size());
+    }
+};
+
+// "membuf" class provides a compile-time allocated buffer that we can add items to one-by-one. This is
+// much safer than working with raw buffers and calculating offsets and sizes by hand (and with a decent
+// compiler the end result is just as efficient).
+//
+// To use:
+//     auto m = membuf().append(value1).append(value2).append(value3);
+// Then:
+//     m.size() - returns total size of the buffer (sizeof(value1)+...)
+//     m.data() - returns a 'const char *' to the buffer contents
+class membuf
+{
+    public:
+
+    template <typename U> static_membuf<0, U> append(const U &u)
+    {
+        return static_membuf<0, U>(u);
+    }
+};
+
 // Fill a circular buffer from a file descriptor, until it contains at least _rlength_ bytes.
 // Throws cp_read_exception if the requested number of bytes cannot be read, with:
 //     errcode = 0   if end of stream (remote end closed)
@@ -114,6 +192,12 @@ inline void write_all_x(int fd, const void *buf, size_t count)
     }
 }
 
+// Write all the requested buffer (eg membuf) and throw an exception on failure.
+template <typename Buf> inline void write_all_x(int fd, const Buf &b)
+{
+    write_all_x(fd, b.data(), b.size());
+}
+
 // Check the protocol version is compatible with the client.
 //   minversion - minimum protocol version that client can speak
 //   version - maximum protocol version that client can speak