Merge branch 'bugfix/ws_rcv_exceed_buf_v4.0' into 'release/v4.0'
tcp_transport/ws_client: websockets now correctly handle messages longer than buffer (backport v4.0) See merge request espressif/esp-idf!7755
This commit is contained in:
commit
2a467d17bd
6 changed files with 288 additions and 95 deletions
|
@ -91,9 +91,11 @@ struct esp_websocket_client {
|
|||
char *tx_buffer;
|
||||
int buffer_size;
|
||||
ws_transport_opcodes_t last_opcode;
|
||||
int payload_len;
|
||||
int payload_offset;
|
||||
};
|
||||
|
||||
static uint64_t _tick_get_ms()
|
||||
static uint64_t _tick_get_ms(void)
|
||||
{
|
||||
return esp_timer_get_time()/1000;
|
||||
}
|
||||
|
@ -112,6 +114,8 @@ static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle
|
|||
event_data.data_ptr = data;
|
||||
event_data.data_len = data_len;
|
||||
event_data.op_code = client->last_opcode;
|
||||
event_data.payload_len = client->payload_len;
|
||||
event_data.payload_offset = client->payload_offset;
|
||||
|
||||
if ((err = esp_event_post_to(client->event_handle,
|
||||
WEBSOCKET_EVENTS, event,
|
||||
|
@ -426,10 +430,38 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
|
|||
return ESP_OK;
|
||||
}
|
||||
|
||||
static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
|
||||
{
|
||||
int rlen;
|
||||
client->payload_offset = 0;
|
||||
do {
|
||||
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
|
||||
if (rlen < 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
esp_websocket_client_abort_connection(client);
|
||||
return ESP_FAIL;
|
||||
}
|
||||
client->payload_len = esp_transport_ws_get_read_payload_len(client->transport);
|
||||
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
|
||||
|
||||
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
|
||||
|
||||
client->payload_offset += rlen;
|
||||
} while (client->payload_offset < client->payload_len);
|
||||
|
||||
// if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len
|
||||
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
|
||||
const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
|
||||
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, client->payload_len,
|
||||
client->config->network_timeout_ms);
|
||||
}
|
||||
|
||||
return ESP_OK;
|
||||
}
|
||||
|
||||
static void esp_websocket_client_task(void *pv)
|
||||
{
|
||||
const int lock_timeout = portMAX_DELAY;
|
||||
int rlen;
|
||||
esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv;
|
||||
client->run = true;
|
||||
|
||||
|
@ -480,31 +512,20 @@ static void esp_websocket_client_task(void *pv)
|
|||
ESP_LOGD(TAG, "Sending PING...");
|
||||
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING, NULL, 0, client->config->network_timeout_ms);
|
||||
}
|
||||
|
||||
if (read_select == 0) {
|
||||
ESP_LOGV(TAG, "Read poll timeout: skipping esp_transport_read()...");
|
||||
break;
|
||||
}
|
||||
client->ping_tick_ms = _tick_get_ms();
|
||||
|
||||
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
|
||||
if (rlen < 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
if (esp_websocket_client_recv(client) == ESP_FAIL) {
|
||||
ESP_LOGE(TAG, "Error receive data");
|
||||
esp_websocket_client_abort_connection(client);
|
||||
break;
|
||||
}
|
||||
if (rlen >= 0) {
|
||||
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
|
||||
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
|
||||
// if a PING message received -> send out the PONG
|
||||
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
|
||||
const char *data = (rlen == 0) ? NULL : client->rx_buffer;
|
||||
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, rlen,
|
||||
client->config->network_timeout_ms);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case WEBSOCKET_STATE_WAIT_TIMEOUT:
|
||||
|
||||
if (!client->config->auto_reconnect) {
|
||||
client->run = false;
|
||||
break;
|
||||
|
@ -617,7 +638,8 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c
|
|||
}
|
||||
memcpy(client->tx_buffer, data + widx, need_write);
|
||||
// send with ws specific way and specific opcode
|
||||
wlen = esp_transport_ws_send_raw(client->transport, opcode, (char *)client->tx_buffer, need_write, timeout);
|
||||
wlen = esp_transport_ws_send_raw(client->transport, opcode, (char *)client->tx_buffer, need_write,
|
||||
(timeout==portMAX_DELAY)? -1 : timeout * portTICK_PERIOD_MS);
|
||||
if (wlen <= 0) {
|
||||
ret = wlen;
|
||||
ESP_LOGE(TAG, "Network error: esp_transport_write() returned %d, errno=%d", ret, errno);
|
||||
|
@ -643,7 +665,8 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client)
|
|||
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
|
||||
esp_websocket_event_id_t event,
|
||||
esp_event_handler_t event_handler,
|
||||
void* event_handler_arg) {
|
||||
void *event_handler_arg)
|
||||
{
|
||||
if (client == NULL) {
|
||||
return ESP_ERR_INVALID_ARG;
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct esp_websocket_client* esp_websocket_client_handle_t;
|
||||
typedef struct esp_websocket_client *esp_websocket_client_handle_t;
|
||||
|
||||
ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family
|
||||
|
||||
|
@ -48,11 +48,13 @@ typedef enum {
|
|||
* @brief Websocket event data
|
||||
*/
|
||||
typedef struct {
|
||||
const char *data_ptr; /*!< Data pointer */
|
||||
int data_len; /*!< Data length */
|
||||
uint8_t op_code; /*!< Received opcode */
|
||||
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
|
||||
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
|
||||
const char *data_ptr; /*!< Data pointer */
|
||||
int data_len; /*!< Data length */
|
||||
uint8_t op_code; /*!< Received opcode */
|
||||
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
|
||||
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
|
||||
int payload_len; /*!< Total payload length, payloads exceeding buffer will be posted through multiple events */
|
||||
int payload_offset; /*!< Actual offset for the data associated with this event */
|
||||
} esp_websocket_event_data_t;
|
||||
|
||||
/**
|
||||
|
@ -204,7 +206,7 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client);
|
|||
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
|
||||
esp_websocket_event_id_t event,
|
||||
esp_event_handler_t event_handler,
|
||||
void* event_handler_arg);
|
||||
void *event_handler_arg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
typedef enum ws_transport_opcodes {
|
||||
WS_TRANSPORT_OPCODES_CONT = 0x00,
|
||||
WS_TRANSPORT_OPCODES_TEXT = 0x01,
|
||||
WS_TRANSPORT_OPCODES_BINARY = 0x02,
|
||||
WS_TRANSPORT_OPCODES_CLOSE = 0x08,
|
||||
|
@ -81,6 +82,16 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o
|
|||
*/
|
||||
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t);
|
||||
|
||||
/**
|
||||
* @brief Returns payload length of the last received data
|
||||
*
|
||||
* @param t websocket transport handle
|
||||
*
|
||||
* @return
|
||||
* - Number of bytes in the payload
|
||||
*/
|
||||
int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -28,11 +28,19 @@ static const char *TAG = "TRANSPORT_WS";
|
|||
#define MAX_WEBSOCKET_HEADER_SIZE 16
|
||||
#define WS_RESPONSE_OK 101
|
||||
|
||||
|
||||
typedef struct {
|
||||
uint8_t opcode;
|
||||
char mask_key[4]; /*!< Mask key for this payload */
|
||||
int payload_len; /*!< Total length of the payload */
|
||||
int bytes_remaining; /*!< Bytes left to read of the payload */
|
||||
} ws_transport_frame_state_t;
|
||||
|
||||
typedef struct {
|
||||
char *path;
|
||||
char *buffer;
|
||||
char *sub_protocol;
|
||||
uint8_t read_opcode;
|
||||
ws_transport_frame_state_t frame_state;
|
||||
esp_transport_handle_t parent;
|
||||
} transport_ws_t;
|
||||
|
||||
|
@ -44,6 +52,11 @@ static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode)
|
|||
static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t)
|
||||
{
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
|
||||
/* Reading parts of a frame directly will disrupt the WS internal frame state,
|
||||
reset bytes_remaining to prepare for reading a new frame */
|
||||
ws->frame_state.bytes_remaining = 0;
|
||||
|
||||
return ws->parent;
|
||||
}
|
||||
|
||||
|
@ -234,7 +247,7 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const
|
|||
for (i = 0; i < len; ++i) {
|
||||
buffer[i] = (buffer[i] ^ mask[i % 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -261,12 +274,46 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou
|
|||
return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms);
|
||||
}
|
||||
|
||||
static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
||||
|
||||
static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
||||
{
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
|
||||
int bytes_to_read;
|
||||
int rlen = 0;
|
||||
|
||||
if (ws->frame_state.bytes_remaining > len) {
|
||||
ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", ws->frame_state.bytes_remaining, len);
|
||||
bytes_to_read = len;
|
||||
|
||||
} else {
|
||||
bytes_to_read = ws->frame_state.bytes_remaining;
|
||||
}
|
||||
|
||||
// Receive and process payload
|
||||
if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
ws->frame_state.bytes_remaining -= rlen;
|
||||
|
||||
if (ws->frame_state.mask_key) {
|
||||
for (int i = 0; i < bytes_to_read; i++) {
|
||||
buffer[i] = (buffer[i] ^ ws->frame_state.mask_key[i % 4]);
|
||||
}
|
||||
}
|
||||
return rlen;
|
||||
}
|
||||
|
||||
|
||||
/* Read and parse the WS header, determine length of payload */
|
||||
static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
||||
{
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
int payload_len;
|
||||
|
||||
char ws_header[MAX_WEBSOCKET_HEADER_SIZE];
|
||||
char *data_ptr = ws_header, mask, *mask_key = NULL;
|
||||
char *data_ptr = ws_header, mask;
|
||||
int rlen;
|
||||
int poll_read;
|
||||
if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) {
|
||||
|
@ -275,16 +322,17 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_
|
|||
|
||||
// Receive and process header first (based on header size)
|
||||
int header = 2;
|
||||
int mask_len = 4;
|
||||
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
ws->read_opcode = (*data_ptr & 0x0F);
|
||||
ws->frame_state.opcode = (*data_ptr & 0x0F);
|
||||
data_ptr ++;
|
||||
mask = ((*data_ptr >> 7) & 0x01);
|
||||
payload_len = (*data_ptr & 0x7F);
|
||||
data_ptr++;
|
||||
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->read_opcode, mask, payload_len);
|
||||
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->frame_state.opcode, mask, payload_len);
|
||||
if (payload_len == 126) {
|
||||
// headerLen += 2;
|
||||
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
|
||||
|
@ -308,25 +356,45 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_
|
|||
}
|
||||
}
|
||||
|
||||
if (payload_len > len) {
|
||||
ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", payload_len, len);
|
||||
payload_len = len;
|
||||
}
|
||||
|
||||
// Then receive and process payload
|
||||
if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
mask_key = buffer;
|
||||
data_ptr = buffer + 4;
|
||||
for (int i = 0; i < payload_len; i++) {
|
||||
buffer[i] = (data_ptr[i] ^ mask_key[i % 4]);
|
||||
// Read and store mask
|
||||
if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
memcpy(ws->frame_state.mask_key, buffer, mask_len);
|
||||
} else {
|
||||
memset(ws->frame_state.mask_key, 0, mask_len);
|
||||
}
|
||||
|
||||
ws->frame_state.payload_len = payload_len;
|
||||
ws->frame_state.bytes_remaining = payload_len;
|
||||
|
||||
return payload_len;
|
||||
}
|
||||
|
||||
static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
||||
{
|
||||
int rlen = 0;
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
|
||||
// If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload
|
||||
if (ws->frame_state.bytes_remaining <= 0) {
|
||||
if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) <= 0) {
|
||||
// If something when wrong then we prepare for reading a new header
|
||||
ws->frame_state.bytes_remaining = 0;
|
||||
return rlen;
|
||||
}
|
||||
}
|
||||
return payload_len;
|
||||
if (ws->frame_state.payload_len) {
|
||||
if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error reading payload data");
|
||||
ws->frame_state.bytes_remaining = 0;
|
||||
return rlen;
|
||||
}
|
||||
}
|
||||
|
||||
return rlen;
|
||||
}
|
||||
|
||||
static int ws_poll_read(esp_transport_handle_t t, int timeout_ms)
|
||||
|
@ -413,5 +481,13 @@ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char
|
|||
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
|
||||
{
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
return ws->read_opcode;
|
||||
return ws->frame_state.opcode;
|
||||
}
|
||||
|
||||
int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)
|
||||
{
|
||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||
return ws->frame_state.payload_len;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3,10 +3,13 @@ from __future__ import unicode_literals
|
|||
import re
|
||||
import os
|
||||
import socket
|
||||
import select
|
||||
import hashlib
|
||||
import base64
|
||||
from threading import Thread
|
||||
|
||||
import queue
|
||||
import random
|
||||
import string
|
||||
from threading import Thread, Event
|
||||
import ttfw_idf
|
||||
|
||||
|
||||
|
@ -30,7 +33,10 @@ class Websocket:
|
|||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.socket = socket.socket()
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.settimeout(10.0)
|
||||
self.send_q = queue.Queue()
|
||||
self.shutdown = Event()
|
||||
|
||||
def __enter__(self):
|
||||
try:
|
||||
|
@ -43,23 +49,27 @@ class Websocket:
|
|||
self.server_thread = Thread(target=self.run_server)
|
||||
self.server_thread.start()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.shutdown.set()
|
||||
self.server_thread.join()
|
||||
self.socket.close()
|
||||
self.conn.close()
|
||||
|
||||
def run_server(self):
|
||||
self.conn, address = self.socket.accept() # accept new connection
|
||||
self.conn.settimeout(10.0)
|
||||
self.socket.settimeout(10.0)
|
||||
|
||||
print("Connection from: {}".format(address))
|
||||
|
||||
self.establish_connection()
|
||||
|
||||
# Echo data until client closes connection
|
||||
self.echo_data()
|
||||
print("WS established")
|
||||
# Handle connection until client closes it, will echo any data received and send data from send_q queue
|
||||
self.handle_conn()
|
||||
|
||||
def establish_connection(self):
|
||||
while True:
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
# receive data stream. it won't accept data packet greater than 1024 bytes
|
||||
data = self.conn.recv(1024).decode()
|
||||
|
@ -70,6 +80,7 @@ class Websocket:
|
|||
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
|
||||
self.handshake(data)
|
||||
return
|
||||
|
||||
except socket.error as err:
|
||||
print("Unable to establish a websocket connection: {}, {}".format(err))
|
||||
raise
|
||||
|
@ -94,26 +105,46 @@ class Websocket:
|
|||
|
||||
self.conn.send(resp.encode())
|
||||
|
||||
def echo_data(self):
|
||||
while(True):
|
||||
def handle_conn(self):
|
||||
while not self.shutdown.is_set():
|
||||
r,w,e = select.select([self.conn], [], [], 1)
|
||||
try:
|
||||
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
|
||||
if not header:
|
||||
# exit if data is not received
|
||||
return
|
||||
if self.conn in r:
|
||||
self.echo_data()
|
||||
|
||||
# Remove mask bit
|
||||
payload_len = ~(1 << 7) & header[1]
|
||||
if not self.send_q.empty():
|
||||
self._send_data_(self.send_q.get())
|
||||
|
||||
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
|
||||
frame = header + payload
|
||||
|
||||
decoded_payload = self.decode_frame(frame)
|
||||
|
||||
echo_frame = self.encode_frame(decoded_payload)
|
||||
self.conn.send(echo_frame)
|
||||
except socket.error as err:
|
||||
print("Stopped echoing data: {}".format(err))
|
||||
raise
|
||||
|
||||
def echo_data(self):
|
||||
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
|
||||
if not header:
|
||||
# exit if socket closed by peer
|
||||
return
|
||||
|
||||
# Remove mask bit
|
||||
payload_len = ~(1 << 7) & header[1]
|
||||
|
||||
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
|
||||
|
||||
if not payload:
|
||||
# exit if socket closed by peer
|
||||
return
|
||||
frame = header + payload
|
||||
|
||||
decoded_payload = self.decode_frame(frame)
|
||||
print("Sending echo...")
|
||||
self._send_data_(decoded_payload)
|
||||
|
||||
def _send_data_(self, data):
|
||||
frame = self.encode_frame(data)
|
||||
self.conn.send(frame)
|
||||
|
||||
def send_data(self, data):
|
||||
self.send_q.put(data.encode())
|
||||
|
||||
def decode_frame(self, frame):
|
||||
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
|
||||
|
@ -133,8 +164,18 @@ class Websocket:
|
|||
# Set FIN = 1 and OP_CODE = 1 (text)
|
||||
header = (1 << 7) | (1 << 0)
|
||||
|
||||
frame = bytearray(header)
|
||||
frame.append(len(payload))
|
||||
frame = bytearray([header])
|
||||
payload_len = len(payload)
|
||||
|
||||
# If payload len is longer than 125 then the next 16 bits are used to encode length
|
||||
if payload_len > 125:
|
||||
frame.append(126)
|
||||
frame.append(payload_len >> 8)
|
||||
frame.append(0xFF & payload_len)
|
||||
|
||||
else:
|
||||
frame.append(payload_len)
|
||||
|
||||
frame += payload
|
||||
|
||||
return frame
|
||||
|
@ -143,8 +184,27 @@ class Websocket:
|
|||
def test_echo(dut):
|
||||
dut.expect("WEBSOCKET_EVENT_CONNECTED")
|
||||
for i in range(0, 10):
|
||||
dut.expect(re.compile(r"Received=hello (\d)"))
|
||||
dut.expect("Websocket Stopped")
|
||||
dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
|
||||
print("All echos received")
|
||||
|
||||
|
||||
def test_recv_long_msg(dut, websocket, msg_len, repeats):
|
||||
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
|
||||
|
||||
for _ in range(repeats):
|
||||
websocket.send_data(send_msg)
|
||||
|
||||
recv_msg = ''
|
||||
while len(recv_msg) < msg_len:
|
||||
# Filter out color encoding
|
||||
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
|
||||
recv_msg += match
|
||||
|
||||
if recv_msg == send_msg:
|
||||
print("Sent message and received message are equal")
|
||||
else:
|
||||
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
|
||||
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag="Example_WIFI")
|
||||
|
@ -178,12 +238,14 @@ def test_examples_protocol_websocket(env, extra_data):
|
|||
|
||||
if uri_from_stdin:
|
||||
server_port = 4455
|
||||
with Websocket(server_port):
|
||||
with Websocket(server_port) as ws:
|
||||
uri = "ws://{}:{}".format(get_my_ip(), server_port)
|
||||
print("DUT connecting to {}".format(uri))
|
||||
dut1.expect("Please enter uri of websocket endpoint", timeout=30)
|
||||
dut1.write(uri)
|
||||
test_echo(dut1)
|
||||
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
|
||||
test_recv_long_msg(dut1, ws, 2000, 3)
|
||||
|
||||
else:
|
||||
print("DUT connecting to {}".format(uri))
|
||||
|
|
|
@ -17,16 +17,27 @@
|
|||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include "freertos/task.h"
|
||||
#include "freertos/semphr.h"
|
||||
#include "freertos/event_groups.h"
|
||||
|
||||
|
||||
#include "esp_log.h"
|
||||
#include "esp_websocket_client.h"
|
||||
#include "esp_event.h"
|
||||
#include "esp_event_loop.h"
|
||||
|
||||
#define NO_DATA_TIMEOUT_SEC 10
|
||||
|
||||
static const char *TAG = "WEBSOCKET";
|
||||
|
||||
static TimerHandle_t shutdown_signal_timer;
|
||||
static SemaphoreHandle_t shutdown_sema;
|
||||
|
||||
static void shutdown_signaler(TimerHandle_t xTimer)
|
||||
{
|
||||
ESP_LOGI(TAG, "No data received for %d seconds, signaling shutdown", NO_DATA_TIMEOUT_SEC);
|
||||
xSemaphoreGive(shutdown_sema);
|
||||
}
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
static void get_string(char *line, size_t size)
|
||||
{
|
||||
|
@ -50,20 +61,23 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
|
|||
{
|
||||
esp_websocket_event_data_t *data = (esp_websocket_event_data_t *)event_data;
|
||||
switch (event_id) {
|
||||
case WEBSOCKET_EVENT_CONNECTED:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED");
|
||||
break;
|
||||
case WEBSOCKET_EVENT_DISCONNECTED:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED");
|
||||
break;
|
||||
case WEBSOCKET_EVENT_DATA:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
|
||||
ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
|
||||
ESP_LOGW(TAG, "Received=%.*s\r\n", data->data_len, (char*)data->data_ptr);
|
||||
break;
|
||||
case WEBSOCKET_EVENT_ERROR:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
|
||||
break;
|
||||
case WEBSOCKET_EVENT_CONNECTED:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED");
|
||||
break;
|
||||
case WEBSOCKET_EVENT_DISCONNECTED:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED");
|
||||
break;
|
||||
case WEBSOCKET_EVENT_DATA:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
|
||||
ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
|
||||
ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
|
||||
ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
|
||||
|
||||
xTimerReset(shutdown_signal_timer, portMAX_DELAY);
|
||||
break;
|
||||
case WEBSOCKET_EVENT_ERROR:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,7 +85,11 @@ static void websocket_app_start(void)
|
|||
{
|
||||
esp_websocket_client_config_t websocket_cfg = {};
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
shutdown_signal_timer = xTimerCreate("Websocket shutdown timer", NO_DATA_TIMEOUT_SEC * 1000 / portTICK_PERIOD_MS,
|
||||
pdFALSE, NULL, shutdown_signaler);
|
||||
shutdown_sema = xSemaphoreCreateBinary();
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
char line[128];
|
||||
|
||||
ESP_LOGI(TAG, "Please enter uri of websocket endpoint");
|
||||
|
@ -80,10 +98,10 @@ static void websocket_app_start(void)
|
|||
websocket_cfg.uri = line;
|
||||
ESP_LOGI(TAG, "Endpoint uri: %s\n", line);
|
||||
|
||||
#else
|
||||
#else
|
||||
websocket_cfg.uri = CONFIG_WEBSOCKET_URI;
|
||||
|
||||
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
|
||||
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
|
||||
|
||||
ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri);
|
||||
|
||||
|
@ -91,6 +109,7 @@ static void websocket_app_start(void)
|
|||
esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client);
|
||||
|
||||
esp_websocket_client_start(client);
|
||||
xTimerStart(shutdown_signal_timer, portMAX_DELAY);
|
||||
char data[32];
|
||||
int i = 0;
|
||||
while (i < 10) {
|
||||
|
@ -101,8 +120,8 @@ static void websocket_app_start(void)
|
|||
}
|
||||
vTaskDelay(1000 / portTICK_RATE_MS);
|
||||
}
|
||||
// Give server some time to respond before closing
|
||||
vTaskDelay(3000 / portTICK_RATE_MS);
|
||||
|
||||
xSemaphoreTake(shutdown_sema, portMAX_DELAY);
|
||||
esp_websocket_client_stop(client);
|
||||
ESP_LOGI(TAG, "Websocket Stopped");
|
||||
esp_websocket_client_destroy(client);
|
||||
|
|
Loading…
Reference in a new issue