import socket import ssl import os import re import time import threading import Parameter from PKI import PKIDict from PKI import PKIItem from NativeLog import NativeLog class SerialPortCheckFail(StandardError): pass class SSLHandlerFail(StandardError): pass class PCFail(StandardError): pass class TargetFail(StandardError): pass def ssl_handler_wrapper(handler_type): if handler_type == "PC": exception_type = PCFail elif handler_type == "Target": exception_type = TargetFail else: exception_type = None def _handle_func(func): def _handle_args(*args, **kwargs): try: ret = func(*args, **kwargs) except StandardError, e: NativeLog.add_exception_log(e) raise exception_type(str(e)) return ret return _handle_args return _handle_func class SerialPort(object): def __init__(self, tc_action, port_name): self.tc_action = tc_action self.port_name = port_name def flush(self): self.tc_action.flush_data(self.port_name) def write_line(self, data): self.tc_action.serial_write_line(self.port_name, data) def check(self, condition, timeout=10): if self.tc_action.check_response(self.port_name, condition, timeout) is False: raise SerialPortCheckFail("serial port check fail, condition is %s" % condition) def read_data(self): return self.tc_action.serial_read_data(self.port_name) pass class SSLHandler(object): # ssl operation timeout is 30 seconds TIMEOUT = 30 def __init__(self, typ, config, serial_port): self.type = typ self.config = config self.timeout = self.TIMEOUT self.serial_port = serial_port self.accept_thread = None self.data_validation = False def set_timeout(self, timeout): self.timeout = timeout def init_context(self): pass def connect(self, remote_ip, remote_port, local_ip=0, local_port=0): pass def listen(self, local_port=0, local_ip=0): pass def send(self, size, data): pass def recv(self, length, timeout): pass def set_data_validation(self, validation): pass def close(self): if self.accept_thread is not None: self.accept_thread.exit() self.accept_thread.join(5) pass class TargetSSLHandler(SSLHandler): def __init__(self, typ, config, serial_port): SSLHandler.__init__(self, typ, config, serial_port) self.ssl_id = None self.server_id = None @ssl_handler_wrapper("Target") def init_context(self): self.serial_port.flush() self.serial_port.write_line("soc -T") self.serial_port.check("+CLOSEALL") if self.type == "client": version = Parameter.VERSION[self.config["client_version"]] fragment = self.config["client_fragment_size"] ca = self.config["client_trust_anchor"] cert = self.config["client_certificate"] key = self.config["client_key"] verify_required = 0x01 if self.config["verify_server"] is True else 0x00 context_type = 1 else: version = Parameter.VERSION[self.config["server_version"]] fragment = self.config["server_fragment_size"] ca = self.config["server_trust_anchor"] cert = self.config["server_certificate"] key = self.config["server_key"] verify_required = 0x02 if self.config["verify_client"] is True else 0x00 context_type = 2 ssc_cmd = "ssl -I -t %u -r %u -v %u -o %u" % (context_type, fragment, version, verify_required) if ca is not None: _index = PKIDict.PKIDict.CERT_DICT[ca] ssc_cmd += " -a %d" % _index if cert is not None: _index = PKIDict.PKIDict.CERT_DICT[cert] ssc_cmd += " -c %d" % _index if key is not None: _index = PKIDict.PKIDict.KEY_DICT[key] ssc_cmd += " -k %d" % _index # write command and check result self.serial_port.flush() self.serial_port.write_line(ssc_cmd) self.serial_port.check(["+SSL:OK", "AND", "!+SSL:ERROR"]) @ssl_handler_wrapper("Target") def connect(self, remote_ip, remote_port, local_ip=0, local_port=0): self.serial_port.flush() self.serial_port.write_line("soc -B -t SSL -i %s -p %s" % (local_ip, local_port)) self.serial_port.check(["OK", "AND", "!ERROR"]) self.serial_port.flush() self.serial_port.write_line("soc -C -s 0 -i %s -p %s" % (remote_ip, remote_port)) self.serial_port.check(["OK", "AND", "!ERROR"], timeout=30) self.ssl_id = 0 pass def accept_succeed(self): self.ssl_id = 1 class Accept(threading.Thread): def __init__(self, serial_port, succeed_cb): threading.Thread.__init__(self) self.setDaemon(True) self.serial_port = serial_port self.succeed_cb = succeed_cb self.exit_flag = threading.Event() def run(self): while self.exit_flag.isSet() is False: try: self.serial_port.check("+ACCEPT:", timeout=1) self.succeed_cb() break except StandardError: pass def exit(self): self.exit_flag.set() @ssl_handler_wrapper("Target") def listen(self, local_port=0, local_ip=0): self.serial_port.flush() self.serial_port.write_line("soc -B -t SSL -i %s -p %s" % (local_ip, local_port)) self.serial_port.check(["OK", "AND", "!ERROR"]) self.serial_port.flush() self.serial_port.write_line("soc -L -s 0") self.serial_port.check(["OK", "AND", "!ERROR"]) self.server_id = 0 self.accept_thread = self.Accept(self.serial_port, self.accept_succeed) self.accept_thread.start() pass @ssl_handler_wrapper("Target") def send(self, size=10, data=None): if data is not None: size = len(data) self.serial_port.flush() self.serial_port.write_line("soc -S -s %s -l %s" % (self.ssl_id, size)) self.serial_port.check(["OK", "AND", "!ERROR"]) pass @ssl_handler_wrapper("Target") def recv(self, length, timeout=SSLHandler.TIMEOUT): pattern = re.compile("\+RECV:\d+,(\d+)\r\n") data_len = 0 data = "" time1 = time.time() while time.time() - time1 < timeout: data += self.serial_port.read_data() if self.data_validation is True: if "+DATA_ERROR" in data: raise SSLHandlerFail("target data validation fail") while True: match = pattern.search(data) if match is None: break else: data_len += int(match.group(1)) data = data[data.find(match.group())+len(match.group()):] if data_len >= length: result = True break else: result = False if result is False: raise SSLHandlerFail("Target recv fail") def set_data_validation(self, validation): self.data_validation = validation self.serial_port.flush() self.serial_port.write_line("soc -V -s %s -o %s" % (self.ssl_id, 1 if validation is True else 0)) self.serial_port.check(["OK", "AND", "!ERROR"]) @ssl_handler_wrapper("Target") def close(self): SSLHandler.close(self) self.serial_port.flush() self.serial_port.write_line("ssl -D") self.serial_port.check(["+SSL:OK", "OR", "+SSL:ERROR"]) self.serial_port.write_line("soc -T") self.serial_port.check("+CLOSEALL") pass pass def calc_hash(index): return (index & 0xffffffff) % 83 + (index & 0xffffffff) % 167 def verify_data(data, start_index): for i, c in enumerate(data): if ord(c) != calc_hash(start_index + i): NativeLog.add_trace_critical("[Data Validation Error] target sent data index %u is error." " Sent data is %x, should be %x" % (start_index + i, ord(c), calc_hash(start_index + i))) return False return True def make_validation_data(length, start_index): return bytes().join([chr(calc_hash(start_index + i)) for i in range(length)]) class PCSSLHandler(SSLHandler): PROTOCOL_MAPPING = { "SSLv23": ssl.PROTOCOL_SSLv23, "SSLv23_2": ssl.PROTOCOL_SSLv23, "SSLv20": ssl.PROTOCOL_SSLv2, "SSLv30": ssl.PROTOCOL_SSLv3, "TLSv10": ssl.PROTOCOL_TLSv1, "TLSv11": ssl.PROTOCOL_TLSv1_1, "TLSv12": ssl.PROTOCOL_TLSv1_2, } CERT_FOLDER = os.path.join(".", "PKI", PKIDict.PKIDict.CERT_FOLDER) KEY_FOLDER = os.path.join(".", "PKI", PKIDict.PKIDict.KEY_FOLDER) def __init__(self, typ, config, serial_port): SSLHandler.__init__(self, typ, config, serial_port) self.ssl_context = None self.ssl = None self.server_sock = None self.send_index = 0 self.recv_index = 0 class InitContextThread(threading.Thread): def __init__(self, handler, version, cipher_suite, ca, cert, key, verify_required, remote_cert): threading.Thread.__init__(self) self.setDaemon(True) self.handler = handler self.version = version self.cipher_suite = cipher_suite self.ca = ca self.cert = cert self.key = key self.verify_required = verify_required self.remote_cert = remote_cert pass @staticmethod def handle_cert(cert_file, ca_file): cert = PKIItem.Certificate() cert.parse_file(cert_file) ca = PKIItem.Certificate() ca.parse_file(ca_file) if cert.file_encoding == "PEM" and ca.name in cert.cert_chain: cert_chain_t = cert.cert_chain[1:cert.cert_chain.index(ca.name)] ret = ["%s.pem" % c for c in cert_chain_t] else: ret = [] return ret def run(self): try: ssl_context = ssl.SSLContext(self.version) # cipher suite ssl_context.set_ciphers(self.cipher_suite) if self.ca is not None: ssl_context.load_verify_locations(cafile=os.path.join(self.handler.CERT_FOLDER, self.ca)) # python ssl can't verify cert chain, don't know why # need to load cert between cert and ca for pem (pem cert contains cert chain) if self.remote_cert is not None: cert_chain = self.handle_cert(self.remote_cert, self.ca) for c in cert_chain: NativeLog.add_trace_info("load ca chain %s" % c) ssl_context.load_verify_locations(cafile=os.path.join(self.handler.CERT_FOLDER, c)) if self.cert is not None: cert = os.path.join(self.handler.CERT_FOLDER, self.cert) key = os.path.join(self.handler.KEY_FOLDER, self.key) ssl_context.load_cert_chain(cert, keyfile=key) if self.verify_required is True: ssl_context.verify_mode = ssl.CERT_REQUIRED else: ssl_context.verify_mode = ssl.CERT_NONE self.handler.ssl_context = ssl_context except StandardError, e: NativeLog.add_exception_log(e) pass pass @ssl_handler_wrapper("PC") def init_context(self): if self.type == "client": version = self.PROTOCOL_MAPPING[self.config["client_version"]] cipher_suite = Parameter.CIPHER_SUITE[self.config["client_cipher_suite"]] ca = self.config["client_trust_anchor"] cert = self.config["client_certificate"] key = self.config["client_key"] verify_required = self.config["verify_server"] remote_cert = self.config["server_certificate"] else: version = self.PROTOCOL_MAPPING[self.config["server_version"]] cipher_suite = Parameter.CIPHER_SUITE[self.config["server_cipher_suite"]] ca = self.config["server_trust_anchor"] cert = self.config["server_certificate"] key = self.config["server_key"] verify_required = self.config["verify_client"] remote_cert = self.config["client_certificate"] _init_context = self.InitContextThread(self, version, cipher_suite, ca, cert, key, verify_required, remote_cert) _init_context.start() _init_context.join(5) if self.ssl_context is None: raise StandardError("Init Context Fail") pass @ssl_handler_wrapper("PC") def connect(self, remote_ip, remote_port, local_ip=0, local_port=0): sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) # reuse socket in TIME_WAIT state sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.settimeout(self.timeout) sock.bind((local_ip, local_port)) self.ssl = self.ssl_context.wrap_socket(sock) self.ssl.connect((remote_ip, remote_port)) pass def accept_succeed(self, ssl_new): ssl_new.settimeout(self.timeout) self.ssl = ssl_new class Accept(threading.Thread): def __init__(self, server_sock, ssl_context, succeed_cb): threading.Thread.__init__(self) self.setDaemon(True) self.server_sock = server_sock self.ssl_context = ssl_context self.succeed_cb = succeed_cb self.exit_flag = threading.Event() def run(self): while self.exit_flag.isSet() is False: try: new_socket, addr = self.server_sock.accept() ssl_new = self.ssl_context.wrap_socket(new_socket, server_side=True) self.succeed_cb(ssl_new) break except StandardError: pass pass def exit(self): self.exit_flag.set() @ssl_handler_wrapper("PC") def listen(self, local_port=0, local_ip=0): self.server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) # reuse socket in TIME_WAIT state self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server_sock.settimeout(1) self.server_sock.bind((local_ip, local_port)) self.server_sock.listen(5) self.accept_thread = self.Accept(self.server_sock, self.ssl_context, self.accept_succeed) self.accept_thread.start() pass @ssl_handler_wrapper("PC") def send(self, size=10, data=None): if data is None: self.ssl.send(make_validation_data(size, self.send_index)) if self.data_validation is True: self.send_index += size else: self.ssl.send(data) @ssl_handler_wrapper("PC") def recv(self, length, timeout=SSLHandler.TIMEOUT, data_validation=False): time1 = time.time() data_len = 0 while time.time() - time1 < timeout: data = self.ssl.read() if data_validation is True and len(data) > 0: if verify_data(data, self.recv_index) is False: raise SSLHandlerFail("PC data validation fail, index is %s" % self.recv_index) self.recv_index += len(data) data_len += len(data) if data_len >= length: result = True break else: result = False if result is False: raise SSLHandlerFail("PC recv fail") def set_data_validation(self, validation): self.data_validation = validation @ssl_handler_wrapper("PC") def close(self): SSLHandler.close(self) if self.ssl is not None: self.ssl.close() self.ssl = None if self.server_sock is not None: self.server_sock.close() self.server_sock = None del self.ssl_context def main(): ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) # cipher suite ssl_context.set_ciphers("AES256-SHA") ssl_context.load_cert_chain("D:\workspace\\auto_test_script\PKI\Certificate\\" "L2CertRSA512sha1_L1CertRSA512sha1_RootCertRSA512sha1.pem", keyfile="D:\workspace\\auto_test_script\PKI\Key\PrivateKey2RSA512.pem") ssl_context.verify_mode = ssl.CERT_NONE server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) # reuse socket in TIME_WAIT state server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_sock.settimeout(100) server_sock.bind(("192.168.111.5", 443)) server_sock.listen(5) while True: try: new_socket, addr = server_sock.accept() ssl_new = ssl_context.wrap_socket(new_socket, server_side=True) print "server connected" break except StandardError: pass pass if __name__ == '__main__': main()