# This file is part of Scapy # See http://www.secdev.org/projects/scapy for more information # Copyright (C) Philippe Biondi # Copyright (C) Gabriel Potter # This program is published under a GPLv2 license """ Automata with states, transitions and actions. """ from __future__ import absolute_import import types import itertools import time import os import sys import traceback from select import select from collections import deque import threading from scapy.config import conf from scapy.utils import do_graph from scapy.error import log_runtime, warning from scapy.plist import PacketList from scapy.data import MTU from scapy.supersocket import SuperSocket from scapy.consts import WINDOWS import scapy.modules.six as six """ In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionality # noqa: E501 # Passive way: using no-ressources locks +---------+ +---------------+ +-------------------------+ # noqa: E501 | Start +------------->Select_objects +----->+Linux: call select.select| # noqa: E501 +---------+ |(select.select)| +-------------------------+ # noqa: E501 +-------+-------+ | +----v----+ +--------+ | Windows | |Time Out+----------------------------------+ # noqa: E501 +----+----+ +----+---+ | # noqa: E501 | ^ | # noqa: E501 Event | | | # noqa: E501 + | | | # noqa: E501 | +-------v-------+ | | # noqa: E501 | +------+Selectable Sel.+-----+-----------------+-----------+ | # noqa: E501 | | +-------+-------+ | | | v +-----v-----+ # noqa: E501 +-------v----------+ | | | | | Passive lock<-----+release_all<------+ # noqa: E501 |Data added to list| +----v-----+ +-----v-----+ +----v-----+ v v + +-----------+ | # noqa: E501 +--------+---------+ |Selectable| |Selectable | |Selectable| ............ | | # noqa: E501 | +----+-----+ +-----------+ +----------+ | | # noqa: E501 | v | | # noqa: E501 v +----+------+ +------------------+ +-------------v-------------------+ | # noqa: E501 +-----+------+ |wait_return+-->+ check_recv: | | | | # noqa: E501 |call_release| +----+------+ |If data is in list| | END state: selectable returned | +---+--------+ # noqa: E501 +-----+-------- v +-------+----------+ | | | exit door | # noqa: E501 | else | +---------------------------------+ +---+--------+ # noqa: E501 | + | | # noqa: E501 | +----v-------+ | | # noqa: E501 +--------->free -->Passive lock| | | # noqa: E501 +----+-------+ | | # noqa: E501 | | | # noqa: E501 | v | # noqa: E501 +------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+ """ class SelectableObject(object): """DEV: to implement one of those, you need to add 2 things to your object: - add "check_recv" function - call "self.call_release" once you are ready to be read You can set the __selectable_force_select__ to True in the class, if you want to # noqa: E501 force the handler to use fileno(). This may only be usable on sockets created using # noqa: E501 the builtin socket API.""" __selectable_force_select__ = False def __init__(self): self.hooks = [] def check_recv(self): """DEV: will be called only once (at beginning) to check if the object is ready.""" # noqa: E501 raise OSError("This method must be overwritten.") def _wait_non_ressources(self, callback): """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback""" # noqa: E501 self.trigger = threading.Lock() self.was_ended = False self.trigger.acquire() self.trigger.acquire() if not self.was_ended: callback(self) def wait_return(self, callback): """Entry point of SelectableObject: register the callback""" if self.check_recv(): return callback(self) _t = threading.Thread(target=self._wait_non_ressources, args=(callback,)) # noqa: E501 _t.setDaemon(True) _t.start() def register_hook(self, hook): """DEV: When call_release() will be called, the hook will also""" self.hooks.append(hook) def call_release(self, arborted=False): """DEV: Must be call when the object becomes ready to read. Relesases the lock of _wait_non_ressources""" self.was_ended = arborted try: self.trigger.release() except (threading.ThreadError, AttributeError): pass # Trigger hooks for hook in self.hooks: hook() class SelectableSelector(object): """ Select SelectableObject objects. inputs: objects to process remain: timeout. If 0, return []. customTypes: types of the objects that have the check_recv function. """ def _release_all(self): """Releases all locks to kill all threads""" for i in self.inputs: i.call_release(True) self.available_lock.release() def _timeout_thread(self, remain): """Timeout before releasing every thing, if nothing was returned""" time.sleep(remain) if not self._ended: self._ended = True self._release_all() def _exit_door(self, _input): """This function is passed to each SelectableObject as a callback The SelectableObjects have to call it once there are ready""" self.results.append(_input) if self._ended: return self._ended = True self._release_all() def __init__(self, inputs, remain): self.results = [] self.inputs = list(inputs) self.remain = remain self.available_lock = threading.Lock() self.available_lock.acquire() self._ended = False def process(self): """Entry point of SelectableSelector""" if WINDOWS: select_inputs = [] for i in self.inputs: if not isinstance(i, SelectableObject): warning("Unknown ignored object type: %s", type(i)) elif i.__selectable_force_select__: # Then use select.select select_inputs.append(i) elif not self.remain and i.check_recv(): self.results.append(i) elif self.remain: i.wait_return(self._exit_door) if select_inputs: # Use default select function self.results.extend(select(select_inputs, [], [], self.remain)[0]) # noqa: E501 if not self.remain: return self.results threading.Thread(target=self._timeout_thread, args=(self.remain,)).start() # noqa: E501 if not self._ended: self.available_lock.acquire() return self.results else: r, _, _ = select(self.inputs, [], [], self.remain) return r def select_objects(inputs, remain): """ Select SelectableObject objects. Same than: ``select.select([inputs], [], [], remain)`` But also works on Windows, only on SelectableObject. :param inputs: objects to process :param remain: timeout. If 0, return []. """ handler = SelectableSelector(inputs, remain) return handler.process() class ObjectPipe(SelectableObject): read_allowed_exceptions = () def __init__(self): self.closed = False self.rd, self.wr = os.pipe() self.queue = deque() SelectableObject.__init__(self) def fileno(self): return self.rd def check_recv(self): return len(self.queue) > 0 def send(self, obj): self.queue.append(obj) os.write(self.wr, b"X") self.call_release() def write(self, obj): self.send(obj) def flush(self): pass def recv(self, n=0): if self.closed: if self.check_recv(): return self.queue.popleft() return None os.read(self.rd, 1) return self.queue.popleft() def read(self, n=0): return self.recv(n) def close(self): if not self.closed: self.closed = True os.close(self.rd) os.close(self.wr) self.queue.clear() def __del__(self): self.close() @staticmethod def select(sockets, remain=conf.recv_poll_rate): # Only handle ObjectPipes results = [] for s in sockets: if s.closed: results.append(s) if results: return results, None return select_objects(sockets, remain), None class Message: def __init__(self, **args): self.__dict__.update(args) def __repr__(self): return "" % " ".join("%s=%r" % (k, v) for (k, v) in six.iteritems(self.__dict__) # noqa: E501 if not k.startswith("_")) class _instance_state: def __init__(self, instance): self.__self__ = instance.__self__ self.__func__ = instance.__func__ self.__self__.__class__ = instance.__self__.__class__ def __getattr__(self, attr): return getattr(self.__func__, attr) def __call__(self, *args, **kargs): return self.__func__(self.__self__, *args, **kargs) def breaks(self): return self.__self__.add_breakpoints(self.__func__) def intercepts(self): return self.__self__.add_interception_points(self.__func__) def unbreaks(self): return self.__self__.remove_breakpoints(self.__func__) def unintercepts(self): return self.__self__.remove_interception_points(self.__func__) ############## # Automata # ############## class ATMT: STATE = "State" ACTION = "Action" CONDITION = "Condition" RECV = "Receive condition" TIMEOUT = "Timeout condition" IOEVENT = "I/O event" class NewStateRequested(Exception): def __init__(self, state_func, automaton, *args, **kargs): self.func = state_func self.state = state_func.atmt_state self.initial = state_func.atmt_initial self.error = state_func.atmt_error self.final = state_func.atmt_final Exception.__init__(self, "Request state [%s]" % self.state) self.automaton = automaton self.args = args self.kargs = kargs self.action_parameters() # init action parameters def action_parameters(self, *args, **kargs): self.action_args = args self.action_kargs = kargs return self def run(self): return self.func(self.automaton, *self.args, **self.kargs) def __repr__(self): return "NewStateRequested(%s)" % self.state @staticmethod def state(initial=0, final=0, error=0): def deco(f, initial=initial, final=final): f.atmt_type = ATMT.STATE f.atmt_state = f.__name__ f.atmt_initial = initial f.atmt_final = final f.atmt_error = error def state_wrapper(self, *args, **kargs): return ATMT.NewStateRequested(f, self, *args, **kargs) state_wrapper.__name__ = "%s_wrapper" % f.__name__ state_wrapper.atmt_type = ATMT.STATE state_wrapper.atmt_state = f.__name__ state_wrapper.atmt_initial = initial state_wrapper.atmt_final = final state_wrapper.atmt_error = error state_wrapper.atmt_origfunc = f return state_wrapper return deco @staticmethod def action(cond, prio=0): def deco(f, cond=cond): if not hasattr(f, "atmt_type"): f.atmt_cond = {} f.atmt_type = ATMT.ACTION f.atmt_cond[cond.atmt_condname] = prio return f return deco @staticmethod def condition(state, prio=0): def deco(f, state=state): f.atmt_type = ATMT.CONDITION f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ f.atmt_prio = prio return f return deco @staticmethod def receive_condition(state, prio=0): def deco(f, state=state): f.atmt_type = ATMT.RECV f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ f.atmt_prio = prio return f return deco @staticmethod def ioevent(state, name, prio=0, as_supersocket=None): def deco(f, state=state): f.atmt_type = ATMT.IOEVENT f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ f.atmt_ioname = name f.atmt_prio = prio f.atmt_as_supersocket = as_supersocket return f return deco @staticmethod def timeout(state, timeout): def deco(f, state=state, timeout=timeout): f.atmt_type = ATMT.TIMEOUT f.atmt_state = state.atmt_state f.atmt_timeout = timeout f.atmt_condname = f.__name__ return f return deco class _ATMT_Command: RUN = "RUN" NEXT = "NEXT" FREEZE = "FREEZE" STOP = "STOP" END = "END" EXCEPTION = "EXCEPTION" SINGLESTEP = "SINGLESTEP" BREAKPOINT = "BREAKPOINT" INTERCEPT = "INTERCEPT" ACCEPT = "ACCEPT" REPLACE = "REPLACE" REJECT = "REJECT" class _ATMT_supersocket(SuperSocket, SelectableObject): def __init__(self, name, ioevent, automaton, proto, *args, **kargs): SelectableObject.__init__(self) self.name = name self.ioevent = ioevent self.proto = proto # write, read self.spa, self.spb = ObjectPipe(), ObjectPipe() # Register recv hook self.spb.register_hook(self.call_release) kargs["external_fd"] = {ioevent: (self.spa, self.spb)} self.atmt = automaton(*args, **kargs) self.atmt.runbg() def fileno(self): return self.spb.fileno() def send(self, s): if not isinstance(s, bytes): s = bytes(s) return self.spa.send(s) def check_recv(self): return self.spb.check_recv() def recv(self, n=MTU): r = self.spb.recv(n) if self.proto is not None: r = self.proto(r) return r def close(self): if not self.closed: self.atmt.stop() self.spa.close() self.spb.close() self.closed = True @staticmethod def select(sockets, remain=conf.recv_poll_rate): return select_objects(sockets, remain), None class _ATMT_to_supersocket: def __init__(self, name, ioevent, automaton): self.name = name self.ioevent = ioevent self.automaton = automaton def __call__(self, proto, *args, **kargs): return _ATMT_supersocket( self.name, self.ioevent, self.automaton, proto, *args, **kargs ) class Automaton_metaclass(type): def __new__(cls, name, bases, dct): cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct) cls.states = {} cls.state = None cls.recv_conditions = {} cls.conditions = {} cls.ioevents = {} cls.timeout = {} cls.actions = {} cls.initial_states = [] cls.ionames = [] cls.iosupersockets = [] members = {} classes = [cls] while classes: c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501 classes += list(c.__bases__) for k, v in six.iteritems(c.__dict__): if k not in members: members[k] = v decorated = [v for v in six.itervalues(members) if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")] # noqa: E501 for m in decorated: if m.atmt_type == ATMT.STATE: s = m.atmt_state cls.states[s] = m cls.recv_conditions[s] = [] cls.ioevents[s] = [] cls.conditions[s] = [] cls.timeout[s] = [] if m.atmt_initial: cls.initial_states.append(m) elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]: # noqa: E501 cls.actions[m.atmt_condname] = [] for m in decorated: if m.atmt_type == ATMT.CONDITION: cls.conditions[m.atmt_state].append(m) elif m.atmt_type == ATMT.RECV: cls.recv_conditions[m.atmt_state].append(m) elif m.atmt_type == ATMT.IOEVENT: cls.ioevents[m.atmt_state].append(m) cls.ionames.append(m.atmt_ioname) if m.atmt_as_supersocket is not None: cls.iosupersockets.append(m) elif m.atmt_type == ATMT.TIMEOUT: cls.timeout[m.atmt_state].append((m.atmt_timeout, m)) elif m.atmt_type == ATMT.ACTION: for c in m.atmt_cond: cls.actions[c].append(m) for v in six.itervalues(cls.timeout): v.sort(key=lambda x: x[0]) v.append((None, None)) for v in itertools.chain(six.itervalues(cls.conditions), six.itervalues(cls.recv_conditions), six.itervalues(cls.ioevents)): v.sort(key=lambda x: x.atmt_prio) for condname, actlst in six.iteritems(cls.actions): actlst.sort(key=lambda x: x.atmt_cond[condname]) for ioev in cls.iosupersockets: setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls)) # noqa: E501 return cls def build_graph(self): s = 'digraph "%s" {\n' % self.__class__.__name__ se = "" # Keep initial nodes at the beginning for better rendering for st in six.itervalues(self.states): if st.atmt_initial: se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501 elif st.atmt_final: se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501 elif st.atmt_error: se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501 s += se for st in six.itervalues(self.states): for n in st.atmt_origfunc.__code__.co_names + st.atmt_origfunc.__code__.co_consts: # noqa: E501 if n in self.states: s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n) # noqa: E501 for c, k, v in ([("purple", k, v) for k, v in self.conditions.items()] + # noqa: E501 [("red", k, v) for k, v in self.recv_conditions.items()] + # noqa: E501 [("orange", k, v) for k, v in self.ioevents.items()]): for f in v: for n in f.__code__.co_names + f.__code__.co_consts: if n in self.states: line = f.atmt_condname for x in self.actions[f.atmt_condname]: line += "\\l>[%s]" % x.__name__ s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k, n, line, c) # noqa: E501 for k, v in six.iteritems(self.timeout): for t, f in v: if f is None: continue for n in f.__code__.co_names + f.__code__.co_consts: if n in self.states: line = "%s/%.1fs" % (f.atmt_condname, t) for x in self.actions[f.atmt_condname]: line += "\\l>[%s]" % x.__name__ s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501 s += "}\n" return s def graph(self, **kargs): s = self.build_graph() return do_graph(s, **kargs) class Automaton(six.with_metaclass(Automaton_metaclass)): def parse_args(self, debug=0, store=1, **kargs): self.debug_level = debug self.socket_kargs = kargs self.store_packets = store def master_filter(self, pkt): return True def my_send(self, pkt): self.send_sock.send(pkt) # Utility classes and exceptions class _IO_fdwrapper(SelectableObject): def __init__(self, rd, wr): if rd is not None and not isinstance(rd, (int, ObjectPipe)): rd = rd.fileno() if wr is not None and not isinstance(wr, (int, ObjectPipe)): wr = wr.fileno() self.rd = rd self.wr = wr SelectableObject.__init__(self) def fileno(self): if isinstance(self.rd, ObjectPipe): return self.rd.fileno() return self.rd def check_recv(self): return self.rd.check_recv() def read(self, n=65535): if isinstance(self.rd, ObjectPipe): return self.rd.recv(n) return os.read(self.rd, n) def write(self, msg): self.call_release() if isinstance(self.wr, ObjectPipe): self.wr.send(msg) return return os.write(self.wr, msg) def recv(self, n=65535): return self.read(n) def send(self, msg): return self.write(msg) class _IO_mixer(SelectableObject): def __init__(self, rd, wr): self.rd = rd self.wr = wr SelectableObject.__init__(self) def fileno(self): if isinstance(self.rd, int): return self.rd return self.rd.fileno() def check_recv(self): return self.rd.check_recv() def recv(self, n=None): return self.rd.recv(n) def read(self, n=None): return self.recv(n) def send(self, msg): self.wr.send(msg) return self.call_release() def write(self, msg): return self.send(msg) class AutomatonException(Exception): def __init__(self, msg, state=None, result=None): Exception.__init__(self, msg) self.state = state self.result = result class AutomatonError(AutomatonException): pass class ErrorState(AutomatonException): pass class Stuck(AutomatonException): pass class AutomatonStopped(AutomatonException): pass class Breakpoint(AutomatonStopped): pass class Singlestep(AutomatonStopped): pass class InterceptionPoint(AutomatonStopped): def __init__(self, msg, state=None, result=None, packet=None): Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) # noqa: E501 self.packet = packet class CommandMessage(AutomatonException): pass # Services def debug(self, lvl, msg): if self.debug_level >= lvl: log_runtime.debug(msg) def send(self, pkt): if self.state.state in self.interception_points: self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary()) self.intercepted_packet = pkt cmd = Message(type=_ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) # noqa: E501 self.cmdout.send(cmd) cmd = self.cmdin.recv() self.intercepted_packet = None if cmd.type == _ATMT_Command.REJECT: self.debug(3, "INTERCEPT: packet rejected") return elif cmd.type == _ATMT_Command.REPLACE: pkt = cmd.pkt self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501 elif cmd.type == _ATMT_Command.ACCEPT: self.debug(3, "INTERCEPT: packet accepted") else: raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501 self.my_send(pkt) self.debug(3, "SENT : %s" % pkt.summary()) if self.store_packets: self.packets.append(pkt.copy()) # Internals def __init__(self, *args, **kargs): external_fd = kargs.pop("external_fd", {}) self.send_sock_class = kargs.pop("ll", conf.L3socket) self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) self.started = threading.Lock() self.threadid = None self.breakpointed = None self.breakpoints = set() self.interception_points = set() self.intercepted_packet = None self.debug_level = 0 self.init_args = args self.init_kargs = kargs self.io = type.__new__(type, "IOnamespace", (), {}) self.oi = type.__new__(type, "IOnamespace", (), {}) self.cmdin = ObjectPipe() self.cmdout = ObjectPipe() self.ioin = {} self.ioout = {} for n in self.ionames: extfd = external_fd.get(n) if not isinstance(extfd, tuple): extfd = (extfd, extfd) ioin, ioout = extfd if ioin is None: ioin = ObjectPipe() elif not isinstance(ioin, SelectableObject): ioin = self._IO_fdwrapper(ioin, None) if ioout is None: ioout = ObjectPipe() elif not isinstance(ioout, SelectableObject): ioout = self._IO_fdwrapper(None, ioout) self.ioin[n] = ioin self.ioout[n] = ioout ioin.ioname = n ioout.ioname = n setattr(self.io, n, self._IO_mixer(ioout, ioin)) setattr(self.oi, n, self._IO_mixer(ioin, ioout)) for stname in self.states: setattr(self, stname, _instance_state(getattr(self, stname))) self.start() def __iter__(self): return self def __del__(self): self.stop() def _run_condition(self, cond, *args, **kargs): try: self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 cond(self, *args, **kargs) except ATMT.NewStateRequested as state_req: self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501 if cond.atmt_type == ATMT.RECV: if self.store_packets: self.packets.append(args[0]) for action in self.actions[cond.atmt_condname]: self.debug(2, " + Running action [%s]" % action.__name__) action(self, *state_req.action_args, **state_req.action_kargs) raise except Exception as e: self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501 raise else: self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 def _do_start(self, *args, **kargs): ready = threading.Event() _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs) # noqa: E501 _t.setDaemon(True) _t.start() ready.wait() def _do_control(self, ready, *args, **kargs): with self.started: self.threadid = threading.currentThread().ident # Update default parameters a = args + self.init_args[len(args):] k = self.init_kargs.copy() k.update(kargs) self.parse_args(*a, **k) # Start the automaton self.state = self.initial_states[0](self) self.send_sock = self.send_sock_class(**self.socket_kargs) self.listen_sock = self.recv_sock_class(**self.socket_kargs) self.packets = PacketList(name="session[%s]" % self.__class__.__name__) # noqa: E501 singlestep = True iterator = self._do_iter() self.debug(3, "Starting control thread [tid=%i]" % self.threadid) # Sync threads ready.set() try: while True: c = self.cmdin.recv() self.debug(5, "Received command %s" % c.type) if c.type == _ATMT_Command.RUN: singlestep = False elif c.type == _ATMT_Command.NEXT: singlestep = True elif c.type == _ATMT_Command.FREEZE: continue elif c.type == _ATMT_Command.STOP: break while True: state = next(iterator) if isinstance(state, self.CommandMessage): break elif isinstance(state, self.Breakpoint): c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501 self.cmdout.send(c) break if singlestep: c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501 self.cmdout.send(c) break except (StopIteration, RuntimeError): c = Message(type=_ATMT_Command.END, result=self.final_state_output) self.cmdout.send(c) except Exception as e: exc_info = sys.exc_info() self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, traceback.format_exception(*exc_info))) # noqa: E501 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501 self.cmdout.send(m) self.debug(3, "Stopping control thread (tid=%i)" % self.threadid) self.threadid = None def _do_iter(self): while True: try: self.debug(1, "## state=[%s]" % self.state.state) # Entering a new state. First, call new state function if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501 self.breakpointed = self.state.state yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501 state=self.state.state) self.breakpointed = None state_output = self.state.run() if self.state.error: raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501 result=state_output, state=self.state.state) # noqa: E501 if self.state.final: self.final_state_output = state_output return if state_output is None: state_output = () elif not isinstance(state_output, list): state_output = state_output, # Then check immediate conditions for cond in self.conditions[self.state.state]: self._run_condition(cond, *state_output) # If still there and no conditions left, we are stuck! if (len(self.recv_conditions[self.state.state]) == 0 and len(self.ioevents[self.state.state]) == 0 and len(self.timeout[self.state.state]) == 1): raise self.Stuck("stuck in [%s]" % self.state.state, state=self.state.state, result=state_output) # noqa: E501 # Finally listen and pay attention to timeouts expirations = iter(self.timeout[self.state.state]) next_timeout, timeout_func = next(expirations) t0 = time.time() fds = [self.cmdin] if len(self.recv_conditions[self.state.state]) > 0: fds.append(self.listen_sock) for ioev in self.ioevents[self.state.state]: fds.append(self.ioin[ioev.atmt_ioname]) while True: t = time.time() - t0 if next_timeout is not None: if next_timeout <= t: self._run_condition(timeout_func, *state_output) next_timeout, timeout_func = next(expirations) if next_timeout is None: remain = None else: remain = next_timeout - t self.debug(5, "Select on %r" % fds) r = select_objects(fds, remain) self.debug(5, "Selected %r" % r) for fd in r: self.debug(5, "Looking at %r" % fd) if fd == self.cmdin: yield self.CommandMessage("Received command message") # noqa: E501 elif fd == self.listen_sock: pkt = self.listen_sock.recv(MTU) if pkt is not None: if self.master_filter(pkt): self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501 for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501 self._run_condition(rcvcond, pkt, *state_output) # noqa: E501 else: self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501 else: self.debug(3, "IOEVENT on %s" % fd.ioname) for ioevt in self.ioevents[self.state.state]: if ioevt.atmt_ioname == fd.ioname: self._run_condition(ioevt, fd, *state_output) # noqa: E501 except ATMT.NewStateRequested as state_req: self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501 self.state = state_req yield state_req # Public API def add_interception_points(self, *ipts): for ipt in ipts: if hasattr(ipt, "atmt_state"): ipt = ipt.atmt_state self.interception_points.add(ipt) def remove_interception_points(self, *ipts): for ipt in ipts: if hasattr(ipt, "atmt_state"): ipt = ipt.atmt_state self.interception_points.discard(ipt) def add_breakpoints(self, *bps): for bp in bps: if hasattr(bp, "atmt_state"): bp = bp.atmt_state self.breakpoints.add(bp) def remove_breakpoints(self, *bps): for bp in bps: if hasattr(bp, "atmt_state"): bp = bp.atmt_state self.breakpoints.discard(bp) def start(self, *args, **kargs): if not self.started.locked(): self._do_start(*args, **kargs) def run(self, resume=None, wait=True): if resume is None: resume = Message(type=_ATMT_Command.RUN) self.cmdin.send(resume) if wait: try: c = self.cmdout.recv() except KeyboardInterrupt: self.cmdin.send(Message(type=_ATMT_Command.FREEZE)) return if c.type == _ATMT_Command.END: return c.result elif c.type == _ATMT_Command.INTERCEPT: raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501 elif c.type == _ATMT_Command.SINGLESTEP: raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501 elif c.type == _ATMT_Command.BREAKPOINT: raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501 elif c.type == _ATMT_Command.EXCEPTION: six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2]) def runbg(self, resume=None, wait=False): self.run(resume, wait) def next(self): return self.run(resume=Message(type=_ATMT_Command.NEXT)) __next__ = next def stop(self): self.cmdin.send(Message(type=_ATMT_Command.STOP)) with self.started: # Flush command pipes while True: r = select_objects([self.cmdin, self.cmdout], 0) if not r: break for fd in r: fd.recv() def restart(self, *args, **kargs): self.stop() self.start(*args, **kargs) def accept_packet(self, pkt=None, wait=False): rsm = Message() if pkt is None: rsm.type = _ATMT_Command.ACCEPT else: rsm.type = _ATMT_Command.REPLACE rsm.pkt = pkt return self.run(resume=rsm, wait=wait) def reject_packet(self, wait=False): rsm = Message(type=_ATMT_Command.REJECT) return self.run(resume=rsm, wait=wait)