From b56012783c62d125cedfaf6199eb8a862558910f Mon Sep 17 00:00:00 2001 From: Marius Vikhammer Date: Fri, 20 Mar 2020 11:07:07 +0800 Subject: [PATCH] tcp_transport/ws_client: websockets now correctly handle messages longer than buffer transport_ws can now be read multiple times in a row to read frames larger than the buffer. Added reporting of total payload length and offset to the user in websocket_client. Added local example test for long messages. Closes IDF-1083 --- .../esp_websocket_client.c | 59 +++++--- .../include/esp_websocket_client.h | 16 +- .../tcp_transport/include/esp_transport_ws.h | 11 ++ components/tcp_transport/transport_ws.c | 143 ++++++++++++++---- examples/protocols/websocket/README.md | 6 - examples/protocols/websocket/example_test.py | 117 ++++++++++---- .../websocket/main/websocket_example.c | 62 +++++--- 7 files changed, 305 insertions(+), 109 deletions(-) diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index e9b412631..d00bdbbe9 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -93,6 +93,8 @@ struct esp_websocket_client { char *tx_buffer; int buffer_size; ws_transport_opcodes_t last_opcode; + int payload_len; + int payload_offset; }; static uint64_t _tick_get_ms(void) @@ -101,19 +103,20 @@ static uint64_t _tick_get_ms(void) } static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle_t client, - esp_websocket_event_id_t event, - const char *data, - int data_len) + esp_websocket_event_id_t event, + const char *data, + int data_len) { esp_err_t err; esp_websocket_event_data_t event_data; event_data.client = client; event_data.user_context = client->config->user_context; - event_data.data_ptr = data; event_data.data_len = data_len; event_data.op_code = client->last_opcode; + event_data.payload_len = client->payload_len; + event_data.payload_offset = client->payload_offset; if ((err = esp_event_post_to(client->event_handle, WEBSOCKET_EVENTS, event, @@ -446,10 +449,38 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con return ESP_OK; } +static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) +{ + int rlen; + client->payload_offset = 0; + do { + rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms); + if (rlen < 0) { + ESP_LOGE(TAG, "Error read data"); + esp_websocket_client_abort_connection(client); + return ESP_FAIL; + } + client->payload_len = esp_transport_ws_get_read_payload_len(client->transport); + client->last_opcode = esp_transport_ws_get_read_opcode(client->transport); + + esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen); + + client->payload_offset += rlen; + } while (client->payload_offset < client->payload_len); + + // if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len + if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) { + const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer; + esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, client->payload_len, + client->config->network_timeout_ms); + } + + return ESP_OK; +} + static void esp_websocket_client_task(void *pv) { const int lock_timeout = portMAX_DELAY; - int rlen; esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv; client->run = true; @@ -506,22 +537,11 @@ static void esp_websocket_client_task(void *pv) } client->ping_tick_ms = _tick_get_ms(); - rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms); - if (rlen < 0) { - ESP_LOGE(TAG, "Error read data"); + if (esp_websocket_client_recv(client) == ESP_FAIL) { + ESP_LOGE(TAG, "Error receive data"); esp_websocket_client_abort_connection(client); break; } - if (rlen >= 0) { - client->last_opcode = esp_transport_ws_get_read_opcode(client->transport); - esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen); - // if a PING message received -> send out the PONG - if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) { - const char *data = (rlen == 0) ? NULL : client->rx_buffer; - esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, rlen, - client->config->network_timeout_ms); - } - } break; case WEBSOCKET_STATE_WAIT_TIMEOUT: @@ -663,7 +683,8 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client) esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client, esp_websocket_event_id_t event, esp_event_handler_t event_handler, - void* event_handler_arg) { + void *event_handler_arg) +{ if (client == NULL) { return ESP_ERR_INVALID_ARG; } diff --git a/components/esp_websocket_client/include/esp_websocket_client.h b/components/esp_websocket_client/include/esp_websocket_client.h index 0f8c64e31..ae8cc8a4b 100644 --- a/components/esp_websocket_client/include/esp_websocket_client.h +++ b/components/esp_websocket_client/include/esp_websocket_client.h @@ -27,7 +27,7 @@ extern "C" { #endif -typedef struct esp_websocket_client* esp_websocket_client_handle_t; +typedef struct esp_websocket_client *esp_websocket_client_handle_t; ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family @@ -47,11 +47,13 @@ typedef enum { * @brief Websocket event data */ typedef struct { - const char *data_ptr; /*!< Data pointer */ - int data_len; /*!< Data length */ - uint8_t op_code; /*!< Received opcode */ - esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */ - void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */ + const char *data_ptr; /*!< Data pointer */ + int data_len; /*!< Data length */ + uint8_t op_code; /*!< Received opcode */ + esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */ + void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */ + int payload_len; /*!< Total payload length, payloads exceeding buffer will be posted through multiple events */ + int payload_offset; /*!< Actual offset for the data associated with this event */ } esp_websocket_event_data_t; /** @@ -205,7 +207,7 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client); esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client, esp_websocket_event_id_t event, esp_event_handler_t event_handler, - void* event_handler_arg); + void *event_handler_arg); #ifdef __cplusplus } diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 7251e92e2..5e5405791 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -14,6 +14,7 @@ extern "C" { #endif typedef enum ws_transport_opcodes { + WS_TRANSPORT_OPCODES_CONT = 0x00, WS_TRANSPORT_OPCODES_TEXT = 0x01, WS_TRANSPORT_OPCODES_BINARY = 0x02, WS_TRANSPORT_OPCODES_CLOSE = 0x08, @@ -105,6 +106,16 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o */ ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t); +/** + * @brief Returns payload length of the last received data + * + * @param t websocket transport handle + * + * @return + * - Number of bytes in the payload + */ +int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t); + #ifdef __cplusplus } diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index b0d0eca17..77b5d1b21 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -25,16 +25,24 @@ static const char *TAG = "TRANSPORT_WS"; #define WS_MASK 0x80 #define WS_SIZE16 126 #define WS_SIZE64 127 -#define MAX_WEBSOCKET_HEADER_SIZE 10 +#define MAX_WEBSOCKET_HEADER_SIZE 16 #define WS_RESPONSE_OK 101 + +typedef struct { + uint8_t opcode; + char mask_key[4]; /*!< Mask key for this payload */ + int payload_len; /*!< Total length of the payload */ + int bytes_remaining; /*!< Bytes left to read of the payload */ +} ws_transport_frame_state_t; + typedef struct { char *path; char *buffer; char *sub_protocol; char *user_agent; char *headers; - uint8_t read_opcode; + ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; } transport_ws_t; @@ -46,6 +54,11 @@ static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode) static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); + + /* Reading parts of a frame directly will disrupt the WS internal frame state, + reset bytes_remaining to prepare for reading a new frame */ + ws->frame_state.bytes_remaining = 0; + return ws->parent; } @@ -89,7 +102,8 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int { transport_ws_t *ws = esp_transport_get_context_data(t); if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) { - ESP_LOGE(TAG, "Error connect to ther server"); + ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port); + return -1; } unsigned char random_key[16]; @@ -194,16 +208,25 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const ESP_LOGE(TAG, "Error transport_poll_write"); return poll_write; } - ws_header[header_len++] = opcode; - // NOTE: no support for > 16-bit sized messages - if (len > 125) { + if (len <= 125) { + ws_header[header_len++] = (uint8_t)(len | mask_flag); + } else if (len < 65536) { ws_header[header_len++] = WS_SIZE16 | mask_flag; ws_header[header_len++] = (uint8_t)(len >> 8); ws_header[header_len++] = (uint8_t)(len & 0xFF); } else { - ws_header[header_len++] = (uint8_t)(len | mask_flag); + ws_header[header_len++] = WS_SIZE64 | mask_flag; + /* Support maximum 4 bytes length */ + ws_header[header_len++] = 0; //(uint8_t)((len >> 56) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 48) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 40) & 0xFF); + ws_header[header_len++] = 0; //(uint8_t)((len >> 32) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 24) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 16) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 8) & 0xFF); + ws_header[header_len++] = (uint8_t)((len >> 0) & 0xFF); } if (mask_flag) { @@ -215,6 +238,7 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const buffer[i] = (buffer[i] ^ mask[i % 4]); } } + if (esp_transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) { ESP_LOGE(TAG, "Error write header"); return -1; @@ -252,12 +276,46 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms); } -static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) + +static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + + int bytes_to_read; + int rlen = 0; + + if (ws->frame_state.bytes_remaining > len) { + ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", ws->frame_state.bytes_remaining, len); + bytes_to_read = len; + + } else { + bytes_to_read = ws->frame_state.bytes_remaining; + } + + // Receive and process payload + if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; + } + ws->frame_state.bytes_remaining -= rlen; + + if (ws->frame_state.mask_key) { + for (int i = 0; i < bytes_to_read; i++) { + buffer[i] = (buffer[i] ^ ws->frame_state.mask_key[i % 4]); + } + } + return rlen; +} + + +/* Read and parse the WS header, determine length of payload */ +static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); int payload_len; + char ws_header[MAX_WEBSOCKET_HEADER_SIZE]; - char *data_ptr = ws_header, mask, *mask_key = NULL; + char *data_ptr = ws_header, mask; int rlen; int poll_read; if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) { @@ -266,16 +324,17 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ // Receive and process header first (based on header size) int header = 2; + int mask_len = 4; if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } - ws->read_opcode = (*data_ptr & 0x0F); + ws->frame_state.opcode = (*data_ptr & 0x0F); data_ptr ++; mask = ((*data_ptr >> 7) & 0x01); payload_len = (*data_ptr & 0x7F); data_ptr++; - ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->read_opcode , mask, payload_len); + ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { @@ -299,27 +358,48 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ } } - if (payload_len > len) { - ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", payload_len, len); - payload_len = len; - } - - // Then receive and process payload - if ((rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) { - ESP_LOGE(TAG, "Error read data"); - return rlen; - } - if (mask) { - mask_key = buffer; - data_ptr = buffer + 4; - for (int i = 0; i < payload_len; i++) { - buffer[i] = (data_ptr[i] ^ mask_key[i % 4]); + // Read and store mask + if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error read data"); + return rlen; } + memcpy(ws->frame_state.mask_key, buffer, mask_len); + } else { + memset(ws->frame_state.mask_key, 0, mask_len); } + + ws->frame_state.payload_len = payload_len; + ws->frame_state.bytes_remaining = payload_len; + return payload_len; } +static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +{ + int rlen = 0; + transport_ws_t *ws = esp_transport_get_context_data(t); + + // If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload + if (ws->frame_state.bytes_remaining <= 0) { + if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) <= 0) { + // If something when wrong then we prepare for reading a new header + ws->frame_state.bytes_remaining = 0; + return rlen; + } + } + if (ws->frame_state.payload_len) { + if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) { + ESP_LOGE(TAG, "Error reading payload data"); + ws->frame_state.bytes_remaining = 0; + return rlen; + } + } + + return rlen; +} + + static int ws_poll_read(esp_transport_handle_t t, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); @@ -355,6 +435,7 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path) ws->path = realloc(ws->path, strlen(path) + 1); strcpy(ws->path, path); } + esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle) { esp_transport_handle_t t = esp_transport_init(); @@ -363,7 +444,7 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl ws->parent = parent_handle; ws->path = strdup("/"); - ESP_TRANSPORT_MEM_CHECK(TAG, ws->path, { + ESP_TRANSPORT_MEM_CHECK(TAG, ws->path, { free(ws); return NULL; }); @@ -445,5 +526,11 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); - return ws->read_opcode; + return ws->frame_state.opcode; +} + +int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + return ws->frame_state.payload_len; } \ No newline at end of file diff --git a/examples/protocols/websocket/README.md b/examples/protocols/websocket/README.md index e49bda81a..454ad376f 100644 --- a/examples/protocols/websocket/README.md +++ b/examples/protocols/websocket/README.md @@ -34,12 +34,6 @@ See the Getting Started Guide for full steps to configure and use ESP-IDF to bui ## Example Output ``` -I (482) system_api: Base MAC address is not set, read default base MAC address from BLK0 of EFUSE -I (2492) example_connect: Ethernet Link Up -I (4472) tcpip_adapter: eth ip: 192.168.2.137, mask: 255.255.255.0, gw: 192.168.2.2 -I (4472) example_connect: Connected to Ethernet -I (4472) example_connect: IPv4 address: 192.168.2.137 -I (4472) example_connect: IPv6 address: fe80:0000:0000:0000:bedd:c2ff:fed4:a92b I (4482) WEBSOCKET: Connecting to ws://echo.websocket.org... I (5012) WEBSOCKET: WEBSOCKET_EVENT_CONNECTED I (5492) WEBSOCKET: Sending hello 0000 diff --git a/examples/protocols/websocket/example_test.py b/examples/protocols/websocket/example_test.py index eeb0e6034..d861bd441 100644 --- a/examples/protocols/websocket/example_test.py +++ b/examples/protocols/websocket/example_test.py @@ -2,11 +2,15 @@ from __future__ import print_function from __future__ import unicode_literals import re import os +import sys import socket +import select import hashlib import base64 -import sys -from threading import Thread +import queue +import random +import string +from threading import Thread, Event try: import IDF @@ -20,8 +24,6 @@ except Exception: sys.path.insert(0, test_fw_path) import IDF -import DUT - def get_my_ip(): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -43,7 +45,10 @@ class Websocket: def __init__(self, port): self.port = port self.socket = socket.socket() + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.settimeout(10.0) + self.send_q = queue.Queue() + self.shutdown = Event() def __enter__(self): try: @@ -56,23 +61,27 @@ class Websocket: self.server_thread = Thread(target=self.run_server) self.server_thread.start() + return self + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown.set() self.server_thread.join() self.socket.close() self.conn.close() def run_server(self): self.conn, address = self.socket.accept() # accept new connection - self.conn.settimeout(10.0) + self.socket.settimeout(10.0) + print("Connection from: {}".format(address)) self.establish_connection() - - # Echo data until client closes connection - self.echo_data() + print("WS established") + # Handle connection until client closes it, will echo any data received and send data from send_q queue + self.handle_conn() def establish_connection(self): - while True: + while not self.shutdown.is_set(): try: # receive data stream. it won't accept data packet greater than 1024 bytes data = self.conn.recv(1024).decode() @@ -83,6 +92,7 @@ class Websocket: if "Upgrade: websocket" in data and "Connection: Upgrade" in data: self.handshake(data) return + except socket.error as err: print("Unable to establish a websocket connection: {}, {}".format(err)) raise @@ -107,26 +117,46 @@ class Websocket: self.conn.send(resp.encode()) - def echo_data(self): - while(True): + def handle_conn(self): + while not self.shutdown.is_set(): + r,w,e = select.select([self.conn], [], [], 1) try: - header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL)) - if not header: - # exit if data is not received - return + if self.conn in r: + self.echo_data() - # Remove mask bit - payload_len = ~(1 << 7) & header[1] + if not self.send_q.empty(): + self._send_data_(self.send_q.get()) - payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL)) - frame = header + payload - - decoded_payload = self.decode_frame(frame) - - echo_frame = self.encode_frame(decoded_payload) - self.conn.send(echo_frame) except socket.error as err: print("Stopped echoing data: {}".format(err)) + raise + + def echo_data(self): + header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL)) + if not header: + # exit if socket closed by peer + return + + # Remove mask bit + payload_len = ~(1 << 7) & header[1] + + payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL)) + + if not payload: + # exit if socket closed by peer + return + frame = header + payload + + decoded_payload = self.decode_frame(frame) + print("Sending echo...") + self._send_data_(decoded_payload) + + def _send_data_(self, data): + frame = self.encode_frame(data) + self.conn.send(frame) + + def send_data(self, data): + self.send_q.put(data.encode()) def decode_frame(self, frame): # Mask out MASK bit from payload length, this len is only valid for short messages (<126) @@ -147,7 +177,17 @@ class Websocket: header = (1 << 7) | (1 << 0) frame = bytearray([header]) - frame.append(len(payload)) + payload_len = len(payload) + + # If payload len is longer than 125 then the next 16 bits are used to encode length + if payload_len > 125: + frame.append(126) + frame.append(payload_len >> 8) + frame.append(0xFF & payload_len) + + else: + frame.append(payload_len) + frame += payload return frame @@ -156,8 +196,27 @@ class Websocket: def test_echo(dut): dut.expect("WEBSOCKET_EVENT_CONNECTED") for i in range(0, 10): - dut.expect(re.compile(r"Received=hello (\d)")) - dut.expect("Websocket Stopped") + dut.expect(re.compile(r"Received=hello (\d)"), timeout=30) + print("All echos received") + + +def test_recv_long_msg(dut, websocket, msg_len, repeats): + send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len)) + + for _ in range(repeats): + websocket.send_data(send_msg) + + recv_msg = '' + while len(recv_msg) < msg_len: + # Filter out color encoding + match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0] + recv_msg += match + + if recv_msg == send_msg: + print("Sent message and received message are equal") + else: + raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\ + \nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg))) @IDF.idf_example_test(env_tag="Example_WIFI") @@ -191,12 +250,14 @@ def test_examples_protocol_websocket(env, extra_data): if uri_from_stdin: server_port = 4455 - with Websocket(server_port): + with Websocket(server_port) as ws: uri = "ws://{}:{}".format(get_my_ip(), server_port) print("DUT connecting to {}".format(uri)) dut1.expect("Please enter uri of websocket endpoint", timeout=30) dut1.write(uri) test_echo(dut1) + # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte + test_recv_long_msg(dut1, ws, 2000, 3) else: print("DUT connecting to {}".format(uri)) diff --git a/examples/protocols/websocket/main/websocket_example.c b/examples/protocols/websocket/main/websocket_example.c index d74474914..8b4bf8600 100644 --- a/examples/protocols/websocket/main/websocket_example.c +++ b/examples/protocols/websocket/main/websocket_example.c @@ -1,4 +1,4 @@ -/* ESP Websocket Client Example +/* ESP HTTP Client Example This example code is in the Public Domain (or CC0 licensed, at your option.) @@ -16,6 +16,7 @@ #include "freertos/FreeRTOS.h" #include "freertos/task.h" +#include "freertos/semphr.h" #include "freertos/event_groups.h" @@ -23,11 +24,22 @@ #include "esp_websocket_client.h" #include "esp_event.h" +#define NO_DATA_TIMEOUT_SEC 10 + static const char *TAG = "WEBSOCKET"; static EventGroupHandle_t wifi_event_group; const static int CONNECTED_BIT = BIT0; +static TimerHandle_t shutdown_signal_timer; +static SemaphoreHandle_t shutdown_sema; + +static void shutdown_signaler(TimerHandle_t xTimer) +{ + ESP_LOGI(TAG, "No data received for %d seconds, signaling shutdown", NO_DATA_TIMEOUT_SEC); + xSemaphoreGive(shutdown_sema); +} + #if CONFIG_WEBSOCKET_URI_FROM_STDIN static void get_string(char *line, size_t size) { @@ -51,20 +63,23 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i { esp_websocket_event_data_t *data = (esp_websocket_event_data_t *)event_data; switch (event_id) { - case WEBSOCKET_EVENT_CONNECTED: - ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED"); - break; - case WEBSOCKET_EVENT_DISCONNECTED: - ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED"); - break; - case WEBSOCKET_EVENT_DATA: - ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA"); - ESP_LOGI(TAG, "Received opcode=%d", data->op_code); - ESP_LOGW(TAG, "Received=%.*s\r\n", data->data_len, (char*)data->data_ptr); - break; - case WEBSOCKET_EVENT_ERROR: - ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR"); - break; + case WEBSOCKET_EVENT_CONNECTED: + ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED"); + break; + case WEBSOCKET_EVENT_DISCONNECTED: + ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED"); + break; + case WEBSOCKET_EVENT_DATA: + ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA"); + ESP_LOGI(TAG, "Received opcode=%d", data->op_code); + ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr); + ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset); + + xTimerReset(shutdown_signal_timer, portMAX_DELAY); + break; + case WEBSOCKET_EVENT_ERROR: + ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR"); + break; } } @@ -114,7 +129,11 @@ static void websocket_app_start(void) { esp_websocket_client_config_t websocket_cfg = {}; - #if CONFIG_WEBSOCKET_URI_FROM_STDIN + shutdown_signal_timer = xTimerCreate("Websocket shutdown timer", NO_DATA_TIMEOUT_SEC * 1000 / portTICK_PERIOD_MS, + pdFALSE, NULL, shutdown_signaler); + shutdown_sema = xSemaphoreCreateBinary(); + +#if CONFIG_WEBSOCKET_URI_FROM_STDIN char line[128]; ESP_LOGI(TAG, "Please enter uri of websocket endpoint"); @@ -123,10 +142,10 @@ static void websocket_app_start(void) websocket_cfg.uri = line; ESP_LOGI(TAG, "Endpoint uri: %s\n", line); - #else +#else websocket_cfg.uri = CONFIG_WEBSOCKET_URI; - #endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */ +#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */ ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri); @@ -134,6 +153,7 @@ static void websocket_app_start(void) esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client); esp_websocket_client_start(client); + xTimerStart(shutdown_signal_timer, portMAX_DELAY); char data[32]; int i = 0; while (i < 10) { @@ -144,8 +164,8 @@ static void websocket_app_start(void) } vTaskDelay(1000 / portTICK_RATE_MS); } - // Give server some time to respond before closing - vTaskDelay(3000 / portTICK_RATE_MS); + + xSemaphoreTake(shutdown_sema, portMAX_DELAY); esp_websocket_client_stop(client); ESP_LOGI(TAG, "Websocket Stopped"); esp_websocket_client_destroy(client); @@ -164,4 +184,4 @@ void app_main(void) nvs_flash_init(); wifi_init(); websocket_app_start(); -} \ No newline at end of file +}