diff --git a/components/openssl/include/internal/ssl_types.h b/components/openssl/include/internal/ssl_types.h index 1dc31f5a5..34249ea05 100644 --- a/components/openssl/include/internal/ssl_types.h +++ b/components/openssl/include/internal/ssl_types.h @@ -213,7 +213,7 @@ struct ssl_st /* where we are */ OSSL_STATEM statem; - SSL_SESSION session; + SSL_SESSION *session; int verify_mode; diff --git a/components/openssl/library/ssl_lib.c b/components/openssl/library/ssl_lib.c index a84b89e06..ded30a33a 100644 --- a/components/openssl/library/ssl_lib.c +++ b/components/openssl/library/ssl_lib.c @@ -117,6 +117,38 @@ OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl) return state; } +/** + * @brief create a new SSL session object + */ +SSL_SESSION* SSL_SESSION_new(void) +{ + SSL_SESSION *session; + + session = ssl_zalloc(sizeof(SSL_SESSION)); + if (!session) + SSL_RET(failed1); + + session->peer = X509_new(); + if (!session->peer) + SSL_RET(failed2); + + return session; + +failed2: + ssl_free(session); +failed1: + return NULL; +} + +/** + * @brief free a new SSL session object + */ +void SSL_SESSION_free(SSL_SESSION *session) +{ + X509_free(session->peer); + ssl_free(session); +} + /** * @brief create a SSL context */ @@ -210,6 +242,10 @@ SSL *SSL_new(SSL_CTX *ctx) if (!ssl) SSL_RET(failed1, "ssl_zalloc\n"); + ssl->session = SSL_SESSION_new(); + if (!ssl->session) + SSL_RET(failed2, "ssl_zalloc\n"); + ssl->ctx = ctx; ssl->method = ctx->method; @@ -222,12 +258,14 @@ SSL *SSL_new(SSL_CTX *ctx) ret = SSL_METHOD_CALL(new, ssl); if (ret) - SSL_RET(failed2, "ssl_new\n"); + SSL_RET(failed3, "ssl_new\n"); ssl->rwstate = SSL_NOTHING; return ssl; +failed3: + SSL_SESSION_free(ssl->session); failed2: ssl_free(ssl); failed1: @@ -243,6 +281,8 @@ void SSL_free(SSL *ssl) SSL_METHOD_CALL(free, ssl); + SSL_SESSION_free(ssl->session); + if (ssl->ca_reload) X509_free(ssl->client_CA); @@ -1369,7 +1409,7 @@ long SSL_set_time(SSL *ssl, long t) { SSL_ASSERT(ssl); - ssl->session.time = t; + ssl->session->time = t; return t; } @@ -1381,7 +1421,7 @@ long SSL_set_timeout(SSL *ssl, long t) { SSL_ASSERT(ssl); - ssl->session.timeout = t; + ssl->session->timeout = t; return t; } diff --git a/components/openssl/library/ssl_x509.c b/components/openssl/library/ssl_x509.c index e96511dc4..c3fa0b307 100644 --- a/components/openssl/library/ssl_x509.c +++ b/components/openssl/library/ssl_x509.c @@ -32,7 +32,7 @@ X509* X509_new(void) x->method = X509_method(); - ret = x->method->x509_new(x); + ret = X509_METHOD_CALL(new, x); if (ret) SSL_RET(failed2, "x509_new\n"); @@ -256,5 +256,5 @@ X509 *SSL_get_peer_certificate(const SSL *ssl) { SSL_ASSERT(ssl); - return ssl->session.peer; + return ssl->session->peer; } diff --git a/components/openssl/platform/ssl_pm.c b/components/openssl/platform/ssl_pm.c index 9abfc212e..0cf8f6c0a 100644 --- a/components/openssl/platform/ssl_pm.c +++ b/components/openssl/platform/ssl_pm.c @@ -43,16 +43,16 @@ struct ssl_pm struct x509_pm { - mbedtls_x509_crt x509_crt; + mbedtls_x509_crt *x509_crt; - int load; + mbedtls_x509_crt *ex_crt; }; struct pkey_pm { - mbedtls_pk_context pkey; + mbedtls_pk_context *pkey; - int load; + mbedtls_pk_context *ex_pkey; }; @@ -78,13 +78,9 @@ int ssl_pm_new(SSL *ssl) const SSL_METHOD *method = ssl->method; - ssl->session.peer = ssl_zalloc(sizeof(X509)); - if (!ssl->session.peer) - SSL_ERR(ret, failed1, "ssl_zalloc\n"); - ssl_pm = ssl_zalloc(sizeof(struct ssl_pm)); if (!ssl_pm) - SSL_ERR(ret, failed2, "ssl_zalloc\n"); + SSL_ERR(ret, failed1, "ssl_zalloc\n"); mbedtls_net_init(&ssl_pm->fd); mbedtls_net_init(&ssl_pm->cl_fd); @@ -96,7 +92,7 @@ int ssl_pm_new(SSL *ssl) ret = mbedtls_ctr_drbg_seed(&ssl_pm->ctr_drbg, mbedtls_entropy_func, &ssl_pm->entropy, pers, pers_len); if (ret) - SSL_ERR(ret, failed3, "mbedtls_ctr_drbg_seed:[-0x%x]\n", -ret); + SSL_ERR(ret, failed2, "mbedtls_ctr_drbg_seed:[-0x%x]\n", -ret); if (method->endpoint) { endpoint = MBEDTLS_SSL_IS_SERVER; @@ -105,7 +101,7 @@ int ssl_pm_new(SSL *ssl) } ret = mbedtls_ssl_config_defaults(&ssl_pm->conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret) - SSL_ERR(ret, failed3, "mbedtls_ssl_config_defaults:[-0x%x]\n", -ret); + SSL_ERR(ret, failed2, "mbedtls_ssl_config_defaults:[-0x%x]\n", -ret); if (TLS1_2_VERSION == ssl->version) version = MBEDTLS_SSL_MINOR_VERSION_3; @@ -124,7 +120,7 @@ int ssl_pm_new(SSL *ssl) ret = mbedtls_ssl_setup(&ssl_pm->ssl, &ssl_pm->conf); if (ret) - SSL_ERR(ret, failed4, "mbedtls_ssl_setup:[-0x%x]\n", -ret); + SSL_ERR(ret, failed3, "mbedtls_ssl_setup:[-0x%x]\n", -ret); mbedtls_ssl_set_bio(&ssl_pm->ssl, &ssl_pm->fd, mbedtls_net_send, mbedtls_net_recv, NULL); @@ -132,13 +128,11 @@ int ssl_pm_new(SSL *ssl) return 0; -failed4: +failed3: mbedtls_ssl_config_free(&ssl_pm->conf); mbedtls_ctr_drbg_free(&ssl_pm->ctr_drbg); -failed3: - mbedtls_entropy_free(&ssl_pm->entropy); failed2: - ssl_free(ssl->session.peer); + mbedtls_entropy_free(&ssl_pm->entropy); failed1: return -1; } @@ -155,9 +149,6 @@ void ssl_pm_free(SSL *ssl) mbedtls_ssl_config_free(&ssl_pm->conf); mbedtls_ssl_free(&ssl_pm->ssl); - ssl_free(ssl->session.peer); - ssl->session.peer = NULL; - ssl_free(ssl_pm); ssl->ssl_pm = NULL; } @@ -186,12 +177,12 @@ static int ssl_pm_reload_crt(SSL *ssl) mbedtls_ssl_conf_authmode(&ssl_pm->conf, mode); - if (ca_pm->load) { - mbedtls_ssl_conf_ca_chain(&ssl_pm->conf, &ca_pm->x509_crt, NULL); + if (ca_pm->x509_crt) { + mbedtls_ssl_conf_ca_chain(&ssl_pm->conf, ca_pm->x509_crt, NULL); } - if (pkey_pm->load) { - ret = mbedtls_ssl_conf_own_cert(&ssl_pm->conf, &crt_pm->x509_crt, &pkey_pm->pkey); + if (crt_pm->x509_crt && pkey_pm->pkey) { + ret = mbedtls_ssl_conf_own_cert(&ssl_pm->conf, crt_pm->x509_crt, pkey_pm->pkey); if (ret) return -1; } @@ -217,9 +208,11 @@ int ssl_pm_handshake(SSL *ssl) ssl_speed_up_exit(); if (!mbed_ret) { + struct x509_pm *x509_pm = (struct x509_pm *)ssl->session->peer->x509_pm; + ret = 1; - ssl->session.peer->x509_pm = (struct x509_pm *)mbedtls_ssl_get_peer_cert(&ssl_pm->ssl); + x509_pm->ex_crt = (mbedtls_x509_crt *)mbedtls_ssl_get_peer_cert(&ssl_pm->ssl); } else { ret = 0; SSL_DEBUG(1, "mbedtls_ssl_handshake [-0x%x]\n", -mbed_ret); @@ -234,8 +227,13 @@ int ssl_pm_shutdown(SSL *ssl) struct ssl_pm *ssl_pm = (struct ssl_pm *)ssl->ssl_pm; mbed_ret = mbedtls_ssl_close_notify(&ssl_pm->ssl); - if (!mbed_ret) + if (!mbed_ret) { + struct x509_pm *x509_pm = (struct x509_pm *)ssl->session->peer->x509_pm; + ret = 0; + + x509_pm->ex_crt = NULL; + } else ret = -1; @@ -365,51 +363,26 @@ int x509_pm_new(X509 *x) x509_pm = ssl_zalloc(sizeof(struct x509_pm)); if (!x509_pm) - return -1; + SSL_RET(failed1); x->x509_pm = x509_pm; return 0; + +failed1: + return -1; } void x509_pm_unload(X509 *x) { struct x509_pm *x509_pm = (struct x509_pm *)x->x509_pm; - if (x509_pm->load) - mbedtls_x509_crt_free(&x509_pm->x509_crt); + if (x509_pm->x509_crt) { + mbedtls_x509_crt_free(x509_pm->x509_crt); - x509_pm->load = 0; -} - -int x509_pm_load(X509 *x, const unsigned char *buffer, int len) -{ - int ret; - unsigned char *load_buf; - struct x509_pm *x509_pm = (struct x509_pm *)x->x509_pm; - - load_buf = ssl_malloc(len + 1); - if (!load_buf) - SSL_RET(failed1); - - ssl_memcpy(load_buf, buffer, len); - load_buf[len] = '\0'; - - x509_pm_unload(x); - - mbedtls_x509_crt_init(&x509_pm->x509_crt); - ret = mbedtls_x509_crt_parse(&x509_pm->x509_crt, load_buf, len); - ssl_free(load_buf); - - if (ret) - SSL_RET(failed1, ""); - - x509_pm->load = 1; - - return 0; - -failed1: - return -1; + ssl_free(x509_pm->x509_crt); + x509_pm->x509_crt = NULL; + } } void x509_pm_free(X509 *x) @@ -420,6 +393,44 @@ void x509_pm_free(X509 *x) x->x509_pm = NULL; } +int x509_pm_load(X509 *x, const unsigned char *buffer, int len) +{ + int ret; + unsigned char *load_buf; + struct x509_pm *x509_pm = (struct x509_pm *)x->x509_pm; + + if (!x509_pm->x509_crt) { + x509_pm->x509_crt = ssl_malloc(sizeof(mbedtls_x509_crt)); + if (!x509_pm->x509_crt) + SSL_RET(failed1); + } + + load_buf = ssl_malloc(len + 1); + if (!load_buf) + SSL_RET(failed2); + + ssl_memcpy(load_buf, buffer, len); + load_buf[len] = '\0'; + + if (x509_pm->x509_crt) + mbedtls_x509_crt_free(x509_pm->x509_crt); + + mbedtls_x509_crt_init(x509_pm->x509_crt); + ret = mbedtls_x509_crt_parse(x509_pm->x509_crt, load_buf, len); + ssl_free(load_buf); + + if (ret) + SSL_RET(failed2); + + return 0; + +failed2: + ssl_free(x509_pm->x509_crt); + x509_pm->x509_crt = NULL; +failed1: + return -1; +} + int pkey_pm_new(EVP_PKEY *pkey) { struct pkey_pm *pkey_pm; @@ -437,40 +448,12 @@ void pkey_pm_unload(EVP_PKEY *pkey) { struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm; - if (pkey_pm->load) - mbedtls_pk_free(&pkey_pm->pkey); + if (pkey_pm->pkey) { + mbedtls_pk_free(pkey_pm->pkey); - pkey_pm->load = 0; -} - -int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len) -{ - int ret; - unsigned char *load_buf; - struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm; - - load_buf = ssl_malloc(len + 1); - if (!load_buf) - SSL_RET(failed1); - - ssl_memcpy(load_buf, buffer, len); - load_buf[len] = '\0'; - - pkey_pm_unload(pkey); - - mbedtls_pk_init(&pkey_pm->pkey); - ret = mbedtls_pk_parse_key(&pkey_pm->pkey, load_buf, len, NULL, 0); - ssl_free(load_buf); - - if (ret) - SSL_RET(failed1, ""); - - pkey_pm->load = 1; - - return 0; - -failed1: - return -1; + ssl_free(pkey_pm->pkey); + pkey_pm->pkey = NULL; + } } void pkey_pm_free(EVP_PKEY *pkey) @@ -481,6 +464,46 @@ void pkey_pm_free(EVP_PKEY *pkey) pkey->pkey_pm = NULL; } +int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len) +{ + int ret; + unsigned char *load_buf; + struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm; + + if (!pkey_pm->pkey) { + pkey_pm->pkey = ssl_malloc(sizeof(mbedtls_pk_context)); + if (!pkey_pm->pkey) + SSL_RET(failed1); + } + + load_buf = ssl_malloc(len + 1); + if (!load_buf) + SSL_RET(failed2); + + ssl_memcpy(load_buf, buffer, len); + load_buf[len] = '\0'; + + if (pkey_pm->pkey) + mbedtls_pk_free(pkey_pm->pkey); + + mbedtls_pk_init(pkey_pm->pkey); + ret = mbedtls_pk_parse_key(pkey_pm->pkey, load_buf, len, NULL, 0); + ssl_free(load_buf); + + if (ret) + SSL_RET(failed2); + + return 0; + +failed2: + ssl_free(pkey_pm->pkey); + pkey_pm->pkey = NULL; +failed1: + return -1; +} + + + void ssl_pm_set_bufflen(SSL *ssl, int len) { max_content_len = len;