2020-01-21 15:44:02 +00:00
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import re
|
|
|
|
import os
|
|
|
|
import socket
|
|
|
|
import select
|
|
|
|
import subprocess
|
|
|
|
from threading import Thread, Event
|
|
|
|
import ttfw_idf
|
|
|
|
import ssl
|
|
|
|
|
|
|
|
|
|
|
|
def _path(f):
|
|
|
|
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
|
|
|
|
|
|
|
|
|
|
|
|
def set_server_cert_cn(ip):
|
|
|
|
arg_list = [
|
|
|
|
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', "/CN={}".format(ip), '-new'],
|
|
|
|
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
|
|
|
|
'-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
|
|
|
|
for args in arg_list:
|
|
|
|
if subprocess.check_call(args) != 0:
|
|
|
|
raise("openssl command {} failed".format(args))
|
|
|
|
|
|
|
|
|
|
|
|
def get_my_ip():
|
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
|
try:
|
|
|
|
# doesn't even have to be reachable
|
|
|
|
s.connect(('10.255.255.255', 1))
|
|
|
|
IP = s.getsockname()[0]
|
|
|
|
except Exception:
|
|
|
|
IP = '127.0.0.1'
|
|
|
|
finally:
|
|
|
|
s.close()
|
|
|
|
return IP
|
|
|
|
|
|
|
|
|
|
|
|
# Simple server for mqtt over TLS connection
|
|
|
|
class TlsServer:
|
|
|
|
|
|
|
|
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
|
|
|
|
self.port = port
|
|
|
|
self.socket = socket.socket()
|
|
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
|
|
self.socket.settimeout(10.0)
|
|
|
|
self.shutdown = Event()
|
|
|
|
self.client_cert = client_cert
|
|
|
|
self.refuse_connection = refuse_connection
|
|
|
|
self.ssl_error = None
|
|
|
|
self.use_alpn = use_alpn
|
|
|
|
self.negotiated_protocol = None
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
try:
|
|
|
|
self.socket.bind(('', self.port))
|
|
|
|
except socket.error as e:
|
|
|
|
print("Bind failed:{}".format(e))
|
|
|
|
raise
|
|
|
|
|
|
|
|
self.socket.listen(1)
|
|
|
|
self.server_thread = Thread(target=self.run_server)
|
|
|
|
self.server_thread.start()
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
self.shutdown.set()
|
|
|
|
self.server_thread.join()
|
|
|
|
self.socket.close()
|
|
|
|
if (self.conn is not None):
|
|
|
|
self.conn.close()
|
|
|
|
|
|
|
|
def get_last_ssl_error(self):
|
|
|
|
return self.ssl_error
|
|
|
|
|
|
|
|
def get_negotiated_protocol(self):
|
|
|
|
return self.negotiated_protocol
|
|
|
|
|
|
|
|
def run_server(self):
|
|
|
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
|
|
if self.client_cert:
|
|
|
|
context.verify_mode = ssl.CERT_REQUIRED
|
|
|
|
context.load_verify_locations(cafile=_path("ca.crt"))
|
|
|
|
context.load_cert_chain(certfile=_path("srv.crt"), keyfile=_path("server.key"))
|
|
|
|
if self.use_alpn:
|
|
|
|
context.set_alpn_protocols(["mymqtt", "http/1.1"])
|
|
|
|
self.socket = context.wrap_socket(self.socket, server_side=True)
|
|
|
|
try:
|
|
|
|
self.conn, address = self.socket.accept() # accept new connection
|
|
|
|
self.socket.settimeout(10.0)
|
|
|
|
print(" - connection from: {}".format(address))
|
|
|
|
if self.use_alpn:
|
|
|
|
self.negotiated_protocol = self.conn.selected_alpn_protocol()
|
|
|
|
print(" - negotiated_protocol: {}".format(self.negotiated_protocol))
|
|
|
|
self.handle_conn()
|
|
|
|
except ssl.SSLError as e:
|
|
|
|
self.conn = None
|
|
|
|
self.ssl_error = str(e)
|
|
|
|
print(" - SSLError: {}".format(str(e)))
|
|
|
|
|
|
|
|
def handle_conn(self):
|
|
|
|
while not self.shutdown.is_set():
|
|
|
|
r,w,e = select.select([self.conn], [], [], 1)
|
|
|
|
try:
|
|
|
|
if self.conn in r:
|
|
|
|
self.process_mqtt_connect()
|
|
|
|
|
|
|
|
except socket.error as err:
|
|
|
|
print(" - error: {}".format(err))
|
|
|
|
raise
|
|
|
|
|
|
|
|
def process_mqtt_connect(self):
|
|
|
|
try:
|
|
|
|
data = bytearray(self.conn.recv(1024))
|
|
|
|
message = ''.join(format(x, '02x') for x in data)
|
|
|
|
if message[0:16] == '101800044d515454':
|
|
|
|
if self.refuse_connection is False:
|
|
|
|
print(" - received mqtt connect, sending ACK")
|
|
|
|
self.conn.send(bytearray.fromhex("20020000"))
|
|
|
|
else:
|
|
|
|
# injecting connection not authorized error
|
|
|
|
print(" - received mqtt connect, sending NAK")
|
|
|
|
self.conn.send(bytearray.fromhex("20020005"))
|
|
|
|
else:
|
|
|
|
raise Exception(" - error process_mqtt_connect unexpected connect received: {}".format(message))
|
|
|
|
finally:
|
|
|
|
# stop the server after the connect message in happy flow, or if any exception occur
|
|
|
|
self.shutdown.set()
|
|
|
|
|
|
|
|
|
2020-02-07 10:10:04 +00:00
|
|
|
@ttfw_idf.idf_custom_test(env_tag="Example_WIFI", group="test-apps")
|
2020-04-03 14:43:55 +00:00
|
|
|
def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
2020-01-21 15:44:02 +00:00
|
|
|
"""
|
|
|
|
steps:
|
|
|
|
1. join AP
|
|
|
|
2. connect to uri specified in the config
|
|
|
|
3. send and receive data
|
|
|
|
"""
|
2020-02-07 10:10:04 +00:00
|
|
|
dut1 = env.get_dut("mqtt_publish_connect_test", "tools/test_apps/protocols/mqtt/publish_connect_test", dut_class=ttfw_idf.ESP32DUT)
|
2020-01-21 15:44:02 +00:00
|
|
|
# check and log bin size
|
|
|
|
binary_file = os.path.join(dut1.app.binary_path, "mqtt_publish_connect_test.bin")
|
|
|
|
bin_size = os.path.getsize(binary_file)
|
|
|
|
ttfw_idf.log_performance("mqtt_publish_connect_test_bin_size", "{}KB".format(bin_size // 1024))
|
2020-02-28 15:12:11 +00:00
|
|
|
ttfw_idf.check_performance("mqtt_publish_connect_test_bin_size_vin_size", bin_size // 1024, dut1.TARGET)
|
2020-01-21 15:44:02 +00:00
|
|
|
# Look for test case symbolic names
|
|
|
|
cases = {}
|
|
|
|
try:
|
|
|
|
for i in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT",
|
|
|
|
"CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN"]:
|
|
|
|
cases[i] = dut1.app.get_sdkconfig()[i]
|
|
|
|
except Exception:
|
|
|
|
print('ENV_TEST_FAILURE: Some mandatory test case not found in sdkconfig')
|
|
|
|
raise
|
|
|
|
|
|
|
|
dut1.start_app()
|
|
|
|
esp_ip = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)
|
|
|
|
print("Got IP={}".format(esp_ip[0]))
|
|
|
|
#
|
|
|
|
# start connection test
|
|
|
|
ip = get_my_ip()
|
|
|
|
set_server_cert_cn(ip)
|
|
|
|
server_port = 2222
|
|
|
|
|
|
|
|
def start_case(case, desc):
|
|
|
|
print("Starting {}: {}".format(case, desc))
|
|
|
|
case_id = cases[case]
|
|
|
|
dut1.write("conn {} {} {}".format(ip, server_port, case_id))
|
|
|
|
dut1.expect("Test case:{} started".format(case_id))
|
|
|
|
return case_id
|
|
|
|
|
|
|
|
for case in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT"]:
|
|
|
|
# All these cases connect to the server with no server verification or with server only verification
|
|
|
|
with TlsServer(server_port):
|
|
|
|
test_nr = start_case(case, "default server - expect to connect normally")
|
|
|
|
dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
|
|
|
|
with TlsServer(server_port, refuse_connection=True):
|
|
|
|
test_nr = start_case(case, "ssl shall connect, but mqtt sends connect refusal")
|
|
|
|
dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
|
|
|
|
dut1.expect("MQTT ERROR: 0x5") # expecting 0x5 ... connection not authorized error
|
|
|
|
with TlsServer(server_port, client_cert=True) as s:
|
|
|
|
test_nr = start_case(case, "server with client verification - handshake error since client presents no client certificate")
|
|
|
|
dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
|
|
|
|
dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
|
|
|
|
if "PEER_DID_NOT_RETURN_A_CERTIFICATE" not in s.get_last_ssl_error():
|
|
|
|
raise("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
|
|
|
|
|
|
|
|
for case in ["CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH", "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD"]:
|
|
|
|
# These cases connect to server with both server and client verification (client key might be password protected)
|
|
|
|
with TlsServer(server_port, client_cert=True):
|
|
|
|
test_nr = start_case(case, "server with client verification - expect to connect normally")
|
|
|
|
dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
|
|
|
|
|
|
|
|
case = "CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT"
|
|
|
|
with TlsServer(server_port) as s:
|
|
|
|
test_nr = start_case(case, "invalid server certificate on default server - expect ssl handshake error")
|
|
|
|
dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
|
|
|
|
dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
|
|
|
|
if "alert unknown ca" not in s.get_last_ssl_error():
|
|
|
|
raise Exception("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
|
|
|
|
|
|
|
|
case = "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT"
|
|
|
|
with TlsServer(server_port, client_cert=True) as s:
|
|
|
|
test_nr = start_case(case, "Invalid client certificate on server with client verification - expect ssl handshake error")
|
|
|
|
dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
|
|
|
|
dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
|
|
|
|
if "CERTIFICATE_VERIFY_FAILED" not in s.get_last_ssl_error():
|
|
|
|
raise Exception("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
|
|
|
|
|
|
|
|
for case in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN"]:
|
|
|
|
with TlsServer(server_port, use_alpn=True) as s:
|
|
|
|
test_nr = start_case(case, "server with alpn - expect connect, check resolved protocol")
|
|
|
|
dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
|
|
|
|
if case == "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT" and s.get_negotiated_protocol() is None:
|
|
|
|
print(" - client with alpn off, no negotiated protocol: OK")
|
|
|
|
elif case == "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN" and s.get_negotiated_protocol() == "mymqtt":
|
|
|
|
print(" - client with alpn on, negotiated protocol resolved: OK")
|
|
|
|
else:
|
|
|
|
raise Exception("Unexpected negotiated protocol {}".format(s.get_negotiated_protocol()))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2020-04-03 14:43:55 +00:00
|
|
|
test_app_protocol_mqtt_publish_connect()
|