diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index 08d2ba21f..031b25876 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -63,6 +63,8 @@ typedef struct { void *user_context; int network_timeout_ms; char *subprotocol; + char *user_agent; + char *headers; } websocket_config_storage_t; typedef enum { @@ -179,6 +181,16 @@ static esp_err_t esp_websocket_client_set_config(esp_websocket_client_handle_t c cfg->subprotocol = strdup(config->subprotocol); ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->subprotocol, return ESP_ERR_NO_MEM); } + if (config->user_agent) { + free(cfg->user_agent); + cfg->user_agent = strdup(config->user_agent); + ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->user_agent, return ESP_ERR_NO_MEM); + } + if (config->headers) { + free(cfg->headers); + cfg->headers = strdup(config->headers); + ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->headers, return ESP_ERR_NO_MEM); + } cfg->network_timeout_ms = WEBSOCKET_NETWORK_TIMEOUT_MS; cfg->user_context = config->user_context; @@ -207,6 +219,8 @@ static esp_err_t esp_websocket_client_destroy_config(esp_websocket_client_handle free(cfg->username); free(cfg->password); free(cfg->subprotocol); + free(cfg->user_agent); + free(cfg->headers); memset(cfg, 0, sizeof(websocket_config_storage_t)); free(client->config); client->config = NULL; @@ -221,6 +235,12 @@ static void set_websocket_transport_optional_settings(esp_websocket_client_handl if (trans && client->config->subprotocol) { esp_transport_ws_set_subprotocol(trans, client->config->subprotocol); } + if (trans && client->config->user_agent) { + esp_transport_ws_set_user_agent(trans, client->config->user_agent); + } + if (trans && client->config->headers) { + esp_transport_ws_set_headers(trans, client->config->headers); + } } esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_client_config_t *config) diff --git a/components/esp_websocket_client/include/esp_websocket_client.h b/components/esp_websocket_client/include/esp_websocket_client.h index 3dbe1df9a..1add2c0f8 100644 --- a/components/esp_websocket_client/include/esp_websocket_client.h +++ b/components/esp_websocket_client/include/esp_websocket_client.h @@ -92,6 +92,8 @@ typedef struct { const char *cert_pem; /*!< SSL Certification, PEM format as string, if the client requires to verify server */ esp_websocket_transport_t transport; /*!< Websocket transport type, see `esp_websocket_transport_t */ char *subprotocol; /*!< Websocket subprotocol */ + char *user_agent; /*!< Websocket user-agent */ + char *headers; /*!< Websocket additional headers */ } esp_websocket_client_config_t; /** diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 9920e1bff..7251e92e2 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -50,6 +50,30 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path); */ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol); +/** + * @brief Set websocket user-agent header + * + * @param t websocket transport handle + * @param sub_protocol user-agent string + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent); + +/** + * @brief Set websocket additional headers + * + * @param t websocket transport handle + * @param sub_protocol additional header strings each terminated with \r\n + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers); + /** * @brief Sends websocket raw message with custom opcode and payload * diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 637d4d617..c6ed255cc 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -31,6 +31,8 @@ typedef struct { char *path; char *buffer; char *sub_protocol; + char *user_agent; + char *headers; uint8_t read_opcode; esp_transport_handle_t parent; } transport_ws_t; @@ -96,24 +98,27 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int // Size of base64 coded string is equal '((input_size * 4) / 3) + (input_size / 96) + 6' including Z-term unsigned char client_key[28] = {0}; + const char *user_agent_ptr = (ws->user_agent)?(ws->user_agent):"ESP32 Websocket Client"; + size_t outlen = 0; mbedtls_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key)); int len = snprintf(ws->buffer, DEFAULT_WS_BUFFER, "GET %s HTTP/1.1\r\n" "Connection: Upgrade\r\n" "Host: %s:%d\r\n" + "User-Agent: %s\r\n" "Upgrade: websocket\r\n" "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Key: %s\r\n" - "User-Agent: ESP32 Websocket Client\r\n", + "Sec-WebSocket-Key: %s\r\n", ws->path, - host, port, + host, port, user_agent_ptr, client_key); if (len <= 0 || len >= DEFAULT_WS_BUFFER) { ESP_LOGE(TAG, "Error in request generation, %d", len); return -1; } if (ws->sub_protocol) { + ESP_LOGD(TAG, "sub_protocol: %s", ws->sub_protocol); int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "Sec-WebSocket-Protocol: %s\r\n", ws->sub_protocol); len += r; if (r <= 0 || len >= DEFAULT_WS_BUFFER) { @@ -122,6 +127,16 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return -1; } } + if (ws->headers) { + ESP_LOGD(TAG, "headers: %s", ws->headers); + int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "%s", ws->headers); + len += r; + if (r <= 0 || len >= DEFAULT_WS_BUFFER) { + ESP_LOGE(TAG, "Error in request generation" + "(strncpy of headers returned %d, desired request len: %d, buffer size: %d", r, len, DEFAULT_WS_BUFFER); + return -1; + } + } int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "\r\n"); len += r; if (r <= 0 || len >= DEFAULT_WS_BUFFER) { @@ -233,7 +248,7 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const for (i = 0; i < len; ++i) { buffer[i] = (buffer[i] ^ mask[i % 4]); } - } + } return ret; } @@ -352,6 +367,8 @@ static esp_err_t ws_destroy(esp_transport_handle_t t) free(ws->buffer); free(ws->path); free(ws->sub_protocol); + free(ws->user_agent); + free(ws->headers); free(ws); return 0; } @@ -409,6 +426,46 @@ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char return ESP_OK; } +esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->user_agent) { + free(ws->user_agent); + } + if (user_agent == NULL) { + ws->user_agent = NULL; + return ESP_OK; + } + ws->user_agent = strdup(user_agent); + if (ws->user_agent == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + +esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->headers) { + free(ws->headers); + } + if (headers == NULL) { + ws->headers = NULL; + return ESP_OK; + } + ws->headers = strdup(headers); + if (ws->headers == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t);