diff --git a/components/esp-tls/esp_tls.c b/components/esp-tls/esp_tls.c index a0c916a66..39c048196 100644 --- a/components/esp-tls/esp_tls.c +++ b/components/esp-tls/esp_tls.c @@ -551,15 +551,32 @@ static ssize_t tcp_write(esp_tls_t *tls, const char *data, size_t datalen) static ssize_t tls_write(esp_tls_t *tls, const char *data, size_t datalen) { - ssize_t ret = mbedtls_ssl_write(&tls->ssl, (unsigned char*) data, datalen); - if (ret < 0) { - if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { - ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_MBEDTLS, -ret); - ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, ESP_ERR_MBEDTLS_SSL_WRITE_FAILED); - ESP_LOGE(TAG, "write error :%d:", ret); + size_t written = 0; + size_t write_len = datalen; + while (written < datalen) { + if (write_len > MBEDTLS_SSL_OUT_CONTENT_LEN) { + write_len = MBEDTLS_SSL_OUT_CONTENT_LEN; } + if (datalen > MBEDTLS_SSL_OUT_CONTENT_LEN) { + ESP_LOGD(TAG, "Fragmenting data of excessive size :%d, offset: %d, size %d", datalen, written, write_len); + } + ssize_t ret = mbedtls_ssl_write(&tls->ssl, (unsigned char*) data + written, write_len); + if (ret <= 0) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret != 0) { + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_MBEDTLS, -ret); + ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ERR_TYPE_ESP, ESP_ERR_MBEDTLS_SSL_WRITE_FAILED); + ESP_LOGE(TAG, "write error :%d:", ret); + return ret; + } else { + // Exitting the tls-write process as less than desired datalen are writable + ESP_LOGD(TAG, "mbedtls_ssl_write() returned %d, already written %d, exitting...", ret, written); + return written; + } + } + written += ret; + write_len = datalen - written; } - return ret; + return written; } static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_t *tls) diff --git a/examples/protocols/mqtt/ssl/main/app_main.c b/examples/protocols/mqtt/ssl/main/app_main.c index 355819c95..007465429 100644 --- a/examples/protocols/mqtt/ssl/main/app_main.c +++ b/examples/protocols/mqtt/ssl/main/app_main.c @@ -11,25 +11,16 @@ #include #include #include -#include "esp_wifi.h" #include "esp_system.h" #include "nvs_flash.h" #include "esp_event.h" #include "tcpip_adapter.h" #include "protocol_examples_common.h" -#include "freertos/FreeRTOS.h" -#include "freertos/task.h" -#include "freertos/semphr.h" -#include "freertos/queue.h" - -#include "lwip/sockets.h" -#include "lwip/dns.h" -#include "lwip/netdb.h" - #include "esp_log.h" #include "mqtt_client.h" #include "esp_tls.h" +#include "esp_ota_ops.h" static const char *TAG = "MQTTS_EXAMPLE"; @@ -41,6 +32,20 @@ extern const uint8_t mqtt_eclipse_org_pem_start[] asm("_binary_mqtt_eclipse_or #endif extern const uint8_t mqtt_eclipse_org_pem_end[] asm("_binary_mqtt_eclipse_org_pem_end"); +// +// Note: this function is for testing purposes only publishing the entire active partition +// (to be checked against the original binary) +// +static void send_binary(esp_mqtt_client_handle_t client) +{ + spi_flash_mmap_handle_t out_handle; + const void *binary_address; + const esp_partition_t* partition = esp_ota_get_running_partition(); + esp_partition_mmap(partition, 0, partition->size, SPI_FLASH_MMAP_DATA, &binary_address, &out_handle); + int msg_id = esp_mqtt_client_publish(client, "/topic/binary", binary_address, partition->size, 0, 0); + ESP_LOGI(TAG, "binary sent with msg_id=%d", msg_id); +} + static esp_err_t mqtt_event_handler_cb(esp_mqtt_event_handle_t event) { esp_mqtt_client_handle_t client = event->client; @@ -77,6 +82,10 @@ static esp_err_t mqtt_event_handler_cb(esp_mqtt_event_handle_t event) ESP_LOGI(TAG, "MQTT_EVENT_DATA"); printf("TOPIC=%.*s\r\n", event->topic_len, event->topic); printf("DATA=%.*s\r\n", event->data_len, event->data); + if (strncmp(event->data, "send binary please", event->data_len) == 0) { + ESP_LOGI(TAG, "Sending the binary"); + send_binary(client); + } break; case MQTT_EVENT_ERROR: ESP_LOGI(TAG, "MQTT_EVENT_ERROR"); @@ -121,6 +130,7 @@ void app_main(void) ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version()); esp_log_level_set("*", ESP_LOG_INFO); + esp_log_level_set("esp-tls", ESP_LOG_VERBOSE); esp_log_level_set("MQTT_CLIENT", ESP_LOG_VERBOSE); esp_log_level_set("MQTT_EXAMPLE", ESP_LOG_VERBOSE); esp_log_level_set("TRANSPORT_TCP", ESP_LOG_VERBOSE); diff --git a/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py b/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py index ef9d1a9d6..25e98ac76 100644 --- a/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py +++ b/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py @@ -27,6 +27,7 @@ import DUT event_client_connected = Event() event_stop_client = Event() event_client_received_correct = Event() +event_client_received_binary = Event() message_log = "" @@ -45,9 +46,27 @@ def mqtt_client_task(client): # The callback for when a PUBLISH message is received from the server. def on_message(client, userdata, msg): global message_log + global event_client_received_correct + global event_client_received_binary + if msg.topic == "/topic/binary": + binary = userdata + size = os.path.getsize(binary) + print("Receiving binary from esp and comparing with {}, size {}...".format(binary, size)) + with open(binary, "rb") as f: + bin = f.read() + if bin == msg.payload[:size]: + print("...matches!") + event_client_received_binary.set() + return + else: + recv_binary = binary + ".received" + with open(recv_binary, "w") as fw: + fw.write(msg.payload) + raise ValueError('Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary)) payload = msg.payload.decode() if not event_client_received_correct.is_set() and payload == "data": - client.publish("/topic/qos0", "data_to_esp32") + client.subscribe("/topic/binary") + client.publish("/topic/qos0", "send binary please") if msg.topic == "/topic/qos0" and payload == "data": event_client_received_correct.set() message_log += "Received data:" + msg.topic + " " + payload + "\n" @@ -63,6 +82,7 @@ def test_examples_protocol_mqtt_ssl(env, extra_data): 2. Test connects a client to the same broker 3. Test evaluates python client received correct qos0 message 4. Test ESP32 client received correct qos0 message + 5. Test python client receives binary data from running partition and compares it with the binary """ dut1 = env.get_dut("mqtt_ssl", "examples/protocols/mqtt/ssl") # check and log bin size @@ -85,6 +105,7 @@ def test_examples_protocol_mqtt_ssl(env, extra_data): client = mqtt.Client() client.on_connect = on_connect client.on_message = on_message + client.user_data_set(binary_file) client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) @@ -112,7 +133,10 @@ def test_examples_protocol_mqtt_ssl(env, extra_data): if not event_client_received_correct.wait(timeout=30): raise ValueError('Wrong data received, msg log: {}'.format(message_log)) print("Checking esp-client received msg published from py-client...") - dut1.expect(re.compile(r"DATA=data_to_esp32"), timeout=30) + dut1.expect(re.compile(r"DATA=send binary please"), timeout=30) + print("Receiving binary data from running partition...") + if not event_client_received_binary.wait(timeout=30): + raise ValueError('Binary not received within timeout') finally: event_stop_client.set() thread1.join() diff --git a/examples/protocols/mqtt/ssl/sdkconfig.ci b/examples/protocols/mqtt/ssl/sdkconfig.ci index ce328a6b0..b3557c28d 100644 --- a/examples/protocols/mqtt/ssl/sdkconfig.ci +++ b/examples/protocols/mqtt/ssl/sdkconfig.ci @@ -1,2 +1,12 @@ CONFIG_BROKER_URI="mqtts://${EXAMPLE_MQTT_BROKER_SSL}" CONFIG_BROKER_CERTIFICATE_OVERRIDE="${EXAMPLE_MQTT_BROKER_CERTIFICATE}" +CONFIG_MQTT_USE_CUSTOM_CONFIG=y +CONFIG_MQTT_TCP_DEFAULT_PORT=1883 +CONFIG_MQTT_SSL_DEFAULT_PORT=8883 +CONFIG_MQTT_WS_DEFAULT_PORT=80 +CONFIG_MQTT_WSS_DEFAULT_PORT=443 +CONFIG_MQTT_BUFFER_SIZE=16384 +CONFIG_MQTT_TASK_STACK_SIZE=6144 +CONFIG_MBEDTLS_ASYMMETRIC_CONTENT_LEN=y +CONFIG_MBEDTLS_SSL_IN_CONTENT_LEN=16384 +CONFIG_MBEDTLS_SSL_OUT_CONTENT_LEN=4096