From 0014462c63c973acb84b2f20d120a575b49dbd2e Mon Sep 17 00:00:00 2001 From: Pedro Date: Thu, 14 Dec 2023 17:29:04 +0100 Subject: [PATCH] ARQ refactor with explicit states --- modem/arq_session.py | 28 +++- modem/arq_session_irs.py | 254 +++++++++++------------------ modem/arq_session_iss.py | 182 ++++++++------------- modem/command_arq_raw.py | 2 +- modem/frame_handler_arq_session.py | 58 +++---- tests/test_arq_session.py | 2 +- 6 files changed, 204 insertions(+), 322 deletions(-) diff --git a/modem/arq_session.py b/modem/arq_session.py index 5703afb4..5c1dfcfb 100644 --- a/modem/arq_session.py +++ b/modem/arq_session.py @@ -2,6 +2,7 @@ import queue, threading import codec2 import data_frame_factory import structlog +from modem_frametypes import FRAME_TYPE class ARQSession(): @@ -15,12 +16,17 @@ class ARQSession(): self.logger = structlog.get_logger(type(self).__name__) self.config = config + self.snr = [] + self.dxcall = dxcall + self.dx_snr = [] self.tx_frame_queue = tx_frame_queue self.speed_level = 0 + self.frames_per_burst = 1 self.frame_factory = data_frame_factory.DataFrameFactory(self.config) + self.event_frame_received = threading.Event() self.id = None @@ -46,10 +52,24 @@ class ARQSession(): self.log(f"{type(self).__name__} state change from {self.state} to {state}") self.state = state - def get_payload_size(self, speed_level): - mode = self.MODE_BY_SPEED[speed_level] - return codec2.get_bytes_per_frame(mode.value) + def get_data_payload_size(self): + return self.frame_factory.get_available_data_payload_for_mode( + FRAME_TYPE.ARQ_BURST_FRAME, + self.MODE_BY_SPEED[self.speed_level] + ) def set_details(self, snr, frequency_offset): - self.snr = snr + self.snr.append(snr) self.frequency_offset = frequency_offset + + def on_frame_received(self, frame): + self.event_frame_received.set() + frame_type = frame['frame_type_int'] + if self.state in self.STATE_TRANSITION: + if frame_type in self.STATE_TRANSITION[self.state]: + action_name = self.STATE_TRANSITION[self.state][frame_type] + getattr(self, action_name)(frame) + return + + self.log(f"Ignoring unknow transition from state {self.state} with frame {frame['frame_type']}") + \ No newline at end of file diff --git a/modem/arq_session_irs.py b/modem/arq_session_irs.py index afc08666..b5acccf9 100644 --- a/modem/arq_session_irs.py +++ b/modem/arq_session_irs.py @@ -3,14 +3,16 @@ import data_frame_factory import queue import arq_session import helpers +from modem_frametypes import FRAME_TYPE class ARQSessionIRS(arq_session.ARQSession): - STATE_CONN_REQ_RECEIVED = 0 - STATE_WAITING_INFO = 1 - STATE_WAITING_DATA = 2 - STATE_FAILED = 3 - STATE_ENDED = 10 + STATE_NEW = 0 + STATE_OPEN_ACK_SENT = 1 + STATE_INFO_ACK_SENT = 2 + STATE_BURST_REPLY_SENT = 3 + STATE_ENDED = 4 + STATE_FAILED = 5 RETRIES_CONNECT = 3 RETRIES_TRANSFER = 3 # we need to increase this @@ -18,140 +20,123 @@ class ARQSessionIRS(arq_session.ARQSession): TIMEOUT_CONNECT = 6 TIMEOUT_DATA = 6 + STATE_TRANSITION = { + STATE_NEW: { + FRAME_TYPE.ARQ_SESSION_OPEN.value : 'send_open_ack', + }, + STATE_OPEN_ACK_SENT: { + FRAME_TYPE.ARQ_SESSION_OPEN.value: 'send_open_ack', + FRAME_TYPE.ARQ_SESSION_INFO.value: 'send_info_ack', + }, + STATE_INFO_ACK_SENT: { + FRAME_TYPE.ARQ_SESSION_INFO.value: 'send_info_ack', + FRAME_TYPE.ARQ_BURST_FRAME.value: 'receive_data', + }, + STATE_BURST_REPLY_SENT: { + FRAME_TYPE.ARQ_BURST_FRAME.value: 'receive_data', + }, + } + def __init__(self, config: dict, tx_frame_queue: queue.Queue, dxcall: str, session_id: int): super().__init__(config, tx_frame_queue, dxcall) self.id = session_id - self.speed = 0 - self.frames_per_burst = 3 + self.dxcall = dxcall self.version = 1 - self.snr = 0 - self.dx_snr = 0 - self.retries = self.RETRIES_TRANSFER - self.state = self.STATE_CONN_REQ_RECEIVED + self.state = self.STATE_NEW - self.event_info_received = threading.Event() - self.event_data_received = threading.Event() - - self.frame_factory = data_frame_factory.DataFrameFactory(self.config) - - self.received_frame = None + self.total_length = 0 + self.total_crc = '' self.received_data = None self.received_bytes = 0 self.received_crc = None - def generate_id(self): - pass - - def set_state(self, state): - self.log(f"ARQ Session IRS {self.id} state {self.state}") - self.state = state - def set_modem_decode_modes(self, modes): pass - def _all_data_received(self): + def all_data_received(self): return self.received_bytes == len(self.received_data) - def _final_crc_check(self): - return self.received_crc == helpers.get_crc_32(bytes(self.received_data)).hex() + def final_crc_check(self): + return self.total_crc == helpers.get_crc_32(bytes(self.received_data)).hex() - def handshake_session(self): - if self.state in [self.STATE_CONN_REQ_RECEIVED, self.STATE_WAITING_INFO]: - self.send_open_ack() - self.set_state(self.STATE_WAITING_INFO) - return True - return False - - def handshake_info(self): - if self.state == self.STATE_WAITING_INFO and not self.event_info_received.wait(self.TIMEOUT_CONNECT): - return False - - self.send_info_ack() - self.set_state(self.STATE_WAITING_DATA) - return True - - def send_info_ack(self): - info_ack = self.frame_factory.build_arq_session_info_ack( - self.id, self.received_crc, self.snr, - self.speed_level, self.frames_per_burst) - self.transmit_frame(info_ack) - - - def receive_data(self): - self.retries = self.RETRIES_TRANSFER - while self.retries > 0 and not self._all_data_received(): - if self.event_data_received.wait(self.TIMEOUT_DATA): - self.process_incoming_data() - self.send_burst_ack_nack(True) - self.retries = self.RETRIES_TRANSFER - else: - self.send_burst_ack_nack(False) - self.retries -= 1 - - if self._all_data_received(): - if self._final_crc_check(): - self.set_state(self.STATE_ENDED) - self.logger.info("------ ALL DATA RECEIVED ------", state=self.state, dxcall=self.dxcall, snr=self.snr) - - else: - self.logger.warning("CRC check failed.") - self.set_state(self.STATE_FAILED) - - else: + def transmit_and_wait(self, frame, timeout): + self.transmit_frame(frame) + if not self.event_frame_received.wait(timeout): + self.log("Timeout waiting for ISS to say something") self.set_state(self.STATE_FAILED) - # finally send a data ack / nack - self.send_data_ack_nack() - - def runner(self): - - if not self.handshake_session(): - return False - - if not self.handshake_info(): - return False - - if not self.receive_data(): - return False - return True - - def run(self): - self.set_state(self.STATE_CONN_REQ_RECEIVED) - self.thread = threading.Thread(target=self.runner, - name=f"ARQ IRS Session {self.id}", daemon=False) - self.thread.start() - - def send_open_ack(self): + def launch_transmit_and_wait(self, frame, timeout): + thread_wait = threading.Thread(target = self.transmit_and_wait, + args = [frame, timeout]) + thread_wait.start() + + def send_open_ack(self, open_frame): ack_frame = self.frame_factory.build_arq_session_open_ack( self.id, self.dxcall, self.version, - self.snr) - self.transmit_frame(ack_frame) + self.snr[0]) + self.launch_transmit_and_wait(ack_frame, self.TIMEOUT_CONNECT) + self.set_state(self.STATE_OPEN_ACK_SENT) - def send_burst_ack_nack(self, ack: bool): - if ack: - builder = self.frame_factory.build_arq_burst_ack + def send_info_ack(self, info_frame): + # Get session info from ISS + self.received_data = bytearray(info_frame['total_length']) + self.total_crc = info_frame['total_crc'] + self.dx_snr.append(info_frame['snr']) + + info_ack = self.frame_factory.build_arq_session_info_ack( + self.id, self.total_crc, self.snr[0], + self.speed_level, self.frames_per_burst) + self.launch_transmit_and_wait(info_ack, self.TIMEOUT_CONNECT) + self.set_state(self.STATE_INFO_ACK_SENT) + + def process_incoming_data(self, frame): + if frame['offset'] != self.received_bytes: + self.logger.info(f"Discarding data frame due to wrong offset", frame=self.frame_received) + return False + + remaining_data_length = len(self.received_data) - self.received_bytes + + # Is this the last data part? + if remaining_data_length <= len(frame['data']): + # we only want the remaining length, not the entire frame data + data_part = frame['data'][:remaining_data_length] else: - builder = self.frame_factory.build_arq_burst_nack + # we want the entire frame data + data_part = frame['data'] - frame = builder ( + self.received_data[frame['offset']:] = data_part + self.received_bytes += len(data_part) + + return True + + def receive_data(self, burst_frame): + self.process_incoming_data(burst_frame) + + ack = self.frame_factory.build_arq_burst_ack( self.id, self.received_bytes, - self.speed_level, self.frames_per_burst, self.snr) - - self.transmit_frame(frame) + self.speed_level, self.frames_per_burst, self.snr[0]) - def send_data_ack_nack(self): + if not self.all_data_received(): + self.transmit_and_wait(ack) + self.set_state(self.STATE_BURST_REPLY_SENT) + return - builder = self.frame_factory.build_arq_data_ack_nack - frame = builder(self.id, self.state, self.snr) - self.transmit_frame(frame) + if self.final_crc_check(): + self.log("All data received successfully!") + self.transmit_frame(ack) + self.set_state(self.STATE_ENDED) + else: + self.log("CRC fail at the end of transmission!") + self.set_state(self.STATE_FAILED) def calibrate_speed_settings(self): - + return + # decrement speed level after the 2nd retry if self.RETRIES_TRANSFER - self.retries >= 2: self.speed -= 1 @@ -164,58 +149,3 @@ class ARQSessionIRS(arq_session.ARQSession): self.speed = self.speed self.frames_per_burst = self.frames_per_burst - - def on_info_received(self, frame): - if self.state != self.STATE_WAITING_INFO: - self.logger.warning("Discarding received INFO.", state=self.state) - return - - self.received_data = bytearray(frame['total_length']) - self.received_crc = frame['total_crc'] - self.dx_snr = frame['snr'] - - self.calibrate_speed_settings() - self.set_modem_decode_modes(None) - - self.event_info_received.set() - - def on_data_received(self, frame): - if self.state != self.STATE_WAITING_DATA: - self.logger.warning(f"ARQ Session: Received data while in state {self.state}. Ignoring.") - return - - self.received_frame = frame - self.event_data_received.set() - - def process_incoming_data(self): - if self.received_frame['offset'] != self.received_bytes: - self.logger.info(f"Discarding data frame due to wrong offset", frame=self.frame_received) - return False - - remaining_data_length = len(self.received_data) - self.received_bytes - - # Is this the last data part? - if remaining_data_length <= len(self.received_frame['data']): - # we only want the remaining length, not the entire frame data - data_part = self.received_frame['data'][:remaining_data_length] - else: - # we want the entire frame data - data_part = self.received_frame['data'] - - self.received_data[self.received_frame['offset']:] = data_part - self.received_bytes += len(data_part) - - return True - - def on_burst_ack_received(self, ack): - self.event_transfer_ack_received.set() - self.speed_level = ack['speed_level'] - - def on_burst_nack_received(self, nack): - self.speed_level = nack['speed_level'] - - def on_disconnect_received(self): - self.abort() - - def abort(self): - self.state = self.STATE_DISCONNECTED diff --git a/modem/arq_session_iss.py b/modem/arq_session_iss.py index 067a39c9..b68d6970 100644 --- a/modem/arq_session_iss.py +++ b/modem/arq_session_iss.py @@ -3,17 +3,18 @@ import data_frame_factory import queue import random from codec2 import FREEDV_MODE +from modem_frametypes import FRAME_TYPE import arq_session import helpers class ARQSessionISS(arq_session.ARQSession): - STATE_DISCONNECTED = 0 - STATE_CONNECTING = 1 - STATE_CONNECTED = 2 - STATE_SENDING = 3 - - STATE_ENDED = 10 + STATE_NEW = 0 + STATE_OPEN_SENT = 1 + STATE_INFO_SENT = 2 + STATE_BURST_SENT = 3 + STATE_ENDED = 4 + STATE_FAILED = 5 RETRIES_CONNECT = 3 RETRIES_TRANSFER = 3 @@ -21,139 +22,84 @@ class ARQSessionISS(arq_session.ARQSession): TIMEOUT_CONNECT_ACK = 5 TIMEOUT_TRANSFER = 2 + STATE_TRANSITION = { + STATE_OPEN_SENT: { + FRAME_TYPE.ARQ_SESSION_OPEN_ACK.value: 'send_info', + }, + STATE_INFO_SENT: { + FRAME_TYPE.ARQ_SESSION_OPEN_ACK.value: 'send_info', + FRAME_TYPE.ARQ_SESSION_INFO_ACK.value: 'send_data', + }, + STATE_BURST_SENT: { + FRAME_TYPE.ARQ_SESSION_INFO_ACK.value: 'send_data', + FRAME_TYPE.ARQ_BURST_ACK.value: 'send_data', + FRAME_TYPE.ARQ_BURST_NACK.value: 'send_data', + }, + } + def __init__(self, config: dict, tx_frame_queue: queue.Queue, dxcall: str, data: bytearray): super().__init__(config, tx_frame_queue, dxcall) self.data = data + self.data_crc = '' - self.state = self.STATE_DISCONNECTED + self.confirmed_bytes = 0 + + self.state = self.STATE_NEW self.id = self.generate_id() - self.event_open_ack_received = threading.Event() - self.event_info_ack_received = threading.Event() - self.event_transfer_ack_received = threading.Event() - self.event_transfer_data_ack_nack_received = threading.Event() - self.frame_factory = data_frame_factory.DataFrameFactory(self.config) def generate_id(self): return random.randint(1,255) - - def set_state(self, state): - self.logger.info(f"ARQ Session ISS {self.id} state {self.state}") - self.state = state - - def runner(self): - self.state = self.STATE_CONNECTING - - if not self.session_open(): - return False - - if not self.session_info(): - return False - - return self.send_data() - def run(self): - self.thread = threading.Thread(target=self.runner, name=f"ARQ ISS Session {self.id}", daemon=False) - self.thread.run() - - def handshake(self, frame, event): - retries = self.RETRIES_CONNECT + def transmit_wait_and_retry(self, frame_or_burst, timeout, retries): while retries > 0: - self.transmit_frame(frame) - self.logger.info("Waiting...") - if event.wait(self.TIMEOUT_CONNECT_ACK): - return True + if isinstance(frame_or_burst, list): burst = frame_or_burst + else: burst = [frame_or_burst] + for f in burst: + self.transmit_frame(f) + if self.event_frame_received.wait(timeout): + self.log("Timeout interrupted due to received frame.") + break retries = retries - 1 - self.set_state(self.STATE_DISCONNECTED) - return False + def launch_twr(self, frame_or_burst, timeout, retries): + twr = threading.Thread(target = self.transmit_wait_and_retry, args=[frame_or_burst, timeout, retries]) + twr.start() - def session_open(self): - open_frame = self.frame_factory.build_arq_session_open(self.dxcall, self.id) - return self.handshake(open_frame, self.event_open_ack_received) + def start(self): + session_open_frame = self.frame_factory.build_arq_session_open(self.dxcall, self.id) + self.launch_twr(session_open_frame, self.TIMEOUT_CONNECT_ACK, self.RETRIES_CONNECT) + self.set_state(self.STATE_OPEN_SENT) - def session_info(self): + def set_speed_and_frames_per_burst(self, frame): + self.speed_level = frame['speed_level'] + self.log(f"Speed level set to {self.speed_level}") + self.frames_per_burst = frame['frames_per_burst'] + self.log(f"Frames per burst set to {self.frames_per_burst}") + + def send_info(self, open_ack_frame): info_frame = self.frame_factory.build_arq_session_info(self.id, len(self.data), helpers.get_crc_32(self.data), - self.snr) - return self.handshake(info_frame, self.event_info_ack_received) + self.snr[0]) + self.launch_twr(info_frame, self.TIMEOUT_CONNECT_ACK, self.RETRIES_CONNECT) + self.set_state(self.STATE_INFO_SENT) - def on_open_ack_received(self, ack): - if self.state != self.STATE_CONNECTING: - raise RuntimeError(f"ARQ Session: Received OPEN ACK while in state {self.state}") + def send_data(self, irs_frame): + self.set_speed_and_frames_per_burst(irs_frame) - self.event_open_ack_received.set() + if 'offset' in irs_frame: + self.confirmed_bytes = irs_frame['offset'] - def on_info_ack_received(self, ack): - if self.state != self.STATE_CONNECTING: - raise RuntimeError(f"ARQ Session: Received INFO ACK while in state {self.state}") - - self.event_info_ack_received.set() - - # Sends the full payload in multiple frames - def send_data(self): - offset = 0 - while offset < len(self.data): - max_size = self.get_payload_size(self.speed_level) - end_offset = min(len(self.data), max_size) - frame_payload = self.data[offset:end_offset] + payload_size = self.get_data_payload_size() + burst = [] + for f in range(0, self.frames_per_burst): + offset = self.confirmed_bytes + payload = self.data[offset : offset + payload_size] data_frame = self.frame_factory.build_arq_burst_frame( self.MODE_BY_SPEED[self.speed_level], - self.id, offset, frame_payload) - self.set_state(self.STATE_SENDING) - if not self.send_arq(data_frame): - return False - offset = end_offset + 1 + self.id, self.confirmed_bytes, payload) + burst.append(data_frame) - self.awaiting_data_ack_nack() - - - # Send part of the payload using ARQ - def send_arq(self, frame): - retries = self.RETRIES_TRANSFER - while retries > 0: - # to know later if it has changed - speed_level = self.speed_level - self.transmit_frame(frame) - # wait for ack - if self.event_transfer_ack_received.wait(self.TIMEOUT_TRANSFER): - speed_level = self.speed_level - return True - - # don't decrement retries if speed level is changing - if self.speed_level == speed_level: - retries = retries - 1 - - self.set_state(self.STATE_DISCONNECTED) - return False - - def awaiting_data_ack_nack(self): - # TODO Implement the final logics after receiving an ACK or NACK for transmitted data - self.logger.info(f"Awaiting data ack/nack") - if not self.event_transfer_data_ack_nack_received.wait(self.TIMEOUT_TRANSFER): - self.logger.warning(f"data ack / nack missed after timeout") - self.logger.info(f"data ack nack received...") - - def on_burst_ack_received(self, ack): - self.speed_level = ack['speed_level'] - self.event_transfer_ack_received.set() - - def on_burst_nack_received(self, nack): - self.speed_level = nack['speed_level'] - self.event_transfer_ack_received.set() - - def on_data_ack_nack_received(self, ack_nack): - self.event_transfer_data_ack_nack_received.set() - state = ack_nack['state'] - print(state) - - def on_disconnect_received(self): - self.abort() - - def abort(self): - self.state = self.STATE_DISCONNECTED - self.event_connection_ack_received.set() - self.event_connection_ack_received.clear() - self.event_transfer_feedback.set() - self.event_transfer_feedback.clear() + self.launch_twr(burst, self.TIMEOUT_CONNECT_ACK, self.RETRIES_CONNECT) + self.set_state(self.STATE_BURST_SENT) diff --git a/modem/command_arq_raw.py b/modem/command_arq_raw.py index 00b67963..09d14373 100644 --- a/modem/command_arq_raw.py +++ b/modem/command_arq_raw.py @@ -19,5 +19,5 @@ class ARQRawCommand(TxCommand): iss = ARQSessionISS(self.config, tx_frame_queue, self.dxcall, self.data) self.state_manager.register_arq_iss_session(iss) - iss.run() + iss.start() return iss diff --git a/modem/frame_handler_arq_session.py b/modem/frame_handler_arq_session.py index acc3b404..f257ad2b 100644 --- a/modem/frame_handler_arq_session.py +++ b/modem/frame_handler_arq_session.py @@ -10,55 +10,41 @@ class ARQFrameHandler(frame_handler.FrameHandler): def follow_protocol(self): frame = self.details['frame'] + session_id = frame['session_id'] snr = self.details["snr"] frequency_offset = self.details["frequency_offset"] if frame['frame_type_int'] == FR.ARQ_SESSION_OPEN.value: + # Lost OPEN_ACK case .. ISS will retry opening a session - if frame['session_id'] in self.states.arq_irs_sessions: - session = self.states.arq_irs_sessions[frame['session_id']] - if session.state in [ARQSessionIRS.STATE_CONN_REQ_RECEIVED, ARQSessionIRS.STATE_WAITING_INFO]: - session.set_details(snr, frequency_offset) - else: - self.logger.warning(f"IRS Session conflict for session {session.id}") + if session_id in self.states.arq_irs_sessions: + session = self.states.arq_irs_sessions[session_id] + # Normal case when receiving a SESSION_OPEN for the first time else: session = ARQSessionIRS(self.config, self.tx_frame_queue, frame['origin'], - frame['session_id']) + session_id) self.states.register_arq_irs_session(session) - session.set_details(snr, frequency_offset) - session.run() - elif frame['frame_type_int'] == FR.ARQ_SESSION_OPEN_ACK.value: - session:ARQSessionISS = self.states.get_arq_iss_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_open_ack_received(frame) + elif frame['frame_type_int'] in [ + FR.ARQ_SESSION_INFO.value, + FR.ARQ_BURST_FRAME.value, + ]: + session = self.states.get_arq_irs_session(session_id) - elif frame['frame_type_int'] == FR.ARQ_SESSION_INFO.value: - session:ARQSessionIRS = self.states.get_arq_irs_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_info_received(frame) + elif frame['frame_type_int'] in [ + FR.ARQ_SESSION_OPEN_ACK.value, + FR.ARQ_SESSION_INFO_ACK.value, + FR.ARQ_BURST_ACK.value, + FR.ARQ_DATA_ACK_NACK.value + ]: + session = self.states.get_arq_iss_session(session_id) - elif frame['frame_type_int'] == FR.ARQ_SESSION_INFO_ACK.value: - session:ARQSessionISS = self.states.get_arq_iss_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_info_ack_received(frame) - - elif frame['frame_type_int'] == FR.ARQ_BURST_FRAME.value: - session:ARQSessionIRS = self.states.get_arq_irs_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_data_received(frame) - - elif frame['frame_type_int'] == FR.ARQ_BURST_ACK.value: - session:ARQSessionISS = self.states.get_arq_iss_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_burst_ack_received(frame) - - elif frame['frame_type_int'] == FR.ARQ_DATA_ACK_NACK.value: - session:ARQSessionISS = self.states.get_arq_iss_session(frame['session_id']) - session.set_details(snr, frequency_offset) - session.on_data_ack_nack_received(frame) else: self.logger.warning("DISCARDING FRAME", frame=frame) + return + + session.set_details(snr, frequency_offset) + session.on_frame_received(frame) diff --git a/tests/test_arq_session.py b/tests/test_arq_session.py index 6027f530..305c5acf 100644 --- a/tests/test_arq_session.py +++ b/tests/test_arq_session.py @@ -72,7 +72,7 @@ class TestARQSession(unittest.TestCase): def testARQSession(self): # set Packet Error Rate (PER) / frame loss probability - self.loss_probability = 20 + self.loss_probability = 0 self.establishChannels() params = {