Merge branch 'bugfix/ws_rcv_exceed_buf' into 'master'

tcp_transport/ws_client: websockets now correctly handle messages longer than buffer

Closes IDF-1084 and IDF-1083

See merge request espressif/esp-idf!6740
This commit is contained in:
Angus Gratton 2020-01-09 13:18:31 +08:00
commit 4ece6eedae
6 changed files with 288 additions and 95 deletions

View file

@ -93,6 +93,8 @@ struct esp_websocket_client {
char *tx_buffer;
int buffer_size;
ws_transport_opcodes_t last_opcode;
int payload_len;
int payload_offset;
};
static uint64_t _tick_get_ms(void)
@ -101,19 +103,20 @@ static uint64_t _tick_get_ms(void)
}
static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event,
const char *data,
int data_len)
esp_websocket_event_id_t event,
const char *data,
int data_len)
{
esp_err_t err;
esp_websocket_event_data_t event_data;
event_data.client = client;
event_data.user_context = client->config->user_context;
event_data.data_ptr = data;
event_data.data_len = data_len;
event_data.op_code = client->last_opcode;
event_data.payload_len = client->payload_len;
event_data.payload_offset = client->payload_offset;
if ((err = esp_event_post_to(client->event_handle,
WEBSOCKET_EVENTS, event,
@ -446,10 +449,38 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
return ESP_OK;
}
static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
{
int rlen;
client->payload_offset = 0;
do {
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
if (rlen < 0) {
ESP_LOGE(TAG, "Error read data");
esp_websocket_client_abort_connection(client);
return ESP_FAIL;
}
client->payload_len = esp_transport_ws_get_read_payload_len(client->transport);
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
client->payload_offset += rlen;
} while (client->payload_offset < client->payload_len);
// if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, client->payload_len,
client->config->network_timeout_ms);
}
return ESP_OK;
}
static void esp_websocket_client_task(void *pv)
{
const int lock_timeout = portMAX_DELAY;
int rlen;
esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv;
client->run = true;
@ -506,22 +537,11 @@ static void esp_websocket_client_task(void *pv)
}
client->ping_tick_ms = _tick_get_ms();
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
if (rlen < 0) {
ESP_LOGE(TAG, "Error read data");
if (esp_websocket_client_recv(client) == ESP_FAIL) {
ESP_LOGE(TAG, "Error receive data");
esp_websocket_client_abort_connection(client);
break;
}
if (rlen >= 0) {
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
// if a PING message received -> send out the PONG
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
const char *data = (rlen == 0) ? NULL : client->rx_buffer;
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, rlen,
client->config->network_timeout_ms);
}
}
break;
case WEBSOCKET_STATE_WAIT_TIMEOUT:
@ -663,7 +683,8 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client)
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event,
esp_event_handler_t event_handler,
void* event_handler_arg) {
void *event_handler_arg)
{
if (client == NULL) {
return ESP_ERR_INVALID_ARG;
}

View file

@ -27,7 +27,7 @@
extern "C" {
#endif
typedef struct esp_websocket_client* esp_websocket_client_handle_t;
typedef struct esp_websocket_client *esp_websocket_client_handle_t;
ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family
@ -47,11 +47,13 @@ typedef enum {
* @brief Websocket event data
*/
typedef struct {
const char *data_ptr; /*!< Data pointer */
int data_len; /*!< Data length */
uint8_t op_code; /*!< Received opcode */
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
const char *data_ptr; /*!< Data pointer */
int data_len; /*!< Data length */
uint8_t op_code; /*!< Received opcode */
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
int payload_len; /*!< Total payload length, payloads exceeding buffer will be posted through multiple events */
int payload_offset; /*!< Actual offset for the data associated with this event */
} esp_websocket_event_data_t;
/**
@ -205,7 +207,7 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client);
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event,
esp_event_handler_t event_handler,
void* event_handler_arg);
void *event_handler_arg);
#ifdef __cplusplus
}

View file

@ -14,6 +14,7 @@ extern "C" {
#endif
typedef enum ws_transport_opcodes {
WS_TRANSPORT_OPCODES_CONT = 0x00,
WS_TRANSPORT_OPCODES_TEXT = 0x01,
WS_TRANSPORT_OPCODES_BINARY = 0x02,
WS_TRANSPORT_OPCODES_CLOSE = 0x08,
@ -105,6 +106,16 @@ int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t o
*/
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t);
/**
* @brief Returns payload length of the last received data
*
* @param t websocket transport handle
*
* @return
* - Number of bytes in the payload
*/
int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t);
#ifdef __cplusplus
}

View file

@ -2,7 +2,6 @@
#include <string.h>
#include <ctype.h>
#include <sys/random.h>
#include "esp_log.h"
#include "esp_transport.h"
#include "esp_transport_tcp.h"
@ -15,11 +14,13 @@ static const char *TAG = "TRANSPORT_WS";
#define DEFAULT_WS_BUFFER (1024)
#define WS_FIN 0x80
#define WS_OPCODE_CONT 0x00
#define WS_OPCODE_TEXT 0x01
#define WS_OPCODE_BINARY 0x02
#define WS_OPCODE_CLOSE 0x08
#define WS_OPCODE_PING 0x09
#define WS_OPCODE_PONG 0x0a
// Second byte
#define WS_MASK 0x80
#define WS_SIZE16 126
@ -27,13 +28,21 @@ static const char *TAG = "TRANSPORT_WS";
#define MAX_WEBSOCKET_HEADER_SIZE 16
#define WS_RESPONSE_OK 101
typedef struct {
uint8_t opcode;
char mask_key[4]; /*!< Mask key for this payload */
int payload_len; /*!< Total length of the payload */
int bytes_remaining; /*!< Bytes left to read of the payload */
} ws_transport_frame_state_t;
typedef struct {
char *path;
char *buffer;
char *sub_protocol;
char *user_agent;
char *headers;
uint8_t read_opcode;
ws_transport_frame_state_t frame_state;
esp_transport_handle_t parent;
} transport_ws_t;
@ -45,6 +54,11 @@ static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode)
static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
/* Reading parts of a frame directly will disrupt the WS internal frame state,
reset bytes_remaining to prepare for reading a new frame */
ws->frame_state.bytes_remaining = 0;
return ws->parent;
}
@ -275,12 +289,46 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou
return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms);
}
static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
int bytes_to_read;
int rlen = 0;
if (ws->frame_state.bytes_remaining > len) {
ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", ws->frame_state.bytes_remaining, len);
bytes_to_read = len;
} else {
bytes_to_read = ws->frame_state.bytes_remaining;
}
// Receive and process payload
if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data");
return rlen;
}
ws->frame_state.bytes_remaining -= rlen;
if (ws->frame_state.mask_key) {
for (int i = 0; i < bytes_to_read; i++) {
buffer[i] = (buffer[i] ^ ws->frame_state.mask_key[i % 4]);
}
}
return rlen;
}
/* Read and parse the WS header, determine length of payload */
static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
int payload_len;
char ws_header[MAX_WEBSOCKET_HEADER_SIZE];
char *data_ptr = ws_header, mask, *mask_key = NULL;
char *data_ptr = ws_header, mask;
int rlen;
int poll_read;
if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) {
@ -289,16 +337,17 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_
// Receive and process header first (based on header size)
int header = 2;
int mask_len = 4;
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data");
return rlen;
}
ws->read_opcode = (*data_ptr & 0x0F);
ws->frame_state.opcode = (*data_ptr & 0x0F);
data_ptr ++;
mask = ((*data_ptr >> 7) & 0x01);
payload_len = (*data_ptr & 0x7F);
data_ptr++;
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->read_opcode, mask, payload_len);
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->frame_state.opcode, mask, payload_len);
if (payload_len == 126) {
// headerLen += 2;
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
@ -322,27 +371,48 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_
}
}
if (payload_len > len) {
ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", payload_len, len);
payload_len = len;
}
// Then receive and process payload
if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data");
return rlen;
}
if (mask) {
mask_key = buffer;
data_ptr = buffer + 4;
for (int i = 0; i < payload_len; i++) {
buffer[i] = (data_ptr[i] ^ mask_key[i % 4]);
// Read and store mask
if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data");
return rlen;
}
memcpy(ws->frame_state.mask_key, buffer, mask_len);
} else {
memset(ws->frame_state.mask_key, 0, mask_len);
}
ws->frame_state.payload_len = payload_len;
ws->frame_state.bytes_remaining = payload_len;
return payload_len;
}
static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
int rlen = 0;
transport_ws_t *ws = esp_transport_get_context_data(t);
// If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload
if (ws->frame_state.bytes_remaining <= 0) {
if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) <= 0) {
// If something when wrong then we prepare for reading a new header
ws->frame_state.bytes_remaining = 0;
return rlen;
}
}
if (ws->frame_state.payload_len) {
if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error reading payload data");
ws->frame_state.bytes_remaining = 0;
return rlen;
}
}
return rlen;
}
static int ws_poll_read(esp_transport_handle_t t, int timeout_ms)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
@ -469,5 +539,13 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->read_opcode;
return ws->frame_state.opcode;
}
int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)
{
transport_ws_t *ws = esp_transport_get_context_data(t);
return ws->frame_state.payload_len;
}

View file

@ -3,10 +3,13 @@ from __future__ import unicode_literals
import re
import os
import socket
import select
import hashlib
import base64
from threading import Thread
import queue
import random
import string
from threading import Thread, Event
import ttfw_idf
@ -30,7 +33,10 @@ class Websocket:
def __init__(self, port):
self.port = port
self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.settimeout(10.0)
self.send_q = queue.Queue()
self.shutdown = Event()
def __enter__(self):
try:
@ -43,23 +49,27 @@ class Websocket:
self.server_thread = Thread(target=self.run_server)
self.server_thread.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown.set()
self.server_thread.join()
self.socket.close()
self.conn.close()
def run_server(self):
self.conn, address = self.socket.accept() # accept new connection
self.conn.settimeout(10.0)
self.socket.settimeout(10.0)
print("Connection from: {}".format(address))
self.establish_connection()
# Echo data until client closes connection
self.echo_data()
print("WS established")
# Handle connection until client closes it, will echo any data received and send data from send_q queue
self.handle_conn()
def establish_connection(self):
while True:
while not self.shutdown.is_set():
try:
# receive data stream. it won't accept data packet greater than 1024 bytes
data = self.conn.recv(1024).decode()
@ -70,6 +80,7 @@ class Websocket:
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
self.handshake(data)
return
except socket.error as err:
print("Unable to establish a websocket connection: {}, {}".format(err))
raise
@ -94,26 +105,46 @@ class Websocket:
self.conn.send(resp.encode())
def echo_data(self):
while(True):
def handle_conn(self):
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if data is not received
return
if self.conn in r:
self.echo_data()
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
if not self.send_q.empty():
self._send_data_(self.send_q.get())
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
frame = header + payload
decoded_payload = self.decode_frame(frame)
echo_frame = self.encode_frame(decoded_payload)
self.conn.send(echo_frame)
except socket.error as err:
print("Stopped echoing data: {}".format(err))
raise
def echo_data(self):
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if socket closed by peer
return
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
if not payload:
# exit if socket closed by peer
return
frame = header + payload
decoded_payload = self.decode_frame(frame)
print("Sending echo...")
self._send_data_(decoded_payload)
def _send_data_(self, data):
frame = self.encode_frame(data)
self.conn.send(frame)
def send_data(self, data):
self.send_q.put(data.encode())
def decode_frame(self, frame):
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
@ -133,8 +164,18 @@ class Websocket:
# Set FIN = 1 and OP_CODE = 1 (text)
header = (1 << 7) | (1 << 0)
frame = bytearray(header)
frame.append(len(payload))
frame = bytearray([header])
payload_len = len(payload)
# If payload len is longer than 125 then the next 16 bits are used to encode length
if payload_len > 125:
frame.append(126)
frame.append(payload_len >> 8)
frame.append(0xFF & payload_len)
else:
frame.append(payload_len)
frame += payload
return frame
@ -143,8 +184,27 @@ class Websocket:
def test_echo(dut):
dut.expect("WEBSOCKET_EVENT_CONNECTED")
for i in range(0, 10):
dut.expect(re.compile(r"Received=hello (\d)"))
dut.expect("Websocket Stopped")
dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
print("All echos received")
def test_recv_long_msg(dut, websocket, msg_len, repeats):
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
for _ in range(repeats):
websocket.send_data(send_msg)
recv_msg = ''
while len(recv_msg) < msg_len:
# Filter out color encoding
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
recv_msg += match
if recv_msg == send_msg:
print("Sent message and received message are equal")
else:
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
@ttfw_idf.idf_example_test(env_tag="Example_WIFI")
@ -178,12 +238,14 @@ def test_examples_protocol_websocket(env, extra_data):
if uri_from_stdin:
server_port = 4455
with Websocket(server_port):
with Websocket(server_port) as ws:
uri = "ws://{}:{}".format(get_my_ip(), server_port)
print("DUT connecting to {}".format(uri))
dut1.expect("Please enter uri of websocket endpoint", timeout=30)
dut1.write(uri)
test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
test_recv_long_msg(dut1, ws, 2000, 3)
else:
print("DUT connecting to {}".format(uri))

View file

@ -17,15 +17,26 @@
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "freertos/semphr.h"
#include "freertos/event_groups.h"
#include "esp_log.h"
#include "esp_websocket_client.h"
#include "esp_event.h"
#define NO_DATA_TIMEOUT_SEC 10
static const char *TAG = "WEBSOCKET";
static TimerHandle_t shutdown_signal_timer;
static SemaphoreHandle_t shutdown_sema;
static void shutdown_signaler(TimerHandle_t xTimer)
{
ESP_LOGI(TAG, "No data received for %d seconds, signaling shutdown", NO_DATA_TIMEOUT_SEC);
xSemaphoreGive(shutdown_sema);
}
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
static void get_string(char *line, size_t size)
{
@ -49,20 +60,23 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
{
esp_websocket_event_data_t *data = (esp_websocket_event_data_t *)event_data;
switch (event_id) {
case WEBSOCKET_EVENT_CONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED");
break;
case WEBSOCKET_EVENT_DISCONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED");
break;
case WEBSOCKET_EVENT_DATA:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
ESP_LOGW(TAG, "Received=%.*s\r\n", data->data_len, (char*)data->data_ptr);
break;
case WEBSOCKET_EVENT_ERROR:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
break;
case WEBSOCKET_EVENT_CONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED");
break;
case WEBSOCKET_EVENT_DISCONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED");
break;
case WEBSOCKET_EVENT_DATA:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
xTimerReset(shutdown_signal_timer, portMAX_DELAY);
break;
case WEBSOCKET_EVENT_ERROR:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
break;
}
}
@ -70,7 +84,11 @@ static void websocket_app_start(void)
{
esp_websocket_client_config_t websocket_cfg = {};
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
shutdown_signal_timer = xTimerCreate("Websocket shutdown timer", NO_DATA_TIMEOUT_SEC * 1000 / portTICK_PERIOD_MS,
pdFALSE, NULL, shutdown_signaler);
shutdown_sema = xSemaphoreCreateBinary();
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
char line[128];
ESP_LOGI(TAG, "Please enter uri of websocket endpoint");
@ -79,10 +97,10 @@ static void websocket_app_start(void)
websocket_cfg.uri = line;
ESP_LOGI(TAG, "Endpoint uri: %s\n", line);
#else
#else
websocket_cfg.uri = CONFIG_WEBSOCKET_URI;
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri);
@ -90,6 +108,7 @@ static void websocket_app_start(void)
esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client);
esp_websocket_client_start(client);
xTimerStart(shutdown_signal_timer, portMAX_DELAY);
char data[32];
int i = 0;
while (i < 10) {
@ -100,8 +119,8 @@ static void websocket_app_start(void)
}
vTaskDelay(1000 / portTICK_RATE_MS);
}
// Give server some time to respond before closing
vTaskDelay(3000 / portTICK_RATE_MS);
xSemaphoreTake(shutdown_sema, portMAX_DELAY);
esp_websocket_client_stop(client);
ESP_LOGI(TAG, "Websocket Stopped");
esp_websocket_client_destroy(client);