Add some tests for the new custom extensions API
authorMatt Caswell <matt@openssl.org>
Wed, 5 Apr 2017 16:29:47 +0000 (17:29 +0100)
committerMatt Caswell <matt@openssl.org>
Fri, 7 Apr 2017 12:41:04 +0000 (13:41 +0100)
Reviewed-by: Rich Salz <rsalz@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/3139)

test/sslapitest.c

index ed804c9da39a813f27b748558ec40815d768ed6c..0e2bdbefaff2a2c2b22f07bed2456b9ab6c64cad 100644 (file)
@@ -2250,6 +2250,313 @@ static int test_early_data_tls1_2(int idx)
 # endif
 #endif
 
+static int clntaddoldcb = 0;
+static int clntparseoldcb = 0;
+static int srvaddoldcb = 0;
+static int srvparseoldcb = 0;
+static int clntaddnewcb = 0;
+static int clntparsenewcb = 0;
+static int srvaddnewcb = 0;
+static int srvparsenewcb = 0;
+
+#define TEST_EXT_TYPE1  0xff00
+
+static int old_add_cb(SSL *s, unsigned int ext_type, const unsigned char **out,
+                      size_t *outlen, int *al, void *add_arg)
+{
+    int *server = (int *)add_arg;
+    unsigned char *data;
+
+    if (SSL_is_server(s))
+        srvaddoldcb++;
+    else
+        clntaddoldcb++;
+
+    if (*server != SSL_is_server(s))
+        return -1;
+
+    data = OPENSSL_malloc(sizeof(char));
+    if (data == NULL)
+        return -1;
+
+    *data = 1;
+    *out = data;
+    *outlen = sizeof(char);
+
+    return 1;
+}
+
+static void old_free_cb(SSL *s, unsigned int ext_type, const unsigned char *out,
+                        void *add_arg)
+{
+    OPENSSL_free((unsigned char *)out);
+}
+
+static int old_parse_cb(SSL *s, unsigned int ext_type, const unsigned char *in,
+                        size_t inlen, int *al, void *parse_arg)
+{
+    int *server = (int *)parse_arg;
+
+    if (SSL_is_server(s))
+        srvparseoldcb++;
+    else
+        clntparseoldcb++;
+
+    if (*server != SSL_is_server(s))
+        return -1;
+
+    if (inlen != sizeof(char) || *in != 1)
+        return -1;
+
+    return 1;
+}
+
+static int new_add_cb(SSL *s, unsigned int ext_type, unsigned int context,
+                      const unsigned char **out, size_t *outlen, X509 *x,
+                      size_t chainidx, int *al, void *add_arg)
+{
+    int *server = (int *)add_arg;
+    unsigned char *data;
+
+    if (SSL_is_server(s))
+        srvaddnewcb++;
+    else
+        clntaddnewcb++;
+
+    if (*server != SSL_is_server(s))
+        return -1;
+
+    data = OPENSSL_malloc(sizeof(char));
+    if (data == NULL)
+        return -1;
+
+    *data = 1;
+    *out = data;
+    *outlen = sizeof(char);
+
+    return 1;
+}
+
+static void new_free_cb(SSL *s, unsigned int ext_type, unsigned int context,
+                        const unsigned char *out, void *add_arg)
+{
+    OPENSSL_free((unsigned char *)out);
+}
+
+static int new_parse_cb(SSL *s, unsigned int ext_type, unsigned int context,
+                        const unsigned char *in, size_t inlen, X509 *x,
+                        size_t chainidx, int *al, void *parse_arg)
+{
+    int *server = (int *)parse_arg;
+
+    if (SSL_is_server(s))
+        srvparsenewcb++;
+    else
+        clntparsenewcb++;
+
+    if (*server != SSL_is_server(s))
+        return -1;
+
+    if (inlen != sizeof(char) || *in != 1)
+        return -1;
+
+    return 1;
+}
+/*
+ * Custom call back tests.
+ * Test 0: Old style callbacks in TLSv1.2
+ * Test 1: New style callbacks in TLSv1.2
+ * Test 2: New style callbacks in TLSv1.3. Extensions in CH and EE
+ * Test 3: New style callbacks in TLSv1.3. Extensions in CH, SH, EE, Cert + NST
+ */
+static int test_custom_exts(int tst) {
+    SSL_CTX *cctx = NULL, *sctx = NULL;
+    SSL *clientssl = NULL, *serverssl = NULL;
+    int testresult = 0;
+    static int server = 1;
+    static int client = 0;
+    SSL_SESSION *sess = NULL;
+    unsigned int context;
+
+    /* Reset callback counters */
+    clntaddoldcb = clntparseoldcb = srvaddoldcb = srvparseoldcb = 0;
+    clntaddnewcb = clntparsenewcb = srvaddnewcb = srvparsenewcb = 0;
+
+    if (!create_ssl_ctx_pair(TLS_server_method(), TLS_client_method(), &sctx,
+                             &cctx, cert, privkey)) {
+        printf("Unable to create SSL_CTX pair\n");
+        return 0;
+    }
+
+    if (tst < 2) {
+        SSL_CTX_set_options(cctx, SSL_OP_NO_TLSv1_3);
+        SSL_CTX_set_options(sctx, SSL_OP_NO_TLSv1_3);
+    }
+
+    if (tst == 3) {
+        context = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_2_SERVER_HELLO
+                  | SSL_EXT_TLS1_3_SERVER_HELLO
+                  | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS
+                  | SSL_EXT_TLS1_3_CERTIFICATE
+                  | SSL_EXT_TLS1_3_NEW_SESSION_TICKET;
+    } else {
+        context = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_2_SERVER_HELLO
+                  | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS;
+    }
+
+    /* Create a client side custom extension */
+    if (tst == 0) {
+        if (!SSL_CTX_add_client_custom_ext(cctx, TEST_EXT_TYPE1, old_add_cb,
+                                           old_free_cb, &client, old_parse_cb,
+                                           &client)) {
+            printf("Unable to create old style client side custom extension\n");
+            return 0;
+        }
+    } else {
+        if (!SSL_CTX_add_custom_ext(cctx, TEST_EXT_TYPE1, context, new_add_cb,
+                                    new_free_cb, &client, new_parse_cb,
+                                    &client)) {
+            printf("Unable to create new style client side custom extension\n");
+            return 0;
+        }
+    }
+
+    /* Should not be able to add duplicates */
+    if (SSL_CTX_add_client_custom_ext(cctx, TEST_EXT_TYPE1, old_add_cb,
+                                      old_free_cb, &client, old_parse_cb,
+                                      &client)) {
+        printf("Unexpected success adding duplicate client custom extension\n");
+        return 0;
+    }
+    if (SSL_CTX_add_custom_ext(cctx, TEST_EXT_TYPE1, context, new_add_cb,
+                               new_free_cb, &client, new_parse_cb, &client)) {
+        printf("Unexpected success adding duplicate client custom extension\n");
+        return 0;
+    }
+
+    /* Create a server side custom extension */
+    if (tst == 0) {
+        if (!SSL_CTX_add_server_custom_ext(sctx, TEST_EXT_TYPE1, old_add_cb,
+                                           old_free_cb, &server, old_parse_cb,
+                                           &server)) {
+            printf("Unable to create old style server side custom extension\n");
+            return 0;
+        }
+    } else {
+        if (!SSL_CTX_add_custom_ext(sctx, TEST_EXT_TYPE1, context, new_add_cb,
+                                    new_free_cb, &server, new_parse_cb,
+                                    &server)) {
+            printf("Unable to create new style server side custom extension\n");
+            return 0;
+        }
+    }
+
+    /* Should not be able to add duplicates */
+    if (SSL_CTX_add_server_custom_ext(sctx, TEST_EXT_TYPE1, old_add_cb,
+                                      old_free_cb, &server, old_parse_cb,
+                                      &server)) {
+        printf("Unexpected success adding duplicate server custom extension\n");
+        return 0;
+    }
+    if (SSL_CTX_add_custom_ext(sctx, TEST_EXT_TYPE1, context, new_add_cb,
+                               new_free_cb, &server, new_parse_cb, &server)) {
+        printf("Unexpected success adding duplicate server custom extension\n");
+        return 0;
+    }
+
+    if (!create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL, NULL)) {
+        printf("Unable to create SSL objects\n");
+        goto end;
+    }
+
+    if (!create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) {
+        printf("Unable to create SSL connection\n");
+        goto end;
+    }
+
+    if (tst == 0) {
+        if (clntaddoldcb != 1 || clntparseoldcb != 1 || srvaddoldcb != 1
+                || srvparseoldcb != 1) {
+            printf("Custom extension callbacks not called\n");
+            goto end;
+        }
+    } else if (tst == 1 || tst == 2) {
+        if (clntaddnewcb != 1 || clntparsenewcb != 1 || srvaddnewcb != 1
+                || srvparsenewcb != 1) {
+            printf("Custom extension callbacks not called\n");
+            goto end;
+        }
+    } else {
+        if (clntaddnewcb != 1 || clntparsenewcb != 4 || srvaddnewcb != 4
+                || srvparsenewcb != 1) {
+            printf("Custom extension callbacks not called\n");
+            goto end;
+        }
+    }
+
+    sess = SSL_get1_session(clientssl);
+
+    SSL_shutdown(clientssl);
+    SSL_shutdown(serverssl);
+
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    serverssl = clientssl = NULL;
+
+    if (!create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL, NULL)) {
+        printf("Unable to create SSL objects (2)\n");
+        goto end;
+    }
+
+    if (!SSL_set_session(clientssl, sess)) {
+        printf("Failed setting session\n");
+        goto end;
+    }
+
+    if (!create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE)) {
+        printf("Unable to create SSL connection (2)\n");
+        goto end;
+    }
+
+    /*
+     * For a resumed session we expect to add the ClientHello extension. For the
+     * old style callbacks we ignore it on the server side because they set
+     * SSL_EXT_IGNORE_ON_RESUMPTION. The new style callbacks do not ignore
+     * them.
+     */
+    if (tst == 0) {
+        if (clntaddoldcb != 2 || clntparseoldcb != 1 || srvaddoldcb != 1
+                || srvparseoldcb != 1) {
+            printf("Unexpected custom extension callback calls\n");
+            goto end;
+        }
+    } else if (tst == 1 || tst == 2) {
+        if (clntaddnewcb != 2 || clntparsenewcb != 2 || srvaddnewcb != 2
+                || srvparsenewcb != 2) {
+            printf("Unexpected custom extension callback calls\n");
+            goto end;
+        }
+    } else {
+        /* No Certificate message extensions in the resumption handshake */
+        if (clntaddnewcb != 2 || clntparsenewcb != 7 || srvaddnewcb != 7
+                || srvparsenewcb != 2) {
+            printf("Unexpected custom extension callback calls\n");
+            goto end;
+        }
+    }
+
+    testresult = 1;
+
+end:
+    SSL_SESSION_free(sess);
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+
+    return testresult;
+}
+
 int test_main(int argc, char *argv[])
 {
     int testresult = 1;
@@ -2295,6 +2602,7 @@ int test_main(int argc, char *argv[])
     ADD_ALL_TESTS(test_early_data_tls1_2, 2);
 # endif
 #endif
+    ADD_ALL_TESTS(test_custom_exts, 4);
 
     testresult = run_tests(argv[0]);