diff --git a/components/openssl/include/internal/ssl_cert.h b/components/openssl/include/internal/ssl_cert.h index 6441aaf52..86cf31ad5 100644 --- a/components/openssl/include/internal/ssl_cert.h +++ b/components/openssl/include/internal/ssl_cert.h @@ -21,6 +21,15 @@ #include "ssl_types.h" +/** + * @brief create a certification object include private key object according to input certification + * + * @param ic - input certification point + * + * @return certification object point + */ +CERT *__ssl_cert_new(CERT *ic); + /** * @brief create a certification object include private key object * diff --git a/components/openssl/include/internal/ssl_methods.h b/components/openssl/include/internal/ssl_methods.h index a20b7c768..9fd9ce906 100644 --- a/components/openssl/include/internal/ssl_methods.h +++ b/components/openssl/include/internal/ssl_methods.h @@ -69,14 +69,12 @@ #define IMPLEMENT_X509_METHOD(func_name, \ new, \ free, \ - load, \ - unload) \ + load) \ const X509_METHOD* func_name(void) { \ static const X509_METHOD func_name##_data LOCAL_ATRR = { \ new, \ free, \ - load, \ - unload, \ + load \ }; \ return &func_name##_data; \ } @@ -84,14 +82,12 @@ #define IMPLEMENT_PKEY_METHOD(func_name, \ new, \ free, \ - load, \ - unload) \ + load) \ const PKEY_METHOD* func_name(void) { \ static const PKEY_METHOD func_name##_data LOCAL_ATRR = { \ new, \ free, \ - load, \ - unload, \ + load \ }; \ return &func_name##_data; \ } diff --git a/components/openssl/include/internal/ssl_pkey.h b/components/openssl/include/internal/ssl_pkey.h index 5b7f341de..f4da04168 100644 --- a/components/openssl/include/internal/ssl_pkey.h +++ b/components/openssl/include/internal/ssl_pkey.h @@ -21,6 +21,15 @@ #include "ssl_types.h" +/** + * @brief create a private key object according to input private key + * + * @param ipk - input private key point + * + * @return new private key object point + */ +EVP_PKEY* __EVP_PKEY_new(EVP_PKEY *ipk); + /** * @brief create a private key object * diff --git a/components/openssl/include/internal/ssl_types.h b/components/openssl/include/internal/ssl_types.h index 34249ea05..c571865c1 100644 --- a/components/openssl/include/internal/ssl_types.h +++ b/components/openssl/include/internal/ssl_types.h @@ -196,12 +196,8 @@ struct ssl_st /* shut things down(0x01 : sent, 0x02 : received) */ int shutdown; - int crt_reload; - CERT *cert; - int ca_reload; - X509 *client_CA; SSL_CTX *ctx; @@ -274,24 +270,20 @@ struct ssl_method_func_st { struct x509_method_st { - int (*x509_new)(X509 *x); + int (*x509_new)(X509 *x, X509 *m_x); void (*x509_free)(X509 *x); int (*x509_load)(X509 *x, const unsigned char *buf, int len); - - void (*x509_unload)(X509 *x); }; struct pkey_method_st { - int (*pkey_new)(EVP_PKEY *pkey); + int (*pkey_new)(EVP_PKEY *pkey, EVP_PKEY *m_pkey); void (*pkey_free)(EVP_PKEY *pkey); int (*pkey_load)(EVP_PKEY *pkey, const unsigned char *buf, int len); - - void (*pkey_unload)(EVP_PKEY *pkey); }; typedef int (*next_proto_cb)(SSL *ssl, unsigned char **out, diff --git a/components/openssl/include/internal/ssl_x509.h b/components/openssl/include/internal/ssl_x509.h index 5dac46137..2c72980b0 100644 --- a/components/openssl/include/internal/ssl_x509.h +++ b/components/openssl/include/internal/ssl_x509.h @@ -24,6 +24,15 @@ DEFINE_STACK_OF(X509_NAME) +/** + * @brief create a X509 certification object according to input X509 certification + * + * @param ix - input X509 certification point + * + * @return new X509 certification object point + */ +X509* __X509_new(X509 *ix); + /** * @brief create a X509 certification object * diff --git a/components/openssl/include/platform/ssl_pm.h b/components/openssl/include/platform/ssl_pm.h index 47a7331b7..cf1d21379 100644 --- a/components/openssl/include/platform/ssl_pm.h +++ b/components/openssl/include/platform/ssl_pm.h @@ -42,16 +42,13 @@ OSSL_HANDSHAKE_STATE ssl_pm_get_state(const SSL *ssl); void ssl_pm_set_bufflen(SSL *ssl, int len); -int x509_pm_new(X509 *x); +int x509_pm_new(X509 *x, X509 *m_x); void x509_pm_free(X509 *x); int x509_pm_load(X509 *x, const unsigned char *buffer, int len); -void x509_pm_unload(X509 *x); -void x509_pm_start_ca(X509 *x); -int pkey_pm_new(EVP_PKEY *pkey); -void pkey_pm_free(EVP_PKEY *pkey); -int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len); -void pkey_pm_unload(EVP_PKEY *pkey); +int pkey_pm_new(EVP_PKEY *pk, EVP_PKEY *m_pk); +void pkey_pm_free(EVP_PKEY *pk); +int pkey_pm_load(EVP_PKEY *pk, const unsigned char *buffer, int len); long ssl_pm_get_verify_result(const SSL *ssl); diff --git a/components/openssl/library/ssl_cert.c b/components/openssl/library/ssl_cert.c index fd05bc831..e4fd4d778 100644 --- a/components/openssl/library/ssl_cert.c +++ b/components/openssl/library/ssl_cert.c @@ -19,23 +19,34 @@ #include "ssl_port.h" /** - * @brief create a certification object include private key object + * @brief create a certification object according to input certification */ -CERT *ssl_cert_new(void) +CERT *__ssl_cert_new(CERT *ic) { CERT *cert; + X509 *ix; + EVP_PKEY *ipk; + cert = ssl_zalloc(sizeof(CERT)); if (!cert) SSL_RET(failed1, "ssl_zalloc\n"); - cert->pkey = EVP_PKEY_new(); - if (!cert->pkey) - SSL_RET(failed2, "EVP_PKEY_new\n"); + if (ic) { + ipk = ic->pkey; + ix = ic->x509; + } else { + ipk = NULL; + ix = NULL; + } - cert->x509 = X509_new(); + cert->pkey = __EVP_PKEY_new(ipk); + if (!cert->pkey) + SSL_RET(failed2, "__EVP_PKEY_new\n"); + + cert->x509 = __X509_new(ix); if (!cert->x509) - SSL_RET(failed3, "X509_new\n"); + SSL_RET(failed3, "__X509_new\n"); return cert; @@ -47,6 +58,14 @@ failed1: return NULL; } +/** + * @brief create a certification object include private key object + */ +CERT *ssl_cert_new(void) +{ + return __ssl_cert_new(NULL); +} + /** * @brief free a certification object */ diff --git a/components/openssl/library/ssl_lib.c b/components/openssl/library/ssl_lib.c index 06bbe270c..b82d54cd2 100644 --- a/components/openssl/library/ssl_lib.c +++ b/components/openssl/library/ssl_lib.c @@ -158,11 +158,11 @@ SSL_CTX* SSL_CTX_new(const SSL_METHOD *method) CERT *cert; X509 *client_ca; - if (!method) SSL_RET(go_failed1, "method\n"); + if (!method) SSL_RET(go_failed1, "method:NULL\n"); client_ca = X509_new(); if (!client_ca) - SSL_RET(go_failed1, "sk_X509_NAME_new_null\n"); + SSL_RET(go_failed1, "X509_new\n"); cert = ssl_cert_new(); if (!cert) @@ -170,7 +170,7 @@ SSL_CTX* SSL_CTX_new(const SSL_METHOD *method) ctx = (SSL_CTX *)ssl_zalloc(sizeof(SSL_CTX)); if (!ctx) - SSL_RET(go_failed3, "ssl_ctx_new:ctx\n"); + SSL_RET(go_failed3, "ssl_zalloc:ctx\n"); ctx->method = method; ctx->client_CA = client_ca; @@ -244,15 +244,15 @@ SSL *SSL_new(SSL_CTX *ctx) ssl->session = SSL_SESSION_new(); if (!ssl->session) - SSL_RET(failed2, "ssl_zalloc\n"); + SSL_RET(failed2, "SSL_SESSION_new\n"); - ssl->cert = ssl_cert_new(); + ssl->cert = __ssl_cert_new(ctx->cert); if (!ssl->cert) - SSL_RET(failed3, "ssl_cert_new\n"); + SSL_RET(failed3, "__ssl_cert_new\n"); - ssl->client_CA = X509_new(); + ssl->client_CA = __X509_new(ctx->client_CA); if (!ssl->client_CA) - SSL_RET(failed4, "ssl_cert_new\n"); + SSL_RET(failed4, "__X509_new\n"); ssl->ctx = ctx; ssl->method = ctx->method; diff --git a/components/openssl/library/ssl_methods.c b/components/openssl/library/ssl_methods.c index 042d670ab..e363b5e46 100644 --- a/components/openssl/library/ssl_methods.c +++ b/components/openssl/library/ssl_methods.c @@ -72,11 +72,11 @@ IMPLEMENT_SSL_METHOD(SSL3_VERSION, -1, TLS_method_func, SSLv3_method); */ IMPLEMENT_X509_METHOD(X509_method, x509_pm_new, x509_pm_free, - x509_pm_load, x509_pm_unload); + x509_pm_load); /** * @brief get private key object method */ IMPLEMENT_PKEY_METHOD(EVP_PKEY_method, pkey_pm_new, pkey_pm_free, - pkey_pm_load, pkey_pm_unload); + pkey_pm_load); diff --git a/components/openssl/library/ssl_pkey.c b/components/openssl/library/ssl_pkey.c index 6891b69eb..573b1f2e8 100644 --- a/components/openssl/library/ssl_pkey.c +++ b/components/openssl/library/ssl_pkey.c @@ -20,20 +20,24 @@ #include "ssl_port.h" /** - * @brief create a private key object + * @brief create a private key object according to input private key */ -EVP_PKEY* EVP_PKEY_new(void) +EVP_PKEY* __EVP_PKEY_new(EVP_PKEY *ipk) { int ret; EVP_PKEY *pkey; pkey = ssl_zalloc(sizeof(EVP_PKEY)); if (!pkey) - SSL_RET(failed1, "ssl_malloc\n"); + SSL_RET(failed1, "ssl_zalloc\n"); - pkey->method = EVP_PKEY_method(); + if (ipk) { + pkey->method = ipk->method; + } else { + pkey->method = EVP_PKEY_method(); + } - ret = EVP_PKEY_METHOD_CALL(new, pkey); + ret = EVP_PKEY_METHOD_CALL(new, pkey, ipk); if (ret) SSL_RET(failed2, "EVP_PKEY_METHOD_CALL\n"); @@ -45,6 +49,14 @@ failed1: return NULL; } +/** + * @brief create a private key object + */ +EVP_PKEY* EVP_PKEY_new(void) +{ + return __EVP_PKEY_new(NULL); +} + /** * @brief free a private key object */ @@ -105,6 +117,9 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey) SSL_ASSERT(ctx); SSL_ASSERT(pkey); + if (ctx->cert->pkey == pkey) + return 1; + if (ctx->cert->pkey) EVP_PKEY_free(ctx->cert->pkey); @@ -118,12 +133,13 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey) */ int SSL_use_PrivateKey(SSL *ssl, EVP_PKEY *pkey) { - SSL_ASSERT(ctx); + SSL_ASSERT(ssl); SSL_ASSERT(pkey); - if (!ssl->ca_reload) - ssl->ca_reload = 1; - else + if (ssl->cert->pkey == pkey) + return 1; + + if (ssl->cert->pkey) EVP_PKEY_free(ssl->cert->pkey); ssl->cert->pkey = pkey; @@ -138,20 +154,20 @@ int SSL_CTX_use_PrivateKey_ASN1(int type, SSL_CTX *ctx, const unsigned char *d, long len) { int ret; - EVP_PKEY *pkey; + EVP_PKEY *pk; - pkey = d2i_PrivateKey(0, &ctx->cert->pkey, &d, len); - if (!pkey) + pk = d2i_PrivateKey(0, NULL, &d, len); + if (!pk) SSL_RET(failed1, "d2i_PrivateKey\n"); - ret = SSL_CTX_use_PrivateKey(ctx, pkey); + ret = SSL_CTX_use_PrivateKey(ctx, pk); if (!ret) SSL_RET(failed2, "SSL_CTX_use_PrivateKey\n"); return 1; failed2: - EVP_PKEY_free(pkey); + EVP_PKEY_free(pk); failed1: return 0; } @@ -163,44 +179,20 @@ int SSL_use_PrivateKey_ASN1(int type, SSL *ssl, const unsigned char *d, long len) { int ret; - int reload; - EVP_PKEY *pkey; - CERT *cert; - CERT *old_cert; + EVP_PKEY *pk; - if (!ssl->crt_reload) { - cert = ssl_cert_new(); - if (!cert) - SSL_RET(failed1, "ssl_cert_new\n"); + pk = d2i_PrivateKey(0, NULL, &d, len); + if (!pk) + SSL_RET(failed1, "d2i_PrivateKey\n"); - old_cert = ssl->cert ; - ssl->cert = cert; - - ssl->crt_reload = 1; - - reload = 1; - } else { - reload = 0; - } - - pkey = d2i_PrivateKey(0, &ssl->cert->pkey, &d, len); - if (!pkey) - SSL_RET(failed2, "d2i_PrivateKey\n"); - - ret = SSL_use_PrivateKey(ssl, pkey); + ret = SSL_use_PrivateKey(ssl, pk); if (!ret) - SSL_RET(failed3, "SSL_use_PrivateKey\n"); + SSL_RET(failed2, "SSL_use_PrivateKey\n"); return 1; -failed3: - EVP_PKEY_free(pkey); failed2: - if (reload) { - ssl->cert = old_cert; - ssl_cert_free(cert); - ssl->crt_reload = 0; - } + EVP_PKEY_free(pk); failed1: return 0; } diff --git a/components/openssl/library/ssl_x509.c b/components/openssl/library/ssl_x509.c index c3fa0b307..b57cc0dfb 100644 --- a/components/openssl/library/ssl_x509.c +++ b/components/openssl/library/ssl_x509.c @@ -19,9 +19,9 @@ #include "ssl_port.h" /** - * @brief create a X509 certification object + * @brief create a X509 certification object according to input X509 certification */ -X509* X509_new(void) +X509* __X509_new(X509 *ix) { int ret; X509 *x; @@ -30,9 +30,12 @@ X509* X509_new(void) if (!x) SSL_RET(failed1, "ssl_malloc\n"); - x->method = X509_method(); + if (ix) + x->method = ix->method; + else + x->method = X509_method(); - ret = X509_METHOD_CALL(new, x); + ret = X509_METHOD_CALL(new, x, ix); if (ret) SSL_RET(failed2, "x509_new\n"); @@ -44,6 +47,14 @@ failed1: return NULL; } +/** + * @brief create a X509 certification object + */ +X509* X509_new(void) +{ + return __X509_new(NULL); +} + /** * @brief free a X509 certification object */ @@ -78,7 +89,7 @@ X509* d2i_X509(X509 **cert, const unsigned char *buffer, long len) ret = X509_METHOD_CALL(load, x, buffer, len); if (ret) - SSL_RET(failed2, "X509_METHOD_CALL\n"); + SSL_RET(failed2, "x509_load\n"); return x; @@ -97,8 +108,10 @@ int SSL_CTX_add_client_CA(SSL_CTX *ctx, X509 *x) SSL_ASSERT(ctx); SSL_ASSERT(x); - if (ctx->client_CA) - X509_free(ctx->client_CA); + if (ctx->client_CA == x) + return 1; + + X509_free(ctx->client_CA); ctx->client_CA = x; @@ -113,10 +126,10 @@ int SSL_add_client_CA(SSL *ssl, X509 *x) SSL_ASSERT(ssl); SSL_ASSERT(x); - if (!ssl->ca_reload) - ssl->ca_reload = 1; - else - X509_free(ssl->client_CA); + if (ssl->client_CA == x) + return 1; + + X509_free(ssl->client_CA); ssl->client_CA = x; @@ -131,6 +144,11 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x) SSL_ASSERT(ctx); SSL_ASSERT(x); + if (ctx->cert->x509 == x) + return 1; + + X509_free(ctx->cert->x509); + ctx->cert->x509 = x; return 1; @@ -141,9 +159,14 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x) */ int SSL_use_certificate(SSL *ssl, X509 *x) { - SSL_ASSERT(ctx); + SSL_ASSERT(ssl); SSL_ASSERT(x); + if (ssl->cert->x509 == x) + return 1; + + X509_free(ssl->cert->x509); + ssl->cert->x509 = x; return 1; @@ -166,20 +189,20 @@ int SSL_CTX_use_certificate_ASN1(SSL_CTX *ctx, int len, const unsigned char *d) { int ret; - X509 *cert; + X509 *x; - cert = d2i_X509(&ctx->cert->x509, d, len); - if (!cert) + x = d2i_X509(NULL, d, len); + if (!x) SSL_RET(failed1, "d2i_X509\n"); - ret = SSL_CTX_use_certificate(ctx, cert); + ret = SSL_CTX_use_certificate(ctx, x); if (!ret) SSL_RET(failed2, "SSL_CTX_use_certificate\n"); return 1; failed2: - X509_free(cert); + X509_free(x); failed1: return 0; } @@ -193,42 +216,20 @@ int SSL_use_certificate_ASN1(SSL *ssl, int len, int ret; int reload; X509 *x; - CERT *cert; - CERT *old_cert; + int m = 0; - if (!ssl->crt_reload) { - cert = ssl_cert_new(); - if (!cert) - SSL_RET(failed1, "ssl_cert_new\n"); - - old_cert = ssl->cert ; - ssl->cert = cert; - - ssl->crt_reload = 1; - - reload = 1; - } else { - reload = 0; - } - - x = d2i_X509(&ssl->cert->x509, d, len); + x = d2i_X509(NULL, d, len); if (!x) - SSL_RET(failed2, "d2i_X509\n"); + SSL_RET(failed1, "d2i_X509\n"); ret = SSL_use_certificate(ssl, x); if (!ret) - SSL_RET(failed3, "SSL_use_certificate\n"); + SSL_RET(failed2, "SSL_use_certificate\n"); return 1; -failed3: - X509_free(x); failed2: - if (reload) { - ssl->cert = old_cert; - ssl_cert_free(cert); - ssl->crt_reload = 0; - } + X509_free(x); failed1: return 0; } diff --git a/components/openssl/platform/ssl_pm.c b/components/openssl/platform/ssl_pm.c index 311c3a4b6..9f5290cc5 100644 --- a/components/openssl/platform/ssl_pm.c +++ b/components/openssl/platform/ssl_pm.c @@ -78,14 +78,6 @@ int ssl_pm_new(SSL *ssl) const SSL_METHOD *method = ssl->method; - struct x509_pm *ctx_ca = (struct x509_pm *)ssl->ctx->client_CA->x509_pm; - struct x509_pm *ctx_crt = (struct x509_pm *)ssl->ctx->cert->x509->x509_pm; - struct pkey_pm *ctx_pkey = (struct pkey_pm *)ssl->ctx->cert->pkey->pkey_pm; - - struct x509_pm *ssl_ca = (struct x509_pm *)ssl->client_CA->x509_pm; - struct x509_pm *ssl_crt = (struct x509_pm *)ssl->cert->x509->x509_pm; - struct pkey_pm *ssl_pkey = (struct pkey_pm *)ssl->cert->pkey->pkey_pm; - ssl_pm = ssl_zalloc(sizeof(struct ssl_pm)); if (!ssl_pm) SSL_ERR(ret, failed1, "ssl_zalloc\n"); @@ -134,10 +126,6 @@ int ssl_pm_new(SSL *ssl) ssl->ssl_pm = ssl_pm; - ssl_ca->ex_crt = ctx_ca->x509_crt; - ssl_crt->ex_crt = ctx_crt->x509_crt; - ssl_pkey->ex_pkey = ctx_pkey->pkey; - return 0; failed3: @@ -376,7 +364,7 @@ OSSL_HANDSHAKE_STATE ssl_pm_get_state(const SSL *ssl) return state; } -int x509_pm_new(X509 *x) +int x509_pm_new(X509 *x, X509 *m_x) { struct x509_pm *x509_pm; @@ -386,13 +374,19 @@ int x509_pm_new(X509 *x) x->x509_pm = x509_pm; + if (m_x) { + struct x509_pm *m_x509_pm = (struct x509_pm *)m_x->x509_pm; + + x509_pm->ex_crt = m_x509_pm->x509_crt; + } + return 0; failed1: return -1; } -void x509_pm_unload(X509 *x) +void x509_pm_free(X509 *x) { struct x509_pm *x509_pm = (struct x509_pm *)x->x509_pm; @@ -402,11 +396,6 @@ void x509_pm_unload(X509 *x) ssl_free(x509_pm->x509_crt); x509_pm->x509_crt = NULL; } -} - -void x509_pm_free(X509 *x) -{ - x509_pm_unload(x); ssl_free(x->x509_pm); x->x509_pm = NULL; @@ -450,7 +439,7 @@ failed1: return -1; } -int pkey_pm_new(EVP_PKEY *pkey) +int pkey_pm_new(EVP_PKEY *pk, EVP_PKEY *m_pkey) { struct pkey_pm *pkey_pm; @@ -458,14 +447,20 @@ int pkey_pm_new(EVP_PKEY *pkey) if (!pkey_pm) return -1; - pkey->pkey_pm = pkey_pm; + pk->pkey_pm = pkey_pm; + + if (m_pkey) { + struct pkey_pm *m_pkey_pm = (struct pkey_pm *)m_pkey->pkey_pm; + + pkey_pm->ex_pkey = m_pkey_pm->pkey; + } return 0; } -void pkey_pm_unload(EVP_PKEY *pkey) +void pkey_pm_free(EVP_PKEY *pk) { - struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm; + struct pkey_pm *pkey_pm = (struct pkey_pm *)pk->pkey_pm; if (pkey_pm->pkey) { mbedtls_pk_free(pkey_pm->pkey); @@ -473,21 +468,16 @@ void pkey_pm_unload(EVP_PKEY *pkey) ssl_free(pkey_pm->pkey); pkey_pm->pkey = NULL; } + + ssl_free(pk->pkey_pm); + pk->pkey_pm = NULL; } -void pkey_pm_free(EVP_PKEY *pkey) -{ - pkey_pm_unload(pkey); - - ssl_free(pkey->pkey_pm); - pkey->pkey_pm = NULL; -} - -int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len) +int pkey_pm_load(EVP_PKEY *pk, const unsigned char *buffer, int len) { int ret; unsigned char *load_buf; - struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm; + struct pkey_pm *pkey_pm = (struct pkey_pm *)pk->pkey_pm; if (!pkey_pm->pkey) { pkey_pm->pkey = ssl_malloc(sizeof(mbedtls_pk_context));