Implement SSL_read_ex() and SSL_write_ex() as documented.
authorKurt Roeckx <kurt@roeckx.be>
Sat, 19 Nov 2016 19:15:35 +0000 (20:15 +0100)
committerKurt Roeckx <kurt@roeckx.be>
Thu, 23 Feb 2017 19:40:05 +0000 (20:40 +0100)
Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Richard Levitte <levitte@openssl.org>
GH: #1964

include/openssl/ssl.h
ssl/bio_ssl.c
ssl/ssl_err.c
ssl/ssl_lib.c
ssl/ssl_locl.h

index d1614d3a73340e4ce305e35d64ad3aa9de33f928..b85ccc595dc1ab294d69551497fa63a938766afc 100644 (file)
@@ -2244,8 +2244,10 @@ int ERR_load_SSL_strings(void);
 # define SSL_F_SSL_PARSE_SERVERHELLO_USE_SRTP_EXT         311
 # define SSL_F_SSL_PEEK                                   270
 # define SSL_F_SSL_PEEK_EX                                432
+# define SSL_F_SSL_PEEK_INTERNAL                          521
 # define SSL_F_SSL_READ                                   223
 # define SSL_F_SSL_READ_EX                                434
+# define SSL_F_SSL_READ_INTERNAL                          519
 # define SSL_F_SSL_RENEGOTIATE                            516
 # define SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT                320
 # define SSL_F_SSL_SCAN_SERVERHELLO_TLSEXT                321
@@ -2284,6 +2286,7 @@ int ERR_load_SSL_strings(void);
 # define SSL_F_SSL_VERIFY_CERT_CHAIN                      207
 # define SSL_F_SSL_WRITE                                  208
 # define SSL_F_SSL_WRITE_EX                               433
+# define SSL_F_SSL_WRITE_INTERNAL                         520
 # define SSL_F_STATE_MACHINE                              353
 # define SSL_F_TLS12_CHECK_PEER_SIGALG                    333
 # define SSL_F_TLS13_CHANGE_CIPHER_STATE                  440
index e48b90fceaa0cfeb7859a8f2cf6ab8df0b16cce7..3d380c83fb6078df5f1a63164a42721539ce93ee 100644 (file)
@@ -103,7 +103,7 @@ static int ssl_read(BIO *b, char *buf, size_t size, size_t *readbytes)
 
     BIO_clear_retry_flags(b);
 
-    ret = SSL_read_ex(ssl, buf, size, readbytes);
+    ret = ssl_read_internal(ssl, buf, size, readbytes);
 
     switch (SSL_get_error(ssl, ret)) {
     case SSL_ERROR_NONE:
@@ -172,7 +172,7 @@ static int ssl_write(BIO *b, const char *buf, size_t size, size_t *written)
 
     BIO_clear_retry_flags(b);
 
-    ret = SSL_write_ex(ssl, buf, size, written);
+    ret = ssl_write_internal(ssl, buf, size, written);
 
     switch (SSL_get_error(ssl, ret)) {
     case SSL_ERROR_NONE:
index bcb9ddb4f729d4cf6f4c780685b5bba44bfa1830..addc3de3f834db12685cc586ae223f343fdb769d 100644 (file)
@@ -203,8 +203,10 @@ static ERR_STRING_DATA SSL_str_functs[] = {
      "ssl_parse_serverhello_use_srtp_ext"},
     {ERR_FUNC(SSL_F_SSL_PEEK), "SSL_peek"},
     {ERR_FUNC(SSL_F_SSL_PEEK_EX), "SSL_peek_ex"},
+    {ERR_FUNC(SSL_F_SSL_PEEK_INTERNAL), "ssl_peek_internal"},
     {ERR_FUNC(SSL_F_SSL_READ), "SSL_read"},
     {ERR_FUNC(SSL_F_SSL_READ_EX), "SSL_read_ex"},
+    {ERR_FUNC(SSL_F_SSL_READ_INTERNAL), "ssl_read_internal"},
     {ERR_FUNC(SSL_F_SSL_RENEGOTIATE), "SSL_renegotiate"},
     {ERR_FUNC(SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT),
      "ssl_scan_clienthello_tlsext"},
@@ -252,6 +254,7 @@ static ERR_STRING_DATA SSL_str_functs[] = {
     {ERR_FUNC(SSL_F_SSL_VERIFY_CERT_CHAIN), "ssl_verify_cert_chain"},
     {ERR_FUNC(SSL_F_SSL_WRITE), "SSL_write"},
     {ERR_FUNC(SSL_F_SSL_WRITE_EX), "SSL_write_ex"},
+    {ERR_FUNC(SSL_F_SSL_WRITE_INTERNAL), "ssl_write_internal"},
     {ERR_FUNC(SSL_F_STATE_MACHINE), "state_machine"},
     {ERR_FUNC(SSL_F_TLS12_CHECK_PEER_SIGALG), "tls12_check_peer_sigalg"},
     {ERR_FUNC(SSL_F_TLS13_CHANGE_CIPHER_STATE), "tls13_change_cipher_state"},
index 8304c732ae7673488149d3c1ba3a69d3be48d29d..72c101b5f0b18f5f193c5674eee2909b37802337 100644 (file)
@@ -1532,38 +1532,16 @@ static int ssl_io_intern(void *vargs)
     return -1;
 }
 
-int SSL_read(SSL *s, void *buf, int num)
-{
-    int ret;
-    size_t readbytes;
-
-    if (num < 0) {
-        SSLerr(SSL_F_SSL_READ, SSL_R_BAD_LENGTH);
-        return -1;
-    }
-
-    ret = SSL_read_ex(s, buf, (size_t)num, &readbytes);
-
-    /*
-     * The cast is safe here because ret should be <= INT_MAX because num is
-     * <= INT_MAX
-     */
-    if (ret > 0)
-        ret = (int)readbytes;
-
-    return ret;
-}
-
-int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
+int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
 {
     if (s->handshake_func == NULL) {
-        SSLerr(SSL_F_SSL_READ_EX, SSL_R_UNINITIALIZED);
+        SSLerr(SSL_F_SSL_READ_INTERNAL, SSL_R_UNINITIALIZED);
         return -1;
     }
 
     if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
         s->rwstate = SSL_NOTHING;
-        return (0);
+        return 0;
     }
 
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -1584,17 +1562,17 @@ int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
     }
 }
 
-int SSL_peek(SSL *s, void *buf, int num)
+int SSL_read(SSL *s, void *buf, int num)
 {
     int ret;
     size_t readbytes;
 
     if (num < 0) {
-        SSLerr(SSL_F_SSL_PEEK, SSL_R_BAD_LENGTH);
+        SSLerr(SSL_F_SSL_READ, SSL_R_BAD_LENGTH);
         return -1;
     }
 
-    ret = SSL_peek_ex(s, buf, (size_t)num, &readbytes);
+    ret = ssl_read_internal(s, buf, (size_t)num, &readbytes);
 
     /*
      * The cast is safe here because ret should be <= INT_MAX because num is
@@ -1606,15 +1584,24 @@ int SSL_peek(SSL *s, void *buf, int num)
     return ret;
 }
 
-int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
+int SSL_read_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
+{
+    int ret = ssl_read_internal(s, buf, num, readbytes);
+
+    if (ret < 0)
+        ret = 0;
+    return ret;
+}
+
+static int ssl_peek_internal(SSL *s, void *buf, size_t num, size_t *readbytes)
 {
     if (s->handshake_func == NULL) {
-        SSLerr(SSL_F_SSL_PEEK_EX, SSL_R_UNINITIALIZED);
+        SSLerr(SSL_F_SSL_PEEK_INTERNAL, SSL_R_UNINITIALIZED);
         return -1;
     }
 
     if (s->shutdown & SSL_RECEIVED_SHUTDOWN) {
-        return (0);
+        return 0;
     }
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
         struct ssl_async_args args;
@@ -1634,39 +1621,49 @@ int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
     }
 }
 
-int SSL_write(SSL *s, const void *buf, int num)
+int SSL_peek(SSL *s, void *buf, int num)
 {
     int ret;
-    size_t written;
+    size_t readbytes;
 
     if (num < 0) {
-        SSLerr(SSL_F_SSL_WRITE, SSL_R_BAD_LENGTH);
+        SSLerr(SSL_F_SSL_PEEK, SSL_R_BAD_LENGTH);
         return -1;
     }
 
-    ret = SSL_write_ex(s, buf, (size_t)num, &written);
+    ret = ssl_peek_internal(s, buf, (size_t)num, &readbytes);
 
     /*
      * The cast is safe here because ret should be <= INT_MAX because num is
      * <= INT_MAX
      */
     if (ret > 0)
-        ret = (int)written;
+        ret = (int)readbytes;
 
     return ret;
 }
 
-int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
+
+int SSL_peek_ex(SSL *s, void *buf, size_t num, size_t *readbytes)
+{
+    int ret = ssl_peek_internal(s, buf, num, readbytes);
+
+    if (ret < 0)
+        ret = 0;
+    return ret;
+}
+
+int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written)
 {
     if (s->handshake_func == NULL) {
-        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_UNINITIALIZED);
+        SSLerr(SSL_F_SSL_WRITE_INTERNAL, SSL_R_UNINITIALIZED);
         return -1;
     }
 
     if (s->shutdown & SSL_SENT_SHUTDOWN) {
         s->rwstate = SSL_NOTHING;
-        SSLerr(SSL_F_SSL_WRITE_EX, SSL_R_PROTOCOL_IS_SHUTDOWN);
-        return (-1);
+        SSLerr(SSL_F_SSL_WRITE_INTERNAL, SSL_R_PROTOCOL_IS_SHUTDOWN);
+        return -1;
     }
 
     if ((s->mode & SSL_MODE_ASYNC) && ASYNC_get_current_job() == NULL) {
@@ -1687,6 +1684,37 @@ int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
     }
 }
 
+int SSL_write(SSL *s, const void *buf, int num)
+{
+    int ret;
+    size_t written;
+
+    if (num < 0) {
+        SSLerr(SSL_F_SSL_WRITE, SSL_R_BAD_LENGTH);
+        return -1;
+    }
+
+    ret = ssl_write_internal(s, buf, (size_t)num, &written);
+
+    /*
+     * The cast is safe here because ret should be <= INT_MAX because num is
+     * <= INT_MAX
+     */
+    if (ret > 0)
+        ret = (int)written;
+
+    return ret;
+}
+
+int SSL_write_ex(SSL *s, const void *buf, size_t num, size_t *written)
+{
+    int ret = ssl_write_internal(s, buf, num, written);
+
+    if (ret < 0)
+        ret = 0;
+    return ret;
+}
+
 int SSL_shutdown(SSL *s)
 {
     /*
index 89eeb453531817ca2df8a5c5c9f17d2a161d1eaa..26464a584d28edeec69c2e16272064d2e08d4c6e 100644 (file)
@@ -1979,6 +1979,8 @@ static ossl_inline int ssl_has_cert(const SSL *s, int idx)
 
 # ifndef OPENSSL_UNIT_TEST
 
+__owur int ssl_read_internal(SSL *s, void *buf, size_t num, size_t *readbytes);
+__owur int ssl_write_internal(SSL *s, const void *buf, size_t num, size_t *written);
 void ssl_clear_cipher_ctx(SSL *s);
 int ssl_clear_bad_session(SSL *s);
 __owur CERT *ssl_cert_new(void);
@@ -2384,7 +2386,7 @@ void custom_exts_free(custom_ext_methods *exts);
 
 void ssl_comp_free_compression_methods_int(void);
 
-# else
+# else /* OPENSSL_UNIT_TEST */
 
 #  define ssl_init_wbio_buffer SSL_test_functions()->p_ssl_init_wbio_buffer
 #  define ssl3_setup_buffers SSL_test_functions()->p_ssl3_setup_buffers