From e08729d5fe1467971dcdeb3bc92ab4454bcad6d5 Mon Sep 17 00:00:00 2001 From: Davin McCall Date: Mon, 12 Feb 2018 20:34:56 +0000 Subject: [PATCH] dinitctl: query protocol version at beginning + refactoring. --- src/control.cc | 4 +- src/dinitctl.cc | 565 +++++++++++++++++------------------- src/includes/dinit-client.h | 94 +++++- src/shutdown.cc | 2 +- 4 files changed, 349 insertions(+), 316 deletions(-) diff --git a/src/control.cc b/src/control.cc index 7ead46e..15e7ca2 100644 --- a/src/control.cc +++ b/src/control.cc @@ -17,9 +17,9 @@ bool control_conn_t::process_packet() int pktType = rbuf[0]; if (pktType == DINIT_CP_QUERYVERSION) { // Responds with: - // DINIT_RP_CVERSION, (2 byte) minimum compatible version, (2 byte) maximum compatible version + // DINIT_RP_CVERSION, (2 byte) minimum compatible version, (2 byte) actual version char replyBuf[] = { DINIT_RP_CPVERSION, 0, 0, 0, 0 }; - if (! queue_packet(replyBuf, 1)) return false; + if (! queue_packet(replyBuf, sizeof(replyBuf))) return false; rbuf.consume(1); return true; } diff --git a/src/dinitctl.cc b/src/dinitctl.cc index e224932..359ea55 100644 --- a/src/dinitctl.cc +++ b/src/dinitctl.cc @@ -26,12 +26,13 @@ enum class command_t; static int issue_load_service(int socknum, const char *service_name, bool find_only = false); -static int check_load_reply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p); -static int start_stop_service(int socknum, const char *service_name, command_t command, bool do_pin, bool wait_for_service, bool verbose); -static int unpin_service(int socknum, const char *service_name, bool verbose); -static int unload_service(int socknum, const char *service_name); -static int list_services(int socknum); -static int shutdown_dinit(int soclknum); +static int check_load_reply(int socknum, cpbuffer_t &, handle_t *handle_p, service_state_t *state_p); +static int start_stop_service(int socknum, cpbuffer_t &, const char *service_name, command_t command, + bool do_pin, bool wait_for_service, bool verbose); +static int unpin_service(int socknum, cpbuffer_t &, const char *service_name, bool verbose); +static int unload_service(int socknum, cpbuffer_t &, const char *service_name); +static int list_services(int socknum, cpbuffer_t &); +static int shutdown_dinit(int soclknum, cpbuffer_t &); static const char * describeState(bool stopped) @@ -56,6 +57,7 @@ enum class command_t { SHUTDOWN }; + // Entry point. int main(int argc, char **argv) { @@ -221,27 +223,49 @@ int main(int argc, char **argv) return 1; } - // TODO should start by querying protocol version - - if (command == command_t::UNPIN_SERVICE) { - return unpin_service(socknum, service_name, verbose); + try { + // Start by querying protocol version: + cpbuffer_t rbuffer; + check_protocol_version(0, 0, rbuffer, socknum); + + if (command == command_t::UNPIN_SERVICE) { + return unpin_service(socknum, rbuffer, service_name, verbose); + } + else if (command == command_t::UNLOAD_SERVICE) { + return unload_service(socknum, rbuffer, service_name); + } + else if (command == command_t::LIST_SERVICES) { + return list_services(socknum, rbuffer); + } + else if (command == command_t::SHUTDOWN) { + return shutdown_dinit(socknum, rbuffer); + } + else { + return start_stop_service(socknum, rbuffer, service_name, command, do_pin, + wait_for_service, verbose); + } } - else if (command == command_t::UNLOAD_SERVICE) { - return unload_service(socknum, service_name); + catch (cp_old_client_exception &e) { + std::cerr << "dinitctl: too old (server reports newer protocol version)" << std::endl; + return 1; } - else if (command == command_t::LIST_SERVICES) { - return list_services(socknum); + catch (cp_old_server_exception &e) { + std::cerr << "dinitctl: server too old or protocol error" << std::endl; + return 1; } - else if (command == command_t::SHUTDOWN) { - return shutdown_dinit(socknum); + catch (cp_read_exception &e) { + cerr << "dinitctl: control socket read failure or protocol error" << endl; + return 1; } - else { - return start_stop_service(socknum, service_name, command, do_pin, wait_for_service, verbose); + catch (cp_write_exception &e) { + cerr << "dinitctl: control socket write error: " << std::strerror(e.errcode) << endl; + return 1; } } // Start/stop a service -static int start_stop_service(int socknum, const char *service_name, command_t command, bool do_pin, bool wait_for_service, bool verbose) +static int start_stop_service(int socknum, cpbuffer_t &rbuffer, const char *service_name, + command_t command, bool do_pin, bool wait_for_service, bool verbose) { using namespace std; @@ -253,152 +277,139 @@ static int start_stop_service(int socknum, const char *service_name, command_t c // Now we expect a reply: - try { - cpbuffer<1024> rbuffer; - wait_for_reply(rbuffer, socknum); - - service_state_t state; - //service_state_t target_state; - handle_t handle; + wait_for_reply(rbuffer, socknum); + + service_state_t state; + //service_state_t target_state; + handle_t handle; + + if (check_load_reply(socknum, rbuffer, &handle, &state) != 0) { + return 0; + } + + service_state_t wanted_state = do_stop ? service_state_t::STOPPED : service_state_t::STARTED; + int pcommand = 0; + switch (command) { + case command_t::STOP_SERVICE: + pcommand = DINIT_CP_STOPSERVICE; + break; + case command_t::RELEASE_SERVICE: + pcommand = DINIT_CP_RELEASESERVICE; + break; + case command_t::START_SERVICE: + pcommand = DINIT_CP_STARTSERVICE; + break; + case command_t::WAKE_SERVICE: + pcommand = DINIT_CP_WAKESERVICE; + break; + default: ; + } + + // Need to issue STOPSERVICE/STARTSERVICE + // We'll do this regardless of the current service state / target state, since issuing + // start/stop also sets or clears the "explicitly started" flag on the service. + { + int r; - if (check_load_reply(socknum, rbuffer, &handle, &state) != 0) { - return 0; - } - - service_state_t wanted_state = do_stop ? service_state_t::STOPPED : service_state_t::STARTED; - int pcommand = 0; - switch (command) { - case command_t::STOP_SERVICE: - pcommand = DINIT_CP_STOPSERVICE; - break; - case command_t::RELEASE_SERVICE: - pcommand = DINIT_CP_RELEASESERVICE; - break; - case command_t::START_SERVICE: - pcommand = DINIT_CP_STARTSERVICE; - break; - case command_t::WAKE_SERVICE: - pcommand = DINIT_CP_WAKESERVICE; - break; - default: ; + { + auto buf = new char[2 + sizeof(handle)]; + unique_ptr ubuf(buf); + + buf[0] = pcommand; + buf[1] = do_pin ? 1 : 0; + memcpy(buf + 2, &handle, sizeof(handle)); + r = write_all(socknum, buf, 2 + sizeof(handle)); } - // Need to issue STOPSERVICE/STARTSERVICE - // We'll do this regardless of the current service state / target state, since issuing - // start/stop also sets or clears the "explicitly started" flag on the service. - { - int r; - - { - auto buf = new char[2 + sizeof(handle)]; - unique_ptr ubuf(buf); - - buf[0] = pcommand; - buf[1] = do_pin ? 1 : 0; - memcpy(buf + 2, &handle, sizeof(handle)); - r = write_all(socknum, buf, 2 + sizeof(handle)); - } - - if (r == -1) { - perror("dinitctl: write"); - return 1; - } - - wait_for_reply(rbuffer, socknum); - if (rbuffer[0] == DINIT_RP_ALREADYSS) { - bool already = (state == wanted_state); - if (verbose) { - cout << "Service " << (already ? "(already) " : "") << describeState(do_stop) << "." << endl; - } - return 0; // success! - } - if (rbuffer[0] != DINIT_RP_ACK) { - cerr << "dinitctl: Protocol error." << endl; - return 1; - } - rbuffer.consume(1); + if (r == -1) { + perror("dinitctl: write"); + return 1; } - if (! wait_for_service) { + wait_for_reply(rbuffer, socknum); + if (rbuffer[0] == DINIT_RP_ALREADYSS) { + bool already = (state == wanted_state); if (verbose) { - cout << "Issued " << describeVerb(do_stop) << " command successfully." << endl; + cout << "Service " << (already ? "(already) " : "") << describeState(do_stop) << "." << endl; } - return 0; + return 0; // success! } - - service_event_t completionEvent; - service_event_t cancelledEvent; - - if (do_stop) { - completionEvent = service_event_t::STOPPED; - cancelledEvent = service_event_t::STOPCANCELLED; + if (rbuffer[0] != DINIT_RP_ACK) { + cerr << "dinitctl: Protocol error." << endl; + return 1; } - else { - completionEvent = service_event_t::STARTED; - cancelledEvent = service_event_t::STARTCANCELLED; + rbuffer.consume(1); + } + + if (! wait_for_service) { + if (verbose) { + cout << "Issued " << describeVerb(do_stop) << " command successfully." << endl; } - - // Wait until service started: - int r = rbuffer.fill_to(socknum, 2); - while (r > 0) { - if (rbuffer[0] >= 100) { - int pktlen = (unsigned char) rbuffer[1]; - fill_buffer_to(&rbuffer, socknum, pktlen); - - if (rbuffer[0] == DINIT_IP_SERVICEEVENT) { - handle_t ev_handle; - rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle)); - service_event_t event = static_cast(rbuffer[2 + sizeof(ev_handle)]); - if (ev_handle == handle) { - if (event == completionEvent) { - if (verbose) { - cout << "Service " << describeState(do_stop) << "." << endl; - } - return 0; + return 0; + } + + service_event_t completionEvent; + service_event_t cancelledEvent; + + if (do_stop) { + completionEvent = service_event_t::STOPPED; + cancelledEvent = service_event_t::STOPCANCELLED; + } + else { + completionEvent = service_event_t::STARTED; + cancelledEvent = service_event_t::STARTCANCELLED; + } + + // Wait until service started: + int r = rbuffer.fill_to(socknum, 2); + while (r > 0) { + if (rbuffer[0] >= 100) { + int pktlen = (unsigned char) rbuffer[1]; + fill_buffer_to(rbuffer, socknum, pktlen); + + if (rbuffer[0] == DINIT_IP_SERVICEEVENT) { + handle_t ev_handle; + rbuffer.extract((char *) &ev_handle, 2, sizeof(ev_handle)); + service_event_t event = static_cast(rbuffer[2 + sizeof(ev_handle)]); + if (ev_handle == handle) { + if (event == completionEvent) { + if (verbose) { + cout << "Service " << describeState(do_stop) << "." << endl; } - else if (event == cancelledEvent) { - if (verbose) { - cout << "Service " << describeVerb(do_stop) << " cancelled." << endl; - } - return 1; + return 0; + } + else if (event == cancelledEvent) { + if (verbose) { + cout << "Service " << describeVerb(do_stop) << " cancelled." << endl; } - else if (! do_stop && event == service_event_t::FAILEDSTART) { - if (verbose) { - cout << "Service failed to start." << endl; - } - return 1; + return 1; + } + else if (! do_stop && event == service_event_t::FAILEDSTART) { + if (verbose) { + cout << "Service failed to start." << endl; } + return 1; } } - - rbuffer.consume(pktlen); - r = rbuffer.fill_to(socknum, 2); - } - else { - // Not an information packet? - cerr << "dinitctl: protocol error" << endl; - return 1; } - } - - if (r == -1) { - perror("dinitctl: read"); + + rbuffer.consume(pktlen); + r = rbuffer.fill_to(socknum, 2); } else { - cerr << "protocol error (connection closed by server)" << endl; + // Not an information packet? + cerr << "dinitctl: protocol error" << endl; + return 1; } - return 1; } - catch (read_cp_exception &exc) { - cerr << "dinitctl: control socket read failure or protocol error" << endl; - return 1; + + if (r == -1) { + perror("dinitctl: read"); } - catch (std::bad_alloc &exc) { - cerr << "dinitctl: out of memory" << endl; - return 1; + else { + cerr << "protocol error (connection closed by server)" << endl; } - - return 0; + return 1; } // Issue a "load service" command (DINIT_CP_LOADSERVICE), without waiting for @@ -434,12 +445,12 @@ static int issue_load_service(int socknum, const char *service_name, bool find_o } // Check that a "load service" reply was received, and that the requested service was found. -static int check_load_reply(int socknum, cpbuffer<1024> &rbuffer, handle_t *handle_p, service_state_t *state_p) +static int check_load_reply(int socknum, cpbuffer_t &rbuffer, handle_t *handle_p, service_state_t *state_p) { using namespace std; if (rbuffer[0] == DINIT_RP_SERVICERECORD) { - fill_buffer_to(&rbuffer, socknum, 2 + sizeof(*handle_p)); + fill_buffer_to(rbuffer, socknum, 2 + sizeof(*handle_p)); rbuffer.extract((char *) handle_p, 2, sizeof(*handle_p)); if (state_p) *state_p = static_cast(rbuffer[1]); //target_state = static_cast(rbuffer[2 + sizeof(handle)]); @@ -447,16 +458,16 @@ static int check_load_reply(int socknum, cpbuffer<1024> &rbuffer, handle_t *hand return 0; } else if (rbuffer[0] == DINIT_RP_NOSERVICE) { - cerr << "dinitctl: Failed to find/load service." << endl; + cerr << "dinitctl: failed to find/load service." << endl; return 1; } else { - cerr << "dinitctl: Protocol error." << endl; + cerr << "dinitctl: protocol error." << endl; return 1; } } -static int unpin_service(int socknum, const char *service_name, bool verbose) +static int unpin_service(int socknum, cpbuffer_t &rbuffer, const char *service_name, bool verbose) { using namespace std; @@ -467,57 +478,46 @@ static int unpin_service(int socknum, const char *service_name, bool verbose) // Now we expect a reply: - try { - cpbuffer<1024> rbuffer; - wait_for_reply(rbuffer, socknum); + wait_for_reply(rbuffer, socknum); + + handle_t handle; + + if (check_load_reply(socknum, rbuffer, &handle, nullptr) != 0) { + return 1; + } + + // Issue UNPIN command. + { + int r; - handle_t handle; + { + char *buf = new char[1 + sizeof(handle)]; + unique_ptr ubuf(buf); + buf[0] = DINIT_CP_UNPINSERVICE; + memcpy(buf + 1, &handle, sizeof(handle)); + r = write_all(socknum, buf, 2 + sizeof(handle)); + } - if (check_load_reply(socknum, rbuffer, &handle, nullptr) != 0) { + if (r == -1) { + perror("dinitctl: write"); return 1; } - // Issue UNPIN command. - { - int r; - - { - char *buf = new char[1 + sizeof(handle)]; - unique_ptr ubuf(buf); - buf[0] = DINIT_CP_UNPINSERVICE; - memcpy(buf + 1, &handle, sizeof(handle)); - r = write_all(socknum, buf, 2 + sizeof(handle)); - } - - if (r == -1) { - perror("dinitctl: write"); - return 1; - } - - wait_for_reply(rbuffer, socknum); - if (rbuffer[0] != DINIT_RP_ACK) { - cerr << "dinitctl: Protocol error." << endl; - return 1; - } - rbuffer.consume(1); + wait_for_reply(rbuffer, socknum); + if (rbuffer[0] != DINIT_RP_ACK) { + cerr << "dinitctl: protocol error." << endl; + return 1; } + rbuffer.consume(1); } - catch (read_cp_exception &exc) { - cerr << "dinitctl: Control socket read failure or protocol error" << endl; - return 1; - } - catch (std::bad_alloc &exc) { - cerr << "dinitctl: Out of memory" << endl; - return 1; - } - + if (verbose) { cout << "Service unpinned." << endl; } return 0; } -static int unload_service(int socknum, const char *service_name) +static int unload_service(int socknum, cpbuffer_t &rbuffer, const char *service_name) { using namespace std; @@ -528,132 +528,110 @@ static int unload_service(int socknum, const char *service_name) // Now we expect a reply: - try { - cpbuffer<1024> rbuffer; - wait_for_reply(rbuffer, socknum); + wait_for_reply(rbuffer, socknum); - handle_t handle; + handle_t handle; - if (check_load_reply(socknum, rbuffer, &handle, nullptr) != 0) { - return 1; - } + if (check_load_reply(socknum, rbuffer, &handle, nullptr) != 0) { + return 1; + } + + // Issue UNLOAD command. + { + int r; - // Issue UNLOAD command. { - int r; - - { - char *buf = new char[1 + sizeof(handle)]; - unique_ptr ubuf(buf); - buf[0] = DINIT_CP_UNLOADSERVICE; - memcpy(buf + 1, &handle, sizeof(handle)); - r = write_all(socknum, buf, 2 + sizeof(handle)); - } + char *buf = new char[1 + sizeof(handle)]; + unique_ptr ubuf(buf); + buf[0] = DINIT_CP_UNLOADSERVICE; + memcpy(buf + 1, &handle, sizeof(handle)); + r = write_all(socknum, buf, 2 + sizeof(handle)); + } - if (r == -1) { - perror("dinitctl: write"); - return 1; - } + if (r == -1) { + perror("dinitctl: write"); + return 1; + } - wait_for_reply(rbuffer, socknum); - if (rbuffer[0] == DINIT_RP_NAK) { - cerr << "dinitctl: Could not unload service; service not stopped, or is a dependency of " - "other service." << endl; - return 1; - } - if (rbuffer[0] != DINIT_RP_ACK) { - cerr << "dinitctl: Protocol error." << endl; - return 1; - } - rbuffer.consume(1); + wait_for_reply(rbuffer, socknum); + if (rbuffer[0] == DINIT_RP_NAK) { + cerr << "dinitctl: Could not unload service; service not stopped, or is a dependency of " + "other service." << endl; + return 1; } - } - catch (read_cp_exception &exc) { - cerr << "dinitctl: Control socket read failure or protocol error" << endl; - return 1; - } - catch (std::bad_alloc &exc) { - cerr << "dinitctl: Out of memory" << endl; - return 1; + if (rbuffer[0] != DINIT_RP_ACK) { + cerr << "dinitctl: Protocol error." << endl; + return 1; + } + rbuffer.consume(1); } cout << "Service unloaded." << endl; return 0; } -static int list_services(int socknum) +static int list_services(int socknum, cpbuffer_t &rbuffer) { using namespace std; - try { - char cmdbuf[] = { (char)DINIT_CP_LISTSERVICES }; - int r = write_all(socknum, cmdbuf, 1); + char cmdbuf[] = { (char)DINIT_CP_LISTSERVICES }; + int r = write_all(socknum, cmdbuf, 1); + + if (r == -1) { + perror("dinitctl: write"); + return 1; + } + + wait_for_reply(rbuffer, socknum); + while (rbuffer[0] == DINIT_RP_SVCINFO) { + fill_buffer_to(rbuffer, socknum, 8); + int nameLen = rbuffer[1]; + service_state_t current = static_cast(rbuffer[2]); + service_state_t target = static_cast(rbuffer[3]); + + fill_buffer_to(rbuffer, socknum, nameLen + 8); + + char *name_ptr = rbuffer.get_ptr(8); + int clength = std::min(rbuffer.get_contiguous_length(name_ptr), nameLen); + + string name = string(name_ptr, clength); + name.append(rbuffer.get_buf_base(), nameLen - clength); + + cout << "["; + + cout << (target == service_state_t::STARTED ? "{" : " "); + cout << (current == service_state_t::STARTED ? "+" : " "); + cout << (target == service_state_t::STARTED ? "}" : " "); - if (r == -1) { - perror("dinitctl: write"); - return 1; + if (current == service_state_t::STARTING) { + cout << "<<"; } - - cpbuffer<1024> rbuffer; - wait_for_reply(rbuffer, socknum); - while (rbuffer[0] == DINIT_RP_SVCINFO) { - fill_buffer_to(&rbuffer, socknum, 8); - int nameLen = rbuffer[1]; - service_state_t current = static_cast(rbuffer[2]); - service_state_t target = static_cast(rbuffer[3]); - - fill_buffer_to(&rbuffer, socknum, nameLen + 8); - - char *name_ptr = rbuffer.get_ptr(8); - int clength = std::min(rbuffer.get_contiguous_length(name_ptr), nameLen); - - string name = string(name_ptr, clength); - name.append(rbuffer.get_buf_base(), nameLen - clength); - - cout << "["; - - cout << (target == service_state_t::STARTED ? "{" : " "); - cout << (current == service_state_t::STARTED ? "+" : " "); - cout << (target == service_state_t::STARTED ? "}" : " "); - - if (current == service_state_t::STARTING) { - cout << "<<"; - } - else if (current == service_state_t::STOPPING) { - cout << ">>"; - } - else { - cout << " "; - } - - cout << (target == service_state_t::STOPPED ? "{" : " "); - cout << (current == service_state_t::STOPPED ? "-" : " "); - cout << (target == service_state_t::STOPPED ? "}" : " "); - - cout << "] " << name << endl; - - rbuffer.consume(8 + nameLen); - wait_for_reply(rbuffer, socknum); + else if (current == service_state_t::STOPPING) { + cout << ">>"; } - - if (rbuffer[0] != DINIT_RP_LISTDONE) { - cerr << "dinitctl: Control socket protocol error" << endl; - return 1; + else { + cout << " "; } + + cout << (target == service_state_t::STOPPED ? "{" : " "); + cout << (current == service_state_t::STOPPED ? "-" : " "); + cout << (target == service_state_t::STOPPED ? "}" : " "); + + cout << "] " << name << endl; + + rbuffer.consume(8 + nameLen); + wait_for_reply(rbuffer, socknum); } - catch (read_cp_exception &exc) { - cerr << "dinitctl: Control socket read failure or protocol error" << endl; - return 1; - } - catch (std::bad_alloc &exc) { - cerr << "dinitctl: Out of memory" << endl; + + if (rbuffer[0] != DINIT_RP_LISTDONE) { + cerr << "dinitctl: Control socket protocol error" << endl; return 1; } - + return 0; } -static int shutdown_dinit(int socknum) +static int shutdown_dinit(int socknum, cpbuffer_t &rbuffer) { // TODO support no-wait option. using namespace std; @@ -671,17 +649,10 @@ static int shutdown_dinit(int socknum) return 1; } - cpbuffer<1024> rbuffer; - try { - wait_for_reply(rbuffer, socknum); + wait_for_reply(rbuffer, socknum); - if (rbuffer[0] != DINIT_RP_ACK) { - cerr << "dinitctl: Control socket protocol error" << endl; - return 1; - } - } - catch (read_cp_exception &exc) { - cerr << "dinitctl: Control socket read failure or protocol error" << endl; + if (rbuffer[0] != DINIT_RP_ACK) { + cerr << "dinitctl: Control socket protocol error" << endl; return 1; } @@ -694,7 +665,7 @@ static int shutdown_dinit(int socknum) } } } - catch (read_cp_exception &exc) { + catch (cp_read_exception &exc) { // Dinit can terminate before replying: let's assume that happened. // TODO: better check, possibly ensure that dinit actually sends rollback complete before // termination. diff --git a/src/includes/dinit-client.h b/src/includes/dinit-client.h index 34963cf..e40ad6d 100644 --- a/src/includes/dinit-client.h +++ b/src/includes/dinit-client.h @@ -1,33 +1,53 @@ #include +#include // Client library for Dinit clients using handle_t = uint32_t; +using cpbuffer_t = cpbuffer<1024>; -class read_cp_exception +class cp_read_exception { public: int errcode; - read_cp_exception(int err) : errcode(err) { } + cp_read_exception(int err) : errcode(err) { } }; +class cp_write_exception +{ + public: + int errcode; + cp_write_exception(int err) : errcode(err) { } +}; + +class cp_old_client_exception +{ + // no body +}; + +class cp_old_server_exception +{ + // no body +}; + + // Fill a circular buffer from a file descriptor, until it contains at least _rlength_ bytes. -// Throws read_cp_exception if the requested number of bytes cannot be read, with: +// Throws cp_read_exception if the requested number of bytes cannot be read, with: // errcode = 0 if end of stream (remote end closed) // errcode = errno if another error occurred // Note that EINTR is ignored (i.e. the read will be re-tried). -inline void fill_buffer_to(cpbuffer<1024> *buf, int fd, int rlength) +inline void fill_buffer_to(cpbuffer_t &buf, int fd, int rlength) { do { - int r = buf->fill_to(fd, rlength); + int r = buf.fill_to(fd, rlength); if (r == -1) { if (errno != EINTR) { - throw read_cp_exception(errno); + throw cp_read_exception(errno); } } else if (r == 0) { - throw read_cp_exception(0); + throw cp_read_exception(0); } else { return; @@ -37,32 +57,32 @@ inline void fill_buffer_to(cpbuffer<1024> *buf, int fd, int rlength) } // Wait for a reply packet, skipping over any information packets that are received in the meantime. -inline void wait_for_reply(cpbuffer<1024> &rbuffer, int fd) +inline void wait_for_reply(cpbuffer_t &rbuffer, int fd) { - fill_buffer_to(&rbuffer, fd, 1); + fill_buffer_to(rbuffer, fd, 1); while (rbuffer[0] >= 100) { // Information packet; discard. - fill_buffer_to(&rbuffer, fd, 2); + fill_buffer_to(rbuffer, fd, 2); int pktlen = (unsigned char) rbuffer[1]; rbuffer.consume(1); // Consume one byte so we'll read one byte of the next packet - fill_buffer_to(&rbuffer, fd, pktlen); + fill_buffer_to(rbuffer, fd, pktlen); rbuffer.consume(pktlen - 1); } } -// Wait for an info packet. If any other reply packet comes, throw a read_cp_exception. -inline void wait_for_info(cpbuffer<1024> &rbuffer, int fd) +// Wait for an info packet. If any other reply packet comes, throw a cp_read_exception. +inline void wait_for_info(cpbuffer_t &rbuffer, int fd) { - fill_buffer_to(&rbuffer, fd, 2); + fill_buffer_to(rbuffer, fd, 2); if (rbuffer[0] < 100) { - throw read_cp_exception(0); + throw cp_read_exception(0); } int pktlen = (unsigned char) rbuffer[1]; - fill_buffer_to(&rbuffer, fd, pktlen); + fill_buffer_to(rbuffer, fd, pktlen); } // Write *all* the requested buffer and re-try if necessary until @@ -83,3 +103,45 @@ inline int write_all(int fd, const void *buf, size_t count) } return w; } + +// Check the protocol version is compatible with the client. +// minverison - minimum protocol version that client can speak +// version - maximum protocol version that client can speak +// rbuffer, fd - communication buffer and socket +// returns: the actual protocol version +// throws an exception on protocol mismatch or error. +uint16_t check_protocol_version(int minversion, int version, cpbuffer_t &rbuffer, int fd) +{ + constexpr int bufsize = 1; + char buf[bufsize] = { DINIT_CP_QUERYVERSION }; + int r = write_all(fd, buf, bufsize); + if (r == -1) { + throw cp_write_exception(errno); + } + + wait_for_reply(rbuffer, fd); + if (rbuffer[0] != DINIT_RP_CPVERSION) { + throw cp_read_exception{0}; + } + + // DINIT_RP_CVERSION, (2 byte) minimum compatible version, (2 byte) actual version + constexpr int rbufsize = 1 + 2 * sizeof(uint16_t); + fill_buffer_to(rbuffer, fd, rbufsize); + uint16_t rminversion; + uint16_t cpversion; + + rbuffer.extract(reinterpret_cast(&rminversion), 1, sizeof(uint16_t)); + rbuffer.extract(reinterpret_cast(&cpversion), 1 + sizeof(uint16_t), sizeof(uint16_t)); + rbuffer.consume(rbufsize); + + if (rminversion > version) { + // We are too old + throw cp_old_client_exception(); + } + if (cpversion < minversion) { + // Server is too old + throw cp_old_server_exception(); + } + + return cpversion; +} diff --git a/src/shutdown.cc b/src/shutdown.cc index 3f1a131..886826b 100644 --- a/src/shutdown.cc +++ b/src/shutdown.cc @@ -161,7 +161,7 @@ int main(int argc, char **argv) return 1; } } - catch (read_cp_exception &exc) + catch (cp_read_exception &exc) { cerr << "shutdown: control socket read failure or protocol error" << endl; return 1; -- 2.25.1