From 857916285d399ec822e351910cf26389db245917 Mon Sep 17 00:00:00 2001 From: DJ2LS Date: Sun, 21 Jan 2024 20:34:01 +0100 Subject: [PATCH] changed dispatcher to a data type handler --- modem/arq_data_type_handler.py | 83 +++++++++++++++++++++++++++ modem/arq_received_data_dispatcher.py | 35 ----------- modem/arq_session.py | 10 ++-- modem/arq_session_irs.py | 15 +++-- modem/arq_session_iss.py | 18 ++++-- modem/command.py | 3 + modem/command_arq_raw.py | 9 ++- modem/data_frame_factory.py | 10 ++-- modem/frame_handler.py | 1 - tests/test_arq_session.py | 4 +- tests/test_data_dispatcher.py | 35 ----------- tests/test_data_type_handler.py | 37 ++++++++++++ 12 files changed, 165 insertions(+), 95 deletions(-) create mode 100644 modem/arq_data_type_handler.py delete mode 100644 modem/arq_received_data_dispatcher.py delete mode 100644 tests/test_data_dispatcher.py create mode 100644 tests/test_data_type_handler.py diff --git a/modem/arq_data_type_handler.py b/modem/arq_data_type_handler.py new file mode 100644 index 00000000..8b724844 --- /dev/null +++ b/modem/arq_data_type_handler.py @@ -0,0 +1,83 @@ +# File: arq_data_type_handler.py + +import structlog +import lzma +import gzip + +class ARQDataTypeHandler: + def __init__(self): + self.logger = structlog.get_logger(type(self).__name__) + self.handlers = { + "raw": { + 'prepare': self.prepare_raw, + 'handle': self.handle_raw + }, + "raw_lzma": { + 'prepare': self.prepare_raw_lzma, + 'handle': self.handle_raw_lzma + }, + "raw_gzip": { + 'prepare': self.prepare_raw_gzip, + 'handle': self.handle_raw_gzip + }, + "p2pmsg_lzma": { + 'prepare': self.prepare_p2pmsg_lzma, + 'handle': self.handle_p2pmsg_lzma + }, + } + + def dispatch(self, type_byte: int, data: bytearray): + endpoint_name = list(self.handlers.keys())[type_byte] + if endpoint_name in self.handlers and 'handle' in self.handlers[endpoint_name]: + return self.handlers[endpoint_name]['handle'](data) + else: + self.log(f"Unknown handling endpoint: {endpoint_name}", isWarning=True) + + def prepare(self, data: bytearray, endpoint_name="raw" ): + if endpoint_name in self.handlers and 'prepare' in self.handlers[endpoint_name]: + return self.handlers[endpoint_name]['prepare'](data), list(self.handlers.keys()).index(endpoint_name) + else: + self.log(f"Unknown preparation endpoint: {endpoint_name}", isWarning=True) + + def log(self, message, isWarning=False): + msg = f"[{type(self).__name__}]: {message}" + logger = self.logger.warn if isWarning else self.logger.info + logger(msg) + + def prepare_raw(self, data): + self.log(f"Preparing uncompressed data: {len(data)} Bytes") + return data + + def handle_raw(self, data): + self.log(f"Handling uncompressed data: {len(data)} Bytes") + return data + + def prepare_raw_lzma(self, data): + compressed_data = lzma.compress(data) + self.log(f"Preparing LZMA compressed data: {len(data)} Bytes >>> {len(compressed_data)} Bytes") + return compressed_data + + def handle_raw_lzma(self, data): + decompressed_data = lzma.decompress(data) + self.log(f"Handling LZMA compressed data: {len(decompressed_data)} Bytes from {len(data)} Bytes") + return decompressed_data + + def prepare_raw_gzip(self, data): + compressed_data = gzip.compress(data) + self.log(f"Preparing GZIP compressed data: {len(data)} Bytes >>> {len(compressed_data)} Bytes") + return compressed_data + + def handle_raw_gzip(self, data): + decompressed_data = gzip.decompress(data) + self.log(f"Handling GZIP compressed data: {len(decompressed_data)} Bytes from {len(data)} Bytes") + return decompressed_data + + def prepare_p2pmsg_lzma(self, data): + compressed_data = lzma.compress(data) + self.log(f"Preparing LZMA compressed P2PMSG data: {len(data)} Bytes >>> {len(compressed_data)} Bytes") + return compressed_data + + def handle_p2pmsg_lzma(self, data): + decompressed_data = lzma.decompress(data) + self.log(f"Handling LZMA compressed P2PMSG data: {len(decompressed_data)} Bytes from {len(data)} Bytes") + return decompressed_data diff --git a/modem/arq_received_data_dispatcher.py b/modem/arq_received_data_dispatcher.py deleted file mode 100644 index b8572841..00000000 --- a/modem/arq_received_data_dispatcher.py +++ /dev/null @@ -1,35 +0,0 @@ -# File: arq_received_data_dispatcher.py - -import structlog -from arq_data_formatter import ARQDataFormatter - -class ARQReceivedDataDispatcher: - def __init__(self): - self.logger = structlog.get_logger(type(self).__name__) - self.arq_data_formatter = ARQDataFormatter() - self.endpoints = { - "p2pmsg": self.handle_p2pmsg, - "test": self.handle_test, - } - - def log(self, message, isWarning=False): - msg = f"[{type(self).__name__}]: {message}" - logger = self.logger.warn if isWarning else self.logger.info - logger(msg) - - def dispatch(self, byte_data): - """Use the data formatter to decapsulate and then dispatch data to the appropriate endpoint.""" - type_key, data = self.arq_data_formatter.decapsulate(byte_data) - if type_key in self.endpoints: - self.endpoints[type_key](data) - else: - self.handle_raw(data) - - def handle_p2pmsg(self, data): - self.log(f"Handling p2pmsg: {data}") - - def handle_raw(self, data): - self.log(f"Handling raw data: {data}") - - def handle_test(self, data): - self.log(f"Handling test data: {data}") diff --git a/modem/arq_session.py b/modem/arq_session.py index 71ef28e3..26756d90 100644 --- a/modem/arq_session.py +++ b/modem/arq_session.py @@ -5,7 +5,7 @@ import structlog from event_manager import EventManager from modem_frametypes import FRAME_TYPE import time -from arq_received_data_dispatcher import ARQReceivedDataDispatcher +from arq_data_type_handler import ARQDataTypeHandler class ARQSession(): @@ -46,7 +46,7 @@ class ARQSession(): self.frame_factory = data_frame_factory.DataFrameFactory(self.config) self.event_frame_received = threading.Event() - self.arq_received_data_dispatcher = ARQReceivedDataDispatcher() + self.arq_data_type_handler = ARQDataTypeHandler() self.id = None self.session_started = time.time() self.session_ended = 0 @@ -91,9 +91,9 @@ class ARQSession(): if self.state in self.STATE_TRANSITION: if frame_type in self.STATE_TRANSITION[self.state]: action_name = self.STATE_TRANSITION[self.state][frame_type] - received_data = getattr(self, action_name)(frame) - if received_data: - self.arq_received_data_dispatcher.dispatch(received_data) + received_data, type_byte = getattr(self, action_name)(frame) + if isinstance(received_data, bytearray) and isinstance(type_byte, int): + self.arq_data_type_handler.dispatch(type_byte, received_data) return diff --git a/modem/arq_session_irs.py b/modem/arq_session_irs.py index 7eb7f821..8e0b461f 100644 --- a/modem/arq_session_irs.py +++ b/modem/arq_session_irs.py @@ -69,6 +69,7 @@ class ARQSessionIRS(arq_session.ARQSession): self.state = IRS_State.NEW self.state_enum = IRS_State # needed for access State enum from outside + self.type_byte = None self.total_length = 0 self.total_crc = '' self.received_data = None @@ -115,6 +116,7 @@ class ARQSessionIRS(arq_session.ARQSession): self.launch_transmit_and_wait(ack_frame, self.TIMEOUT_CONNECT, mode=FREEDV_MODE.signalling) if not self.abort: self.set_state(IRS_State.OPEN_ACK_SENT) + return None, None def send_info_ack(self, info_frame): # Get session info from ISS @@ -122,6 +124,7 @@ class ARQSessionIRS(arq_session.ARQSession): self.total_length = info_frame['total_length'] self.total_crc = info_frame['total_crc'] self.dx_snr.append(info_frame['snr']) + self.type_byte = info_frame['type'] self.log(f"New transfer of {self.total_length} bytes") self.event_manager.send_arq_session_new(False, self.id, self.dxcall, self.total_length, self.state.name) @@ -135,7 +138,7 @@ class ARQSessionIRS(arq_session.ARQSession): self.launch_transmit_and_wait(info_ack, self.TIMEOUT_CONNECT, mode=FREEDV_MODE.signalling) if not self.abort: self.set_state(IRS_State.INFO_ACK_SENT) - + return None, None def process_incoming_data(self, frame): if frame['offset'] != self.received_bytes: @@ -175,7 +178,7 @@ class ARQSessionIRS(arq_session.ARQSession): # self.transmitted_acks += 1 self.set_state(IRS_State.BURST_REPLY_SENT) self.launch_transmit_and_wait(ack, self.TIMEOUT_DATA, mode=FREEDV_MODE.signalling) - return + return None, None if self.final_crc_matches(): self.log("All data received successfully!") @@ -192,7 +195,8 @@ class ARQSessionIRS(arq_session.ARQSession): self.set_state(IRS_State.ENDED) self.event_manager.send_arq_session_finished( False, self.id, self.dxcall, True, self.state.name, data=self.received_data, statistics=self.calculate_session_statistics()) - return self.received_data + + return self.received_data, self.type_byte else: ack = self.frame_factory.build_arq_burst_ack(self.id, @@ -208,7 +212,7 @@ class ARQSessionIRS(arq_session.ARQSession): self.set_state(IRS_State.FAILED) self.event_manager.send_arq_session_finished( False, self.id, self.dxcall, False, self.state.name, statistics=self.calculate_session_statistics()) - return False + return False, False def calibrate_speed_settings(self): self.speed_level = 0 # for now stay at lowest speed level @@ -231,4 +235,5 @@ class ARQSessionIRS(arq_session.ARQSession): self.launch_transmit_and_wait(stop_ack, self.TIMEOUT_CONNECT, mode=FREEDV_MODE.signalling) self.set_state(IRS_State.ABORTED) self.event_manager.send_arq_session_finished( - False, self.id, self.dxcall, False, self.state.name, statistics=self.calculate_session_statistics()) \ No newline at end of file + False, self.id, self.dxcall, False, self.state.name, statistics=self.calculate_session_statistics()) + return None, None \ No newline at end of file diff --git a/modem/arq_session_iss.py b/modem/arq_session_iss.py index 5edc47e4..14970262 100644 --- a/modem/arq_session_iss.py +++ b/modem/arq_session_iss.py @@ -53,13 +53,13 @@ class ARQSessionISS(arq_session.ARQSession): } } - def __init__(self, config: dict, modem, dxcall: str, data: bytearray, state_manager): + def __init__(self, config: dict, modem, dxcall: str, state_manager, data: bytearray, type_byte: bytes): super().__init__(config, modem, dxcall) self.state_manager = state_manager self.data = data self.total_length = len(data) self.data_crc = '' - + self.type_byte = type_byte self.confirmed_bytes = 0 self.state = ISS_State.NEW @@ -119,11 +119,13 @@ class ARQSessionISS(arq_session.ARQSession): info_frame = self.frame_factory.build_arq_session_info(self.id, self.total_length, helpers.get_crc_32(self.data), - self.snr[0]) + self.snr[0], self.type_byte) self.launch_twr(info_frame, self.TIMEOUT_CONNECT_ACK, self.RETRIES_CONNECT, mode=FREEDV_MODE.signalling) self.set_state(ISS_State.INFO_SENT) + return None, None + def send_data(self, irs_frame): self.set_speed_and_frames_per_burst(irs_frame) @@ -137,15 +139,15 @@ class ARQSessionISS(arq_session.ARQSession): # check if we received an abort flag if irs_frame["flag"]["ABORT"]: self.transmission_aborted(irs_frame) - return + return None, None if irs_frame["flag"]["FINAL"]: if self.confirmed_bytes == self.total_length and irs_frame["flag"]["CHECKSUM"]: self.transmission_ended(irs_frame) - return + else: self.transmission_failed() - return + return None, None payload_size = self.get_data_payload_size() burst = [] @@ -158,6 +160,7 @@ class ARQSessionISS(arq_session.ARQSession): burst.append(data_frame) self.launch_twr(burst, self.TIMEOUT_TRANSFER, self.RETRIES_CONNECT, mode='auto') self.set_state(ISS_State.BURST_SENT) + return None, None def transmission_ended(self, irs_frame): # final function for sucessfully ended transmissions @@ -166,6 +169,7 @@ class ARQSessionISS(arq_session.ARQSession): self.log(f"All data transfered! flag_final={irs_frame['flag']['FINAL']}, flag_checksum={irs_frame['flag']['CHECKSUM']}") self.event_manager.send_arq_session_finished(True, self.id, self.dxcall,True, self.state.name, statistics=self.calculate_session_statistics()) self.state_manager.remove_arq_iss_session(self.id) + return None, None def transmission_failed(self, irs_frame=None): # final function for failed transmissions @@ -173,6 +177,7 @@ class ARQSessionISS(arq_session.ARQSession): self.set_state(ISS_State.FAILED) self.log(f"Transmission failed!") self.event_manager.send_arq_session_finished(True, self.id, self.dxcall,False, self.state.name, statistics=self.calculate_session_statistics()) + return None, None def abort_transmission(self, irs_frame=None): # function for starting the abort sequence @@ -202,4 +207,5 @@ class ARQSessionISS(arq_session.ARQSession): self.event_manager.send_arq_session_finished( True, self.id, self.dxcall, False, self.state.name, statistics=self.calculate_session_statistics()) self.state_manager.remove_arq_iss_session(self.id) + return None, None diff --git a/modem/command.py b/modem/command.py index 9bcb76f4..331e3fa8 100644 --- a/modem/command.py +++ b/modem/command.py @@ -3,6 +3,8 @@ import queue from codec2 import FREEDV_MODE import structlog from state_manager import StateManager +from arq_data_type_handler import ARQDataTypeHandler + class TxCommand(): @@ -13,6 +15,7 @@ class TxCommand(): self.event_manager = event_manager self.set_params_from_api(apiParams) self.frame_factory = DataFrameFactory(config) + self.arq_data_type_handler = ARQDataTypeHandler() def set_params_from_api(self, apiParams): pass diff --git a/modem/command_arq_raw.py b/modem/command_arq_raw.py index 7544db71..4d640bd0 100644 --- a/modem/command_arq_raw.py +++ b/modem/command_arq_raw.py @@ -13,13 +13,20 @@ class ARQRawCommand(TxCommand): if not api_validations.validate_freedata_callsign(self.dxcall): self.dxcall = f"{self.dxcall}-0" + try: + self.type = apiParams['type'] + except KeyError: + self.type = "raw" + self.data = base64.b64decode(apiParams['data']) def run(self, event_queue: Queue, modem): self.emit_event(event_queue) self.logger.info(self.log_message()) - iss = ARQSessionISS(self.config, modem, self.dxcall, self.data, self.state_manager) + prepared_data, type_byte = self.arq_data_type_handler.prepare(self.data, self.type) + + iss = ARQSessionISS(self.config, modem, self.dxcall, self.state_manager, prepared_data, type_byte) if iss.id: self.state_manager.register_arq_iss_session(iss) iss.start() diff --git a/modem/data_frame_factory.py b/modem/data_frame_factory.py index 29c2f460..b62ba11b 100644 --- a/modem/data_frame_factory.py +++ b/modem/data_frame_factory.py @@ -15,7 +15,6 @@ class DataFrameFactory: 'FINAL': 0, # Bit-position for indicating the FINAL state 'ABORT': 1, # Bit-position for indicating the ABORT request 'CHECKSUM': 2, # Bit-position for indicating the CHECKSUM is correct or not - 'ENABLE_COMPRESSION': 3 # Bit-position for indicating compression is enabled } def __init__(self, config): @@ -118,6 +117,7 @@ class DataFrameFactory: "total_crc": 4, "snr": 1, "flag": 1, + "type": 1, } self.template_list[FR_TYPE.ARQ_SESSION_INFO_ACK.value] = { @@ -218,7 +218,7 @@ class DataFrameFactory: elif key in ["session_id", "speed_level", "frames_per_burst", "version", - "offset", "total_length", "state"]: + "offset", "total_length", "state", "type"]: extracted_data[key] = int.from_bytes(data, 'big') elif key in ["snr"]: @@ -350,10 +350,8 @@ class DataFrameFactory: } return self.construct(FR_TYPE.ARQ_SESSION_OPEN_ACK, payload) - def build_arq_session_info(self, session_id: int, total_length: int, total_crc: bytes, snr, flag_compression=False): + def build_arq_session_info(self, session_id: int, total_length: int, total_crc: bytes, snr, type): flag = 0b00000000 - if flag_compression: - flag = helpers.set_flag(flag, 'ENABLE_COMPRESSION', True, self.ARQ_FLAGS) payload = { "session_id": session_id.to_bytes(1, 'big'), @@ -361,6 +359,7 @@ class DataFrameFactory: "total_crc": total_crc, "snr": helpers.snr_to_bytes(1), "flag": flag.to_bytes(1, 'big'), + "type": type.to_bytes(1, 'big'), } return self.construct(FR_TYPE.ARQ_SESSION_INFO, payload) @@ -377,7 +376,6 @@ class DataFrameFactory: } return self.construct(FR_TYPE.ARQ_STOP_ACK, payload) - def build_arq_session_info_ack(self, session_id, total_crc, snr, speed_level, frames_per_burst, flag_final=False, flag_abort=False): flag = 0b00000000 if flag_final: diff --git a/modem/frame_handler.py b/modem/frame_handler.py index 3d454782..d11ba742 100644 --- a/modem/frame_handler.py +++ b/modem/frame_handler.py @@ -31,7 +31,6 @@ class FrameHandler(): def is_frame_for_me(self): call_with_ssid = self.config['STATION']['mycall'] + "-" + str(self.config['STATION']['myssid']) ft = self.details['frame']['frame_type'] - print(self.details) valid = False # Check for callsign checksum if ft in ['ARQ_SESSION_OPEN', 'ARQ_SESSION_OPEN_ACK', 'PING', 'PING_ACK']: diff --git a/tests/test_arq_session.py b/tests/test_arq_session.py index 4bf66ad9..ecfc0b02 100644 --- a/tests/test_arq_session.py +++ b/tests/test_arq_session.py @@ -126,12 +126,13 @@ class TestARQSession(unittest.TestCase): def testARQSessionSmallPayload(self): # set Packet Error Rate (PER) / frame loss probability - self.loss_probability = 50 + self.loss_probability = 0 self.establishChannels() params = { 'dxcall': "XX1XXX-1", 'data': base64.b64encode(bytes("Hello world!", encoding="utf-8")), + 'type': "raw_lzma" } cmd = ARQRawCommand(self.config, self.iss_state_manager, self.iss_event_queue, params) cmd.run(self.iss_event_queue, self.iss_modem) @@ -146,6 +147,7 @@ class TestARQSession(unittest.TestCase): params = { 'dxcall': "XX1XXX-1", 'data': base64.b64encode(np.random.bytes(1000)), + 'type': "raw_lzma" } cmd = ARQRawCommand(self.config, self.iss_state_manager, self.iss_event_queue, params) cmd.run(self.iss_event_queue, self.iss_modem) diff --git a/tests/test_data_dispatcher.py b/tests/test_data_dispatcher.py deleted file mode 100644 index 90b64fa9..00000000 --- a/tests/test_data_dispatcher.py +++ /dev/null @@ -1,35 +0,0 @@ -import sys -sys.path.append('modem') - -import unittest -from arq_data_formatter import ARQDataFormatter -from arq_received_data_dispatcher import ARQReceivedDataDispatcher - -class TestDispatcher(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.data_dispatcher = ARQReceivedDataDispatcher() - cls.data_formatter = ARQDataFormatter() - - - def testEncapsulator(self): - message_type = "p2pmsg" - message_data = {"message": "Hello, P2P World!"} - - encapsulated = self.data_formatter.encapsulate(message_data, message_type) - type, decapsulated = self.data_formatter.decapsulate(encapsulated.encode('utf-8')) - self.assertEqual(type, message_type) - self.assertEqual(decapsulated, message_data) - - def testDispatcher(self): - message_type = "test" - message_data = {"message": "Hello, P2P World!"} - - encapsulated = self.data_formatter.encapsulate(message_data, message_type) - self.data_dispatcher.dispatch(encapsulated.encode('utf-8')) - - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_data_type_handler.py b/tests/test_data_type_handler.py new file mode 100644 index 00000000..b7b8cc26 --- /dev/null +++ b/tests/test_data_type_handler.py @@ -0,0 +1,37 @@ +import sys +sys.path.append('modem') + +import unittest +from arq_data_type_handler import ARQDataTypeHandler + +class TestDispatcher(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.arq_data_type_handler = ARQDataTypeHandler() + + + def testDataTypeHandlerRaw(self): + # Example usage + example_data = b"Hello FreeDATA!" + formatted_data, type_byte = self.arq_data_type_handler.prepare(example_data, "raw") + dispatched_data = self.arq_data_type_handler.dispatch(type_byte, formatted_data) + self.assertEqual(example_data, dispatched_data) + + def testDataTypeHandlerLZMA(self): + # Example usage + example_data = b"Hello FreeDATA!" + formatted_data, type_byte = self.arq_data_type_handler.prepare(example_data, "raw_lzma") + dispatched_data = self.arq_data_type_handler.dispatch(type_byte, formatted_data) + self.assertEqual(example_data, dispatched_data) + + def testDataTypeHandlerGZIP(self): + # Example usage + example_data = b"Hello FreeDATA!" + formatted_data, type_byte = self.arq_data_type_handler.prepare(example_data, "raw_gzip") + dispatched_data = self.arq_data_type_handler.dispatch(type_byte, formatted_data) + self.assertEqual(example_data, dispatched_data) + + +if __name__ == '__main__': + unittest.main()