diff --git a/modem/arq_session_iss.py b/modem/arq_session_iss.py index 68c4e922..9c69522c 100644 --- a/modem/arq_session_iss.py +++ b/modem/arq_session_iss.py @@ -182,7 +182,6 @@ class ARQSessionISS(arq_session.ARQSession): self.states.setARQ(False) self.arq_data_type_handler.failed(self.type_byte, self.data) - return None, None def abort_transmission(self, irs_frame=None): diff --git a/modem/message_system_db_messages.py b/modem/message_system_db_messages.py index 255b8ab5..639ac92f 100644 --- a/modem/message_system_db_messages.py +++ b/modem/message_system_db_messages.py @@ -171,21 +171,27 @@ class DatabaseManagerMessages(DatabaseManager): finally: session.remove() - def increment_message_attempts(self, message_id): - session = self.get_thread_scoped_session() + def increment_message_attempts(self, message_id, session=None): + own_session = False + if not session: + session = self.get_thread_scoped_session() + own_session = True try: message = session.query(P2PMessage).filter_by(id=message_id).first() if message: message.attempt += 1 - session.commit() + if own_session: + session.commit() self.log(f"Incremented attempt count for message {message_id}") else: self.log(f"Message with ID {message_id} not found") except Exception as e: - session.rollback() + if own_session: + session.rollback() self.log(f"An error occurred while incrementing attempts for message {message_id}: {e}") finally: - session.remove() + if own_session: + session.remove() def mark_message_as_read(self, message_id): session = self.get_thread_scoped_session() @@ -217,23 +223,21 @@ class DatabaseManagerMessages(DatabaseManager): return # Query for messages with the specified callsign, 'failed' status, and fewer than 10 attempts - messages = session.query(P2PMessage) \ - .filter(P2PMessage.origin_callsign == callsign) \ + message = session.query(P2PMessage) \ + .filter(P2PMessage.destination_callsign == callsign) \ .filter(P2PMessage.status_id == failed_status.id) \ .filter(P2PMessage.attempt < 10) \ - .all() + .first() - if messages: - # Update each message's status to 'queued' - for message in messages: - # Increment attempt count using the existing function - self.increment_message_attempts(message.id) + if message: + # Increment attempt count using the existing function + self.increment_message_attempts(message.id, session) - message.status_id = queued_status.id - self.log(f"Set message {message.id} to queued and incremented attempt") + message.status_id = queued_status.id + self.log(f"Set message {message.id} to queued and incremented attempt") session.commit() - return {'status': 'success', 'message': f'{len(messages)} message(s) set to queued'} + return {'status': 'success', 'message': f'{len(message)} message(s) set to queued'} else: return {'status': 'failure', 'message': 'No eligible messages found'} except Exception as e: diff --git a/tests/test_arq_session.py b/tests/test_arq_session.py index 961c2262..185e90e3 100644 --- a/tests/test_arq_session.py +++ b/tests/test_arq_session.py @@ -103,7 +103,7 @@ class TestARQSession(unittest.TestCase): def waitForSession(self, q, outbound = False): key = 'arq-transfer-outbound' if outbound else 'arq-transfer-inbound' - while True: + while True and self.channels_running: ev = q.get() if key in ev and ('success' in ev[key] or 'ABORTED' in ev[key]): self.logger.info(f"[{threading.current_thread().name}] {key} session ended.") @@ -125,6 +125,7 @@ class TestARQSession(unittest.TestCase): def waitAndCloseChannels(self): self.waitForSession(self.iss_event_queue, True) + self.channels_running = False self.waitForSession(self.irs_event_queue, False) self.channels_running = False diff --git a/tests/test_message_protocol.py b/tests/test_message_protocol.py index 591daefa..eedf7a00 100644 --- a/tests/test_message_protocol.py +++ b/tests/test_message_protocol.py @@ -108,7 +108,7 @@ class TestMessageProtocol(unittest.TestCase): def waitForSession(self, q, outbound=False): key = 'arq-transfer-outbound' if outbound else 'arq-transfer-inbound' - while True: + while True and self.channels_running: ev = q.get() if key in ev and ('success' in ev[key] or 'ABORTED' in ev[key]): self.logger.info(f"[{threading.current_thread().name}] {key} session ended.") @@ -130,6 +130,7 @@ class TestMessageProtocol(unittest.TestCase): def waitAndCloseChannels(self): self.waitForSession(self.iss_event_queue, True) + self.channels_running = False self.waitForSession(self.irs_event_queue, False) self.channels_running = False