From b77f3ed17134fe6bf99d143abb1aec3f2bfac555 Mon Sep 17 00:00:00 2001 From: Matt Caswell Date: Mon, 22 May 2017 12:33:42 +0100 Subject: [PATCH] Convert existing usage of assert() to ossl_assert() in libssl Provides consistent output and approach. Reviewed-by: Tim Hudson (Merged from https://github.com/openssl/openssl/pull/3496) --- ssl/d1_lib.c | 12 +++++++++--- ssl/packet.c | 33 +++++++++++---------------------- ssl/record/rec_layer_d1.c | 4 +--- ssl/record/rec_layer_s3.c | 1 - ssl/record/ssl3_record.c | 4 +--- ssl/record/ssl3_record_tls13.c | 4 +--- ssl/s3_cbc.c | 12 +++++------- ssl/s3_lib.c | 17 ++++++++++++----- ssl/ssl_init.c | 1 - ssl/ssl_lib.c | 17 +++++++++++------ ssl/ssl_locl.h | 10 +++++----- ssl/statem/extensions_clnt.c | 13 ++++++++----- ssl/statem/extensions_cust.c | 6 ++++-- ssl/statem/statem_lib.c | 3 ++- ssl/t1_lib.c | 16 +++++++++++----- 15 files changed, 81 insertions(+), 72 deletions(-) diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c index 448b2eb1d4..150e875162 100644 --- a/ssl/d1_lib.c +++ b/ssl/d1_lib.c @@ -104,7 +104,10 @@ int dtls1_new(SSL *s) } s->d1 = d1; - s->method->ssl_clear(s); + + if (!s->method->ssl_clear(s)) + return 0; + return 1; } @@ -154,7 +157,7 @@ void dtls1_free(SSL *s) s->d1 = NULL; } -void dtls1_clear(SSL *s) +int dtls1_clear(SSL *s) { pqueue *buffered_messages; pqueue *sent_messages; @@ -186,7 +189,8 @@ void dtls1_clear(SSL *s) s->d1->sent_messages = sent_messages; } - ssl3_clear(s); + if (!ssl3_clear(s)) + return 0; if (s->method->version == DTLS_ANY_VERSION) s->version = DTLS_MAX_VERSION; @@ -196,6 +200,8 @@ void dtls1_clear(SSL *s) #endif else s->version = s->method->version; + + return 1; } long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg) diff --git a/ssl/packet.c b/ssl/packet.c index d081f557e7..7c4be4fee5 100644 --- a/ssl/packet.c +++ b/ssl/packet.c @@ -7,7 +7,7 @@ * https://www.openssl.org/source/license.html */ -#include +#include "e_os.h" #include "packet_locl.h" #define DEFAULT_BUF_SIZE 256 @@ -39,8 +39,7 @@ int WPACKET_sub_allocate_bytes__(WPACKET *pkt, size_t len, int WPACKET_reserve_bytes(WPACKET *pkt, size_t len, unsigned char **allocbytes) { /* Internal API, so should not fail */ - assert(pkt->subs != NULL && len != 0); - if (pkt->subs == NULL || len == 0) + if (!ossl_assert(pkt->subs != NULL && len != 0)) return 0; if (pkt->maxsize - pkt->written < len) @@ -120,8 +119,7 @@ int WPACKET_init_static_len(WPACKET *pkt, unsigned char *buf, size_t len, size_t max = maxmaxsize(lenbytes); /* Internal API, so should not fail */ - assert(buf != NULL && len > 0); - if (buf == NULL || len == 0) + if (!ossl_assert(buf != NULL && len > 0)) return 0; pkt->staticbuf = buf; @@ -134,8 +132,7 @@ int WPACKET_init_static_len(WPACKET *pkt, unsigned char *buf, size_t len, int WPACKET_init_len(WPACKET *pkt, BUF_MEM *buf, size_t lenbytes) { /* Internal API, so should not fail */ - assert(buf != NULL); - if (buf == NULL) + if (!ossl_assert(buf != NULL)) return 0; pkt->staticbuf = NULL; @@ -153,8 +150,7 @@ int WPACKET_init(WPACKET *pkt, BUF_MEM *buf) int WPACKET_set_flags(WPACKET *pkt, unsigned int flags) { /* Internal API, so should not fail */ - assert(pkt->subs != NULL); - if (pkt->subs == NULL) + if (!ossl_assert(pkt->subs != NULL)) return 0; pkt->subs->flags = flags; @@ -228,8 +224,7 @@ int WPACKET_fill_lengths(WPACKET *pkt) { WPACKET_SUB *sub; - assert(pkt->subs != NULL); - if (pkt->subs == NULL) + if (!ossl_assert(pkt->subs != NULL)) return 0; for (sub = pkt->subs; sub != NULL; sub = sub->parent) { @@ -278,8 +273,7 @@ int WPACKET_start_sub_packet_len__(WPACKET *pkt, size_t lenbytes) unsigned char *lenchars; /* Internal API, so should not fail */ - assert(pkt->subs != NULL); - if (pkt->subs == NULL) + if (!ossl_assert(pkt->subs != NULL)) return 0; sub = OPENSSL_zalloc(sizeof(*sub)); @@ -314,9 +308,7 @@ int WPACKET_put_bytes__(WPACKET *pkt, unsigned int val, size_t size) unsigned char *data; /* Internal API, so should not fail */ - assert(size <= sizeof(unsigned int)); - - if (size > sizeof(unsigned int) + if (!ossl_assert(size <= sizeof(unsigned int)) || !WPACKET_allocate_bytes(pkt, size, &data) || !put_value(data, val, size)) return 0; @@ -330,8 +322,7 @@ int WPACKET_set_max_size(WPACKET *pkt, size_t maxsize) size_t lenbytes; /* Internal API, so should not fail */ - assert(pkt->subs != NULL); - if (pkt->subs == NULL) + if (!ossl_assert(pkt->subs != NULL)) return 0; /* Find the WPACKET_SUB for the top level */ @@ -394,8 +385,7 @@ int WPACKET_sub_memcpy__(WPACKET *pkt, const void *src, size_t len, int WPACKET_get_total_written(WPACKET *pkt, size_t *written) { /* Internal API, so should not fail */ - assert(written != NULL); - if (written == NULL) + if (!ossl_assert(written != NULL)) return 0; *written = pkt->written; @@ -406,8 +396,7 @@ int WPACKET_get_total_written(WPACKET *pkt, size_t *written) int WPACKET_get_length(WPACKET *pkt, size_t *len) { /* Internal API, so should not fail */ - assert(pkt->subs != NULL && len != NULL); - if (pkt->subs == NULL || len == NULL) + if (!ossl_assert(pkt->subs != NULL && len != NULL)) return 0; *len = pkt->written - pkt->subs->pwritten; diff --git a/ssl/record/rec_layer_d1.c b/ssl/record/rec_layer_d1.c index 879a9b039c..9f80050f01 100644 --- a/ssl/record/rec_layer_d1.c +++ b/ssl/record/rec_layer_d1.c @@ -14,7 +14,6 @@ #include #include #include "record_locl.h" -#include #include "../packet_locl.h" int DTLS_RECORD_LAYER_new(RECORD_LAYER *rl) @@ -645,8 +644,7 @@ int dtls1_read_bytes(SSL *s, int type, int *recvd_type, unsigned char *buf, * (which is tested for at the top of this function) then init must be * finished */ - assert(SSL_is_init_finished(s)); - if (!SSL_is_init_finished(s)) { + if (!ossl_assert(SSL_is_init_finished(s))) { al = SSL_AD_INTERNAL_ERROR; SSLerr(SSL_F_DTLS1_READ_BYTES, ERR_R_INTERNAL_ERROR); goto f_err; diff --git a/ssl/record/rec_layer_s3.c b/ssl/record/rec_layer_s3.c index dabb02cf1b..fbabdf48c5 100644 --- a/ssl/record/rec_layer_s3.c +++ b/ssl/record/rec_layer_s3.c @@ -8,7 +8,6 @@ */ #include -#include #include #include #define USE_SOCKETS diff --git a/ssl/record/ssl3_record.c b/ssl/record/ssl3_record.c index bba0cc0f0b..bd27290aa8 100644 --- a/ssl/record/ssl3_record.c +++ b/ssl/record/ssl3_record.c @@ -7,7 +7,6 @@ * https://www.openssl.org/source/license.html */ -#include #include "../ssl_locl.h" #include "internal/constant_time_locl.h" #include @@ -436,8 +435,7 @@ int ssl3_get_record(SSL *s) unsigned char *mac; /* TODO(size_t): convert this to do size_t properly */ imac_size = EVP_MD_CTX_size(s->read_hash); - assert(imac_size >= 0 && imac_size <= EVP_MAX_MD_SIZE); - if (imac_size < 0 || imac_size > EVP_MAX_MD_SIZE) { + if (!ossl_assert(imac_size >= 0 && imac_size <= EVP_MAX_MD_SIZE)) { al = SSL_AD_INTERNAL_ERROR; SSLerr(SSL_F_SSL3_GET_RECORD, ERR_LIB_EVP); goto f_err; diff --git a/ssl/record/ssl3_record_tls13.c b/ssl/record/ssl3_record_tls13.c index 9e65852f9d..61e209254e 100644 --- a/ssl/record/ssl3_record_tls13.c +++ b/ssl/record/ssl3_record_tls13.c @@ -7,7 +7,6 @@ * https://www.openssl.org/source/license.html */ -#include #include "../ssl_locl.h" #include "record_locl.h" @@ -64,8 +63,7 @@ int tls13_enc(SSL *s, SSL3_RECORD *recs, size_t n_recs, int sending) * To get here we must have selected a ciphersuite - otherwise ctx would * be NULL */ - assert(s->s3->tmp.new_cipher != NULL); - if (s->s3->tmp.new_cipher == NULL) + if (!ossl_assert(s->s3->tmp.new_cipher != NULL)) return -1; alg_enc = s->s3->tmp.new_cipher->algorithm_enc; } diff --git a/ssl/s3_cbc.c b/ssl/s3_cbc.c index f8d7aed3e1..0981360e0b 100644 --- a/ssl/s3_cbc.c +++ b/ssl/s3_cbc.c @@ -7,7 +7,6 @@ * https://www.openssl.org/source/license.html */ -#include #include "internal/constant_time_locl.h" #include "ssl_locl.h" @@ -229,15 +228,14 @@ int ssl3_cbc_digest_record(const EVP_MD_CTX *ctx, * ssl3_cbc_record_digest_supported should have been called first to * check that the hash function is supported. */ - assert(0); - if (md_out_size) + if (md_out_size != NULL) *md_out_size = 0; - return 0; + return ossl_assert(0); } - if (!ossl_assert(md_length_size <= MAX_HASH_BIT_COUNT_BYTES - && md_block_size <= MAX_HASH_BLOCK_SIZE - && md_size <= EVP_MAX_MD_SIZE)) + if (!ossl_assert(md_length_size <= MAX_HASH_BIT_COUNT_BYTES) + || !ossl_assert(md_block_size <= MAX_HASH_BLOCK_SIZE) + || !ossl_assert(md_size <= EVP_MAX_MD_SIZE)) return 0; header_length = 13; diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c index 5a8f9f46ac..2165f62a7c 100644 --- a/ssl/s3_lib.c +++ b/ssl/s3_lib.c @@ -48,7 +48,6 @@ */ #include -#include #include #include "ssl_locl.h" #include @@ -2914,7 +2913,10 @@ int ssl3_new(SSL *s) if (!SSL_SRP_CTX_init(s)) goto err; #endif - s->method->ssl_clear(s); + + if (!s->method->ssl_clear(s)) + return 0; + return 1; err: return 0; @@ -2950,7 +2952,7 @@ void ssl3_free(SSL *s) s->s3 = NULL; } -void ssl3_clear(SSL *s) +int ssl3_clear(SSL *s) { ssl3_cleanup_key_block(s); OPENSSL_free(s->s3->tmp.ctype); @@ -2972,7 +2974,8 @@ void ssl3_clear(SSL *s) /* NULL/zero-out everything in the s3 struct */ memset(s->s3, 0, sizeof(*s->s3)); - ssl_free_wbio_buffer(s); + if (!ssl_free_wbio_buffer(s)) + return 0; s->version = SSL3_VERSION; @@ -2981,6 +2984,8 @@ void ssl3_clear(SSL *s) s->ext.npn = NULL; s->ext.npn_len = 0; #endif + + return 1; } #ifndef OPENSSL_NO_SRP @@ -4038,7 +4043,9 @@ int ssl_fill_hello_random(SSL *s, int server, unsigned char *result, size_t len, } #ifndef OPENSSL_NO_TLS13DOWNGRADE if (ret) { - assert(sizeof(tls11downgrade) < len && sizeof(tls12downgrade) < len); + if (!ossl_assert(sizeof(tls11downgrade) < len) + || !ossl_assert(sizeof(tls12downgrade) < len)) + return 0; if (dgrd == DOWNGRADE_TO_1_2) memcpy(result + len - sizeof(tls12downgrade), tls12downgrade, sizeof(tls12downgrade)); diff --git a/ssl/ssl_init.c b/ssl/ssl_init.c index 1d9b799658..478a48e9d6 100644 --- a/ssl/ssl_init.c +++ b/ssl/ssl_init.c @@ -12,7 +12,6 @@ #include "internal/err.h" #include #include -#include #include "ssl_locl.h" #include "internal/thread_once.h" diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index cba03bdc15..028b69da08 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c @@ -39,7 +39,6 @@ * OTHERWISE. */ -#include #include #include "ssl_locl.h" #include @@ -493,8 +492,10 @@ int SSL_clear(SSL *s) s->method = s->ctx->method; if (!s->method->ssl_new(s)) return 0; - } else - s->method->ssl_clear(s); + } else { + if (!s->method->ssl_clear(s)) + return 0; + } RECORD_LAYER_clear(&s->rlayer); @@ -981,6 +982,7 @@ void SSL_free(SSL *s) dane_final(&s->dane); CRYPTO_free_ex_data(CRYPTO_EX_INDEX_SSL, s, &s->ex_data); + /* Ignore return value */ ssl_free_wbio_buffer(s); BIO_free_all(s->wbio); @@ -3529,16 +3531,19 @@ int ssl_init_wbio_buffer(SSL *s) return 1; } -void ssl_free_wbio_buffer(SSL *s) +int ssl_free_wbio_buffer(SSL *s) { /* callers ensure s is never null */ if (s->bbio == NULL) - return; + return 1; s->wbio = BIO_pop(s->wbio); - assert(s->wbio != NULL); + if (!ossl_assert(s->wbio != NULL)) + return 0; BIO_free(s->bbio); s->bbio = NULL; + + return 1; } void SSL_CTX_set_quiet_shutdown(SSL_CTX *ctx, int mode) diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h index b0932b0bc6..fe7f7b4e0f 100644 --- a/ssl/ssl_locl.h +++ b/ssl/ssl_locl.h @@ -452,7 +452,7 @@ struct ssl_method_st { unsigned flags; unsigned long mask; int (*ssl_new) (SSL *s); - void (*ssl_clear) (SSL *s); + int (*ssl_clear) (SSL *s); void (*ssl_free) (SSL *s); int (*ssl_accept) (SSL *s); int (*ssl_connect) (SSL *s); @@ -2181,7 +2181,7 @@ __owur int ssl3_read(SSL *s, void *buf, size_t len, size_t *readbytes); __owur int ssl3_peek(SSL *s, void *buf, size_t len, size_t *readbytes); __owur int ssl3_write(SSL *s, const void *buf, size_t len, size_t *written); __owur int ssl3_shutdown(SSL *s); -void ssl3_clear(SSL *s); +int ssl3_clear(SSL *s); __owur long ssl3_ctrl(SSL *s, int cmd, long larg, void *parg); __owur long ssl3_ctx_ctrl(SSL_CTX *s, int cmd, long larg, void *parg); __owur long ssl3_callback_ctrl(SSL *s, int cmd, void (*fp) (void)); @@ -2246,20 +2246,20 @@ __owur int dtls1_query_mtu(SSL *s); __owur int tls1_new(SSL *s); void tls1_free(SSL *s); -void tls1_clear(SSL *s); +int tls1_clear(SSL *s); long tls1_ctrl(SSL *s, int cmd, long larg, void *parg); long tls1_callback_ctrl(SSL *s, int cmd, void (*fp) (void)); __owur int dtls1_new(SSL *s); void dtls1_free(SSL *s); -void dtls1_clear(SSL *s); +int dtls1_clear(SSL *s); long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg); __owur int dtls1_shutdown(SSL *s); __owur int dtls1_dispatch_alert(SSL *s); __owur int ssl_init_wbio_buffer(SSL *s); -void ssl_free_wbio_buffer(SSL *s); +int ssl_free_wbio_buffer(SSL *s); __owur int tls1_change_cipher_state(SSL *s, int which); __owur int tls1_setup_key_block(SSL *s); diff --git a/ssl/statem/extensions_clnt.c b/ssl/statem/extensions_clnt.c index c5f8d3d1e5..8aa795e997 100644 --- a/ssl/statem/extensions_clnt.c +++ b/ssl/statem/extensions_clnt.c @@ -7,7 +7,6 @@ * https://www.openssl.org/source/license.html */ -#include #include #include "../ssl_locl.h" #include "statem_locl.h" @@ -541,8 +540,7 @@ static int add_key_share(SSL *s, WPACKET *pkt, unsigned int curve_id) size_t encodedlen; if (s->s3->tmp.pkey != NULL) { - assert(s->hello_retry_request); - if (!s->hello_retry_request) { + if (!ossl_assert(s->hello_retry_request)) { SSLerr(SSL_F_ADD_KEY_SHARE, ERR_R_INTERNAL_ERROR); return 0; } @@ -923,8 +921,13 @@ int tls_parse_stoc_renegotiate(SSL *s, PACKET *pkt, unsigned int context, const unsigned char *data; /* Check for logic errors */ - assert(expected_len == 0 || s->s3->previous_client_finished_len != 0); - assert(expected_len == 0 || s->s3->previous_server_finished_len != 0); + if (!ossl_assert(expected_len == 0 + || s->s3->previous_client_finished_len != 0) + || !ossl_assert(expected_len == 0 + || s->s3->previous_server_finished_len != 0)) { + *al = SSL_AD_INTERNAL_ERROR; + return 0; + } /* Parse the length byte */ if (!PACKET_get_1_len(pkt, &ilen)) { diff --git a/ssl/statem/extensions_cust.c b/ssl/statem/extensions_cust.c index e06fa9d1d7..cd63d04b00 100644 --- a/ssl/statem/extensions_cust.c +++ b/ssl/statem/extensions_cust.c @@ -9,7 +9,6 @@ /* Custom extension utility functions */ -#include #include #include "../ssl_locl.h" #include "statem_locl.h" @@ -217,7 +216,10 @@ int custom_ext_add(SSL *s, int context, WPACKET *pkt, X509 *x, size_t chainidx, /* * We can't send duplicates: code logic should prevent this. */ - assert((meth->ext_flags & SSL_EXT_FLAG_SENT) == 0); + if (!ossl_assert((meth->ext_flags & SSL_EXT_FLAG_SENT) == 0)) { + *al = SSL_AD_INTERNAL_ERROR; + return 0; + } /* * Indicate extension has been sent: this is both a sanity check to * ensure we don't send duplicate extensions and indicates that it diff --git a/ssl/statem/statem_lib.c b/ssl/statem/statem_lib.c index 7aedc41691..fbf5a3cc69 100644 --- a/ssl/statem/statem_lib.c +++ b/ssl/statem/statem_lib.c @@ -995,7 +995,8 @@ WORK_STATE tls_finish_handshake(SSL *s, WORK_STATE wst, int clearbufs) BUF_MEM_free(s->init_buf); s->init_buf = NULL; } - ssl_free_wbio_buffer(s); + if (!ssl_free_wbio_buffer(s)) + return WORK_ERROR; s->init_num = 0; } diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c index 232bb41fe0..c185a09e9c 100644 --- a/ssl/t1_lib.c +++ b/ssl/t1_lib.c @@ -101,9 +101,11 @@ long tls1_default_timeout(void) int tls1_new(SSL *s) { if (!ssl3_new(s)) - return (0); - s->method->ssl_clear(s); - return (1); + return 0; + if (!s->method->ssl_clear(s)) + return 0; + + return 1; } void tls1_free(SSL *s) @@ -112,13 +114,17 @@ void tls1_free(SSL *s) ssl3_free(s); } -void tls1_clear(SSL *s) +int tls1_clear(SSL *s) { - ssl3_clear(s); + if (!ssl3_clear(s)) + return 0; + if (s->method->version == TLS_ANY_VERSION) s->version = TLS_MAX_VERSION; else s->version = s->method->version; + + return 1; } #ifndef OPENSSL_NO_EC -- 2.25.1