1031 lines
40 KiB
Python
1031 lines
40 KiB
Python
|
# This file is part of Scapy
|
||
|
# See http://www.secdev.org/projects/scapy for more information
|
||
|
# Copyright (C) Philippe Biondi <phil@secdev.org>
|
||
|
# Copyright (C) Gabriel Potter <gabriel@potter.fr>
|
||
|
# This program is published under a GPLv2 license
|
||
|
|
||
|
"""
|
||
|
Automata with states, transitions and actions.
|
||
|
"""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
import types
|
||
|
import itertools
|
||
|
import time
|
||
|
import os
|
||
|
import sys
|
||
|
import traceback
|
||
|
from select import select
|
||
|
from collections import deque
|
||
|
import threading
|
||
|
from scapy.config import conf
|
||
|
from scapy.utils import do_graph
|
||
|
from scapy.error import log_runtime, warning
|
||
|
from scapy.plist import PacketList
|
||
|
from scapy.data import MTU
|
||
|
from scapy.supersocket import SuperSocket
|
||
|
from scapy.consts import WINDOWS
|
||
|
import scapy.modules.six as six
|
||
|
|
||
|
|
||
|
""" In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionality # noqa: E501
|
||
|
# Passive way: using no-ressources locks
|
||
|
+---------+ +---------------+ +-------------------------+ # noqa: E501
|
||
|
| Start +------------->Select_objects +----->+Linux: call select.select| # noqa: E501
|
||
|
+---------+ |(select.select)| +-------------------------+ # noqa: E501
|
||
|
+-------+-------+
|
||
|
|
|
||
|
+----v----+ +--------+
|
||
|
| Windows | |Time Out+----------------------------------+ # noqa: E501
|
||
|
+----+----+ +----+---+ | # noqa: E501
|
||
|
| ^ | # noqa: E501
|
||
|
Event | | | # noqa: E501
|
||
|
+ | | | # noqa: E501
|
||
|
| +-------v-------+ | | # noqa: E501
|
||
|
| +------+Selectable Sel.+-----+-----------------+-----------+ | # noqa: E501
|
||
|
| | +-------+-------+ | | | v +-----v-----+ # noqa: E501
|
||
|
+-------v----------+ | | | | | Passive lock<-----+release_all<------+ # noqa: E501
|
||
|
|Data added to list| +----v-----+ +-----v-----+ +----v-----+ v v + +-----------+ | # noqa: E501
|
||
|
+--------+---------+ |Selectable| |Selectable | |Selectable| ............ | | # noqa: E501
|
||
|
| +----+-----+ +-----------+ +----------+ | | # noqa: E501
|
||
|
| v | | # noqa: E501
|
||
|
v +----+------+ +------------------+ +-------------v-------------------+ | # noqa: E501
|
||
|
+-----+------+ |wait_return+-->+ check_recv: | | | | # noqa: E501
|
||
|
|call_release| +----+------+ |If data is in list| | END state: selectable returned | +---+--------+ # noqa: E501
|
||
|
+-----+-------- v +-------+----------+ | | | exit door | # noqa: E501
|
||
|
| else | +---------------------------------+ +---+--------+ # noqa: E501
|
||
|
| + | | # noqa: E501
|
||
|
| +----v-------+ | | # noqa: E501
|
||
|
+--------->free -->Passive lock| | | # noqa: E501
|
||
|
+----+-------+ | | # noqa: E501
|
||
|
| | | # noqa: E501
|
||
|
| v | # noqa: E501
|
||
|
+------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+
|
||
|
"""
|
||
|
|
||
|
|
||
|
class SelectableObject(object):
|
||
|
"""DEV: to implement one of those, you need to add 2 things to your object:
|
||
|
- add "check_recv" function
|
||
|
- call "self.call_release" once you are ready to be read
|
||
|
|
||
|
You can set the __selectable_force_select__ to True in the class, if you want to # noqa: E501
|
||
|
force the handler to use fileno(). This may only be usable on sockets created using # noqa: E501
|
||
|
the builtin socket API."""
|
||
|
__selectable_force_select__ = False
|
||
|
|
||
|
def __init__(self):
|
||
|
self.hooks = []
|
||
|
|
||
|
def check_recv(self):
|
||
|
"""DEV: will be called only once (at beginning) to check if the object is ready.""" # noqa: E501
|
||
|
raise OSError("This method must be overwritten.")
|
||
|
|
||
|
def _wait_non_ressources(self, callback):
|
||
|
"""This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback""" # noqa: E501
|
||
|
self.trigger = threading.Lock()
|
||
|
self.was_ended = False
|
||
|
self.trigger.acquire()
|
||
|
self.trigger.acquire()
|
||
|
if not self.was_ended:
|
||
|
callback(self)
|
||
|
|
||
|
def wait_return(self, callback):
|
||
|
"""Entry point of SelectableObject: register the callback"""
|
||
|
if self.check_recv():
|
||
|
return callback(self)
|
||
|
_t = threading.Thread(target=self._wait_non_ressources, args=(callback,)) # noqa: E501
|
||
|
_t.setDaemon(True)
|
||
|
_t.start()
|
||
|
|
||
|
def register_hook(self, hook):
|
||
|
"""DEV: When call_release() will be called, the hook will also"""
|
||
|
self.hooks.append(hook)
|
||
|
|
||
|
def call_release(self, arborted=False):
|
||
|
"""DEV: Must be call when the object becomes ready to read.
|
||
|
Relesases the lock of _wait_non_ressources"""
|
||
|
self.was_ended = arborted
|
||
|
try:
|
||
|
self.trigger.release()
|
||
|
except (threading.ThreadError, AttributeError):
|
||
|
pass
|
||
|
# Trigger hooks
|
||
|
for hook in self.hooks:
|
||
|
hook()
|
||
|
|
||
|
|
||
|
class SelectableSelector(object):
|
||
|
"""
|
||
|
Select SelectableObject objects.
|
||
|
|
||
|
inputs: objects to process
|
||
|
remain: timeout. If 0, return [].
|
||
|
customTypes: types of the objects that have the check_recv function.
|
||
|
"""
|
||
|
|
||
|
def _release_all(self):
|
||
|
"""Releases all locks to kill all threads"""
|
||
|
for i in self.inputs:
|
||
|
i.call_release(True)
|
||
|
self.available_lock.release()
|
||
|
|
||
|
def _timeout_thread(self, remain):
|
||
|
"""Timeout before releasing every thing, if nothing was returned"""
|
||
|
time.sleep(remain)
|
||
|
if not self._ended:
|
||
|
self._ended = True
|
||
|
self._release_all()
|
||
|
|
||
|
def _exit_door(self, _input):
|
||
|
"""This function is passed to each SelectableObject as a callback
|
||
|
The SelectableObjects have to call it once there are ready"""
|
||
|
self.results.append(_input)
|
||
|
if self._ended:
|
||
|
return
|
||
|
self._ended = True
|
||
|
self._release_all()
|
||
|
|
||
|
def __init__(self, inputs, remain):
|
||
|
self.results = []
|
||
|
self.inputs = list(inputs)
|
||
|
self.remain = remain
|
||
|
self.available_lock = threading.Lock()
|
||
|
self.available_lock.acquire()
|
||
|
self._ended = False
|
||
|
|
||
|
def process(self):
|
||
|
"""Entry point of SelectableSelector"""
|
||
|
if WINDOWS:
|
||
|
select_inputs = []
|
||
|
for i in self.inputs:
|
||
|
if not isinstance(i, SelectableObject):
|
||
|
warning("Unknown ignored object type: %s", type(i))
|
||
|
elif i.__selectable_force_select__:
|
||
|
# Then use select.select
|
||
|
select_inputs.append(i)
|
||
|
elif not self.remain and i.check_recv():
|
||
|
self.results.append(i)
|
||
|
elif self.remain:
|
||
|
i.wait_return(self._exit_door)
|
||
|
if select_inputs:
|
||
|
# Use default select function
|
||
|
self.results.extend(select(select_inputs, [], [], self.remain)[0]) # noqa: E501
|
||
|
if not self.remain:
|
||
|
return self.results
|
||
|
|
||
|
threading.Thread(target=self._timeout_thread, args=(self.remain,)).start() # noqa: E501
|
||
|
if not self._ended:
|
||
|
self.available_lock.acquire()
|
||
|
return self.results
|
||
|
else:
|
||
|
r, _, _ = select(self.inputs, [], [], self.remain)
|
||
|
return r
|
||
|
|
||
|
|
||
|
def select_objects(inputs, remain):
|
||
|
"""
|
||
|
Select SelectableObject objects. Same than:
|
||
|
``select.select([inputs], [], [], remain)``
|
||
|
But also works on Windows, only on SelectableObject.
|
||
|
|
||
|
:param inputs: objects to process
|
||
|
:param remain: timeout. If 0, return [].
|
||
|
"""
|
||
|
handler = SelectableSelector(inputs, remain)
|
||
|
return handler.process()
|
||
|
|
||
|
|
||
|
class ObjectPipe(SelectableObject):
|
||
|
read_allowed_exceptions = ()
|
||
|
|
||
|
def __init__(self):
|
||
|
self.closed = False
|
||
|
self.rd, self.wr = os.pipe()
|
||
|
self.queue = deque()
|
||
|
SelectableObject.__init__(self)
|
||
|
|
||
|
def fileno(self):
|
||
|
return self.rd
|
||
|
|
||
|
def check_recv(self):
|
||
|
return len(self.queue) > 0
|
||
|
|
||
|
def send(self, obj):
|
||
|
self.queue.append(obj)
|
||
|
os.write(self.wr, b"X")
|
||
|
self.call_release()
|
||
|
|
||
|
def write(self, obj):
|
||
|
self.send(obj)
|
||
|
|
||
|
def flush(self):
|
||
|
pass
|
||
|
|
||
|
def recv(self, n=0):
|
||
|
if self.closed:
|
||
|
if self.check_recv():
|
||
|
return self.queue.popleft()
|
||
|
return None
|
||
|
os.read(self.rd, 1)
|
||
|
return self.queue.popleft()
|
||
|
|
||
|
def read(self, n=0):
|
||
|
return self.recv(n)
|
||
|
|
||
|
def close(self):
|
||
|
if not self.closed:
|
||
|
self.closed = True
|
||
|
os.close(self.rd)
|
||
|
os.close(self.wr)
|
||
|
self.queue.clear()
|
||
|
|
||
|
def __del__(self):
|
||
|
self.close()
|
||
|
|
||
|
@staticmethod
|
||
|
def select(sockets, remain=conf.recv_poll_rate):
|
||
|
# Only handle ObjectPipes
|
||
|
results = []
|
||
|
for s in sockets:
|
||
|
if s.closed:
|
||
|
results.append(s)
|
||
|
if results:
|
||
|
return results, None
|
||
|
return select_objects(sockets, remain), None
|
||
|
|
||
|
|
||
|
class Message:
|
||
|
def __init__(self, **args):
|
||
|
self.__dict__.update(args)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "<Message %s>" % " ".join("%s=%r" % (k, v)
|
||
|
for (k, v) in six.iteritems(self.__dict__) # noqa: E501
|
||
|
if not k.startswith("_"))
|
||
|
|
||
|
|
||
|
class _instance_state:
|
||
|
def __init__(self, instance):
|
||
|
self.__self__ = instance.__self__
|
||
|
self.__func__ = instance.__func__
|
||
|
self.__self__.__class__ = instance.__self__.__class__
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
return getattr(self.__func__, attr)
|
||
|
|
||
|
def __call__(self, *args, **kargs):
|
||
|
return self.__func__(self.__self__, *args, **kargs)
|
||
|
|
||
|
def breaks(self):
|
||
|
return self.__self__.add_breakpoints(self.__func__)
|
||
|
|
||
|
def intercepts(self):
|
||
|
return self.__self__.add_interception_points(self.__func__)
|
||
|
|
||
|
def unbreaks(self):
|
||
|
return self.__self__.remove_breakpoints(self.__func__)
|
||
|
|
||
|
def unintercepts(self):
|
||
|
return self.__self__.remove_interception_points(self.__func__)
|
||
|
|
||
|
|
||
|
##############
|
||
|
# Automata #
|
||
|
##############
|
||
|
|
||
|
class ATMT:
|
||
|
STATE = "State"
|
||
|
ACTION = "Action"
|
||
|
CONDITION = "Condition"
|
||
|
RECV = "Receive condition"
|
||
|
TIMEOUT = "Timeout condition"
|
||
|
IOEVENT = "I/O event"
|
||
|
|
||
|
class NewStateRequested(Exception):
|
||
|
def __init__(self, state_func, automaton, *args, **kargs):
|
||
|
self.func = state_func
|
||
|
self.state = state_func.atmt_state
|
||
|
self.initial = state_func.atmt_initial
|
||
|
self.error = state_func.atmt_error
|
||
|
self.final = state_func.atmt_final
|
||
|
Exception.__init__(self, "Request state [%s]" % self.state)
|
||
|
self.automaton = automaton
|
||
|
self.args = args
|
||
|
self.kargs = kargs
|
||
|
self.action_parameters() # init action parameters
|
||
|
|
||
|
def action_parameters(self, *args, **kargs):
|
||
|
self.action_args = args
|
||
|
self.action_kargs = kargs
|
||
|
return self
|
||
|
|
||
|
def run(self):
|
||
|
return self.func(self.automaton, *self.args, **self.kargs)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "NewStateRequested(%s)" % self.state
|
||
|
|
||
|
@staticmethod
|
||
|
def state(initial=0, final=0, error=0):
|
||
|
def deco(f, initial=initial, final=final):
|
||
|
f.atmt_type = ATMT.STATE
|
||
|
f.atmt_state = f.__name__
|
||
|
f.atmt_initial = initial
|
||
|
f.atmt_final = final
|
||
|
f.atmt_error = error
|
||
|
|
||
|
def state_wrapper(self, *args, **kargs):
|
||
|
return ATMT.NewStateRequested(f, self, *args, **kargs)
|
||
|
|
||
|
state_wrapper.__name__ = "%s_wrapper" % f.__name__
|
||
|
state_wrapper.atmt_type = ATMT.STATE
|
||
|
state_wrapper.atmt_state = f.__name__
|
||
|
state_wrapper.atmt_initial = initial
|
||
|
state_wrapper.atmt_final = final
|
||
|
state_wrapper.atmt_error = error
|
||
|
state_wrapper.atmt_origfunc = f
|
||
|
return state_wrapper
|
||
|
return deco
|
||
|
|
||
|
@staticmethod
|
||
|
def action(cond, prio=0):
|
||
|
def deco(f, cond=cond):
|
||
|
if not hasattr(f, "atmt_type"):
|
||
|
f.atmt_cond = {}
|
||
|
f.atmt_type = ATMT.ACTION
|
||
|
f.atmt_cond[cond.atmt_condname] = prio
|
||
|
return f
|
||
|
return deco
|
||
|
|
||
|
@staticmethod
|
||
|
def condition(state, prio=0):
|
||
|
def deco(f, state=state):
|
||
|
f.atmt_type = ATMT.CONDITION
|
||
|
f.atmt_state = state.atmt_state
|
||
|
f.atmt_condname = f.__name__
|
||
|
f.atmt_prio = prio
|
||
|
return f
|
||
|
return deco
|
||
|
|
||
|
@staticmethod
|
||
|
def receive_condition(state, prio=0):
|
||
|
def deco(f, state=state):
|
||
|
f.atmt_type = ATMT.RECV
|
||
|
f.atmt_state = state.atmt_state
|
||
|
f.atmt_condname = f.__name__
|
||
|
f.atmt_prio = prio
|
||
|
return f
|
||
|
return deco
|
||
|
|
||
|
@staticmethod
|
||
|
def ioevent(state, name, prio=0, as_supersocket=None):
|
||
|
def deco(f, state=state):
|
||
|
f.atmt_type = ATMT.IOEVENT
|
||
|
f.atmt_state = state.atmt_state
|
||
|
f.atmt_condname = f.__name__
|
||
|
f.atmt_ioname = name
|
||
|
f.atmt_prio = prio
|
||
|
f.atmt_as_supersocket = as_supersocket
|
||
|
return f
|
||
|
return deco
|
||
|
|
||
|
@staticmethod
|
||
|
def timeout(state, timeout):
|
||
|
def deco(f, state=state, timeout=timeout):
|
||
|
f.atmt_type = ATMT.TIMEOUT
|
||
|
f.atmt_state = state.atmt_state
|
||
|
f.atmt_timeout = timeout
|
||
|
f.atmt_condname = f.__name__
|
||
|
return f
|
||
|
return deco
|
||
|
|
||
|
|
||
|
class _ATMT_Command:
|
||
|
RUN = "RUN"
|
||
|
NEXT = "NEXT"
|
||
|
FREEZE = "FREEZE"
|
||
|
STOP = "STOP"
|
||
|
END = "END"
|
||
|
EXCEPTION = "EXCEPTION"
|
||
|
SINGLESTEP = "SINGLESTEP"
|
||
|
BREAKPOINT = "BREAKPOINT"
|
||
|
INTERCEPT = "INTERCEPT"
|
||
|
ACCEPT = "ACCEPT"
|
||
|
REPLACE = "REPLACE"
|
||
|
REJECT = "REJECT"
|
||
|
|
||
|
|
||
|
class _ATMT_supersocket(SuperSocket, SelectableObject):
|
||
|
def __init__(self, name, ioevent, automaton, proto, *args, **kargs):
|
||
|
SelectableObject.__init__(self)
|
||
|
self.name = name
|
||
|
self.ioevent = ioevent
|
||
|
self.proto = proto
|
||
|
# write, read
|
||
|
self.spa, self.spb = ObjectPipe(), ObjectPipe()
|
||
|
# Register recv hook
|
||
|
self.spb.register_hook(self.call_release)
|
||
|
kargs["external_fd"] = {ioevent: (self.spa, self.spb)}
|
||
|
self.atmt = automaton(*args, **kargs)
|
||
|
self.atmt.runbg()
|
||
|
|
||
|
def fileno(self):
|
||
|
return self.spb.fileno()
|
||
|
|
||
|
def send(self, s):
|
||
|
if not isinstance(s, bytes):
|
||
|
s = bytes(s)
|
||
|
return self.spa.send(s)
|
||
|
|
||
|
def check_recv(self):
|
||
|
return self.spb.check_recv()
|
||
|
|
||
|
def recv(self, n=MTU):
|
||
|
r = self.spb.recv(n)
|
||
|
if self.proto is not None:
|
||
|
r = self.proto(r)
|
||
|
return r
|
||
|
|
||
|
def close(self):
|
||
|
if not self.closed:
|
||
|
self.atmt.stop()
|
||
|
self.spa.close()
|
||
|
self.spb.close()
|
||
|
self.closed = True
|
||
|
|
||
|
@staticmethod
|
||
|
def select(sockets, remain=conf.recv_poll_rate):
|
||
|
return select_objects(sockets, remain), None
|
||
|
|
||
|
|
||
|
class _ATMT_to_supersocket:
|
||
|
def __init__(self, name, ioevent, automaton):
|
||
|
self.name = name
|
||
|
self.ioevent = ioevent
|
||
|
self.automaton = automaton
|
||
|
|
||
|
def __call__(self, proto, *args, **kargs):
|
||
|
return _ATMT_supersocket(
|
||
|
self.name, self.ioevent, self.automaton,
|
||
|
proto, *args, **kargs
|
||
|
)
|
||
|
|
||
|
|
||
|
class Automaton_metaclass(type):
|
||
|
def __new__(cls, name, bases, dct):
|
||
|
cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
|
||
|
cls.states = {}
|
||
|
cls.state = None
|
||
|
cls.recv_conditions = {}
|
||
|
cls.conditions = {}
|
||
|
cls.ioevents = {}
|
||
|
cls.timeout = {}
|
||
|
cls.actions = {}
|
||
|
cls.initial_states = []
|
||
|
cls.ionames = []
|
||
|
cls.iosupersockets = []
|
||
|
|
||
|
members = {}
|
||
|
classes = [cls]
|
||
|
while classes:
|
||
|
c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501
|
||
|
classes += list(c.__bases__)
|
||
|
for k, v in six.iteritems(c.__dict__):
|
||
|
if k not in members:
|
||
|
members[k] = v
|
||
|
|
||
|
decorated = [v for v in six.itervalues(members)
|
||
|
if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")] # noqa: E501
|
||
|
|
||
|
for m in decorated:
|
||
|
if m.atmt_type == ATMT.STATE:
|
||
|
s = m.atmt_state
|
||
|
cls.states[s] = m
|
||
|
cls.recv_conditions[s] = []
|
||
|
cls.ioevents[s] = []
|
||
|
cls.conditions[s] = []
|
||
|
cls.timeout[s] = []
|
||
|
if m.atmt_initial:
|
||
|
cls.initial_states.append(m)
|
||
|
elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]: # noqa: E501
|
||
|
cls.actions[m.atmt_condname] = []
|
||
|
|
||
|
for m in decorated:
|
||
|
if m.atmt_type == ATMT.CONDITION:
|
||
|
cls.conditions[m.atmt_state].append(m)
|
||
|
elif m.atmt_type == ATMT.RECV:
|
||
|
cls.recv_conditions[m.atmt_state].append(m)
|
||
|
elif m.atmt_type == ATMT.IOEVENT:
|
||
|
cls.ioevents[m.atmt_state].append(m)
|
||
|
cls.ionames.append(m.atmt_ioname)
|
||
|
if m.atmt_as_supersocket is not None:
|
||
|
cls.iosupersockets.append(m)
|
||
|
elif m.atmt_type == ATMT.TIMEOUT:
|
||
|
cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
|
||
|
elif m.atmt_type == ATMT.ACTION:
|
||
|
for c in m.atmt_cond:
|
||
|
cls.actions[c].append(m)
|
||
|
|
||
|
for v in six.itervalues(cls.timeout):
|
||
|
v.sort(key=lambda x: x[0])
|
||
|
v.append((None, None))
|
||
|
for v in itertools.chain(six.itervalues(cls.conditions),
|
||
|
six.itervalues(cls.recv_conditions),
|
||
|
six.itervalues(cls.ioevents)):
|
||
|
v.sort(key=lambda x: x.atmt_prio)
|
||
|
for condname, actlst in six.iteritems(cls.actions):
|
||
|
actlst.sort(key=lambda x: x.atmt_cond[condname])
|
||
|
|
||
|
for ioev in cls.iosupersockets:
|
||
|
setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls)) # noqa: E501
|
||
|
|
||
|
return cls
|
||
|
|
||
|
def build_graph(self):
|
||
|
s = 'digraph "%s" {\n' % self.__class__.__name__
|
||
|
|
||
|
se = "" # Keep initial nodes at the beginning for better rendering
|
||
|
for st in six.itervalues(self.states):
|
||
|
if st.atmt_initial:
|
||
|
se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501
|
||
|
elif st.atmt_final:
|
||
|
se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501
|
||
|
elif st.atmt_error:
|
||
|
se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501
|
||
|
s += se
|
||
|
|
||
|
for st in six.itervalues(self.states):
|
||
|
for n in st.atmt_origfunc.__code__.co_names + st.atmt_origfunc.__code__.co_consts: # noqa: E501
|
||
|
if n in self.states:
|
||
|
s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n) # noqa: E501
|
||
|
|
||
|
for c, k, v in ([("purple", k, v) for k, v in self.conditions.items()] + # noqa: E501
|
||
|
[("red", k, v) for k, v in self.recv_conditions.items()] + # noqa: E501
|
||
|
[("orange", k, v) for k, v in self.ioevents.items()]):
|
||
|
for f in v:
|
||
|
for n in f.__code__.co_names + f.__code__.co_consts:
|
||
|
if n in self.states:
|
||
|
line = f.atmt_condname
|
||
|
for x in self.actions[f.atmt_condname]:
|
||
|
line += "\\l>[%s]" % x.__name__
|
||
|
s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k, n, line, c) # noqa: E501
|
||
|
for k, v in six.iteritems(self.timeout):
|
||
|
for t, f in v:
|
||
|
if f is None:
|
||
|
continue
|
||
|
for n in f.__code__.co_names + f.__code__.co_consts:
|
||
|
if n in self.states:
|
||
|
line = "%s/%.1fs" % (f.atmt_condname, t)
|
||
|
for x in self.actions[f.atmt_condname]:
|
||
|
line += "\\l>[%s]" % x.__name__
|
||
|
s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501
|
||
|
s += "}\n"
|
||
|
return s
|
||
|
|
||
|
def graph(self, **kargs):
|
||
|
s = self.build_graph()
|
||
|
return do_graph(s, **kargs)
|
||
|
|
||
|
|
||
|
class Automaton(six.with_metaclass(Automaton_metaclass)):
|
||
|
def parse_args(self, debug=0, store=1, **kargs):
|
||
|
self.debug_level = debug
|
||
|
self.socket_kargs = kargs
|
||
|
self.store_packets = store
|
||
|
|
||
|
def master_filter(self, pkt):
|
||
|
return True
|
||
|
|
||
|
def my_send(self, pkt):
|
||
|
self.send_sock.send(pkt)
|
||
|
|
||
|
# Utility classes and exceptions
|
||
|
class _IO_fdwrapper(SelectableObject):
|
||
|
def __init__(self, rd, wr):
|
||
|
if rd is not None and not isinstance(rd, (int, ObjectPipe)):
|
||
|
rd = rd.fileno()
|
||
|
if wr is not None and not isinstance(wr, (int, ObjectPipe)):
|
||
|
wr = wr.fileno()
|
||
|
self.rd = rd
|
||
|
self.wr = wr
|
||
|
SelectableObject.__init__(self)
|
||
|
|
||
|
def fileno(self):
|
||
|
if isinstance(self.rd, ObjectPipe):
|
||
|
return self.rd.fileno()
|
||
|
return self.rd
|
||
|
|
||
|
def check_recv(self):
|
||
|
return self.rd.check_recv()
|
||
|
|
||
|
def read(self, n=65535):
|
||
|
if isinstance(self.rd, ObjectPipe):
|
||
|
return self.rd.recv(n)
|
||
|
return os.read(self.rd, n)
|
||
|
|
||
|
def write(self, msg):
|
||
|
self.call_release()
|
||
|
if isinstance(self.wr, ObjectPipe):
|
||
|
self.wr.send(msg)
|
||
|
return
|
||
|
return os.write(self.wr, msg)
|
||
|
|
||
|
def recv(self, n=65535):
|
||
|
return self.read(n)
|
||
|
|
||
|
def send(self, msg):
|
||
|
return self.write(msg)
|
||
|
|
||
|
class _IO_mixer(SelectableObject):
|
||
|
def __init__(self, rd, wr):
|
||
|
self.rd = rd
|
||
|
self.wr = wr
|
||
|
SelectableObject.__init__(self)
|
||
|
|
||
|
def fileno(self):
|
||
|
if isinstance(self.rd, int):
|
||
|
return self.rd
|
||
|
return self.rd.fileno()
|
||
|
|
||
|
def check_recv(self):
|
||
|
return self.rd.check_recv()
|
||
|
|
||
|
def recv(self, n=None):
|
||
|
return self.rd.recv(n)
|
||
|
|
||
|
def read(self, n=None):
|
||
|
return self.recv(n)
|
||
|
|
||
|
def send(self, msg):
|
||
|
self.wr.send(msg)
|
||
|
return self.call_release()
|
||
|
|
||
|
def write(self, msg):
|
||
|
return self.send(msg)
|
||
|
|
||
|
class AutomatonException(Exception):
|
||
|
def __init__(self, msg, state=None, result=None):
|
||
|
Exception.__init__(self, msg)
|
||
|
self.state = state
|
||
|
self.result = result
|
||
|
|
||
|
class AutomatonError(AutomatonException):
|
||
|
pass
|
||
|
|
||
|
class ErrorState(AutomatonException):
|
||
|
pass
|
||
|
|
||
|
class Stuck(AutomatonException):
|
||
|
pass
|
||
|
|
||
|
class AutomatonStopped(AutomatonException):
|
||
|
pass
|
||
|
|
||
|
class Breakpoint(AutomatonStopped):
|
||
|
pass
|
||
|
|
||
|
class Singlestep(AutomatonStopped):
|
||
|
pass
|
||
|
|
||
|
class InterceptionPoint(AutomatonStopped):
|
||
|
def __init__(self, msg, state=None, result=None, packet=None):
|
||
|
Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) # noqa: E501
|
||
|
self.packet = packet
|
||
|
|
||
|
class CommandMessage(AutomatonException):
|
||
|
pass
|
||
|
|
||
|
# Services
|
||
|
def debug(self, lvl, msg):
|
||
|
if self.debug_level >= lvl:
|
||
|
log_runtime.debug(msg)
|
||
|
|
||
|
def send(self, pkt):
|
||
|
if self.state.state in self.interception_points:
|
||
|
self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary())
|
||
|
self.intercepted_packet = pkt
|
||
|
cmd = Message(type=_ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) # noqa: E501
|
||
|
self.cmdout.send(cmd)
|
||
|
cmd = self.cmdin.recv()
|
||
|
self.intercepted_packet = None
|
||
|
if cmd.type == _ATMT_Command.REJECT:
|
||
|
self.debug(3, "INTERCEPT: packet rejected")
|
||
|
return
|
||
|
elif cmd.type == _ATMT_Command.REPLACE:
|
||
|
pkt = cmd.pkt
|
||
|
self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501
|
||
|
elif cmd.type == _ATMT_Command.ACCEPT:
|
||
|
self.debug(3, "INTERCEPT: packet accepted")
|
||
|
else:
|
||
|
raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501
|
||
|
self.my_send(pkt)
|
||
|
self.debug(3, "SENT : %s" % pkt.summary())
|
||
|
|
||
|
if self.store_packets:
|
||
|
self.packets.append(pkt.copy())
|
||
|
|
||
|
# Internals
|
||
|
def __init__(self, *args, **kargs):
|
||
|
external_fd = kargs.pop("external_fd", {})
|
||
|
self.send_sock_class = kargs.pop("ll", conf.L3socket)
|
||
|
self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
|
||
|
self.started = threading.Lock()
|
||
|
self.threadid = None
|
||
|
self.breakpointed = None
|
||
|
self.breakpoints = set()
|
||
|
self.interception_points = set()
|
||
|
self.intercepted_packet = None
|
||
|
self.debug_level = 0
|
||
|
self.init_args = args
|
||
|
self.init_kargs = kargs
|
||
|
self.io = type.__new__(type, "IOnamespace", (), {})
|
||
|
self.oi = type.__new__(type, "IOnamespace", (), {})
|
||
|
self.cmdin = ObjectPipe()
|
||
|
self.cmdout = ObjectPipe()
|
||
|
self.ioin = {}
|
||
|
self.ioout = {}
|
||
|
for n in self.ionames:
|
||
|
extfd = external_fd.get(n)
|
||
|
if not isinstance(extfd, tuple):
|
||
|
extfd = (extfd, extfd)
|
||
|
ioin, ioout = extfd
|
||
|
if ioin is None:
|
||
|
ioin = ObjectPipe()
|
||
|
elif not isinstance(ioin, SelectableObject):
|
||
|
ioin = self._IO_fdwrapper(ioin, None)
|
||
|
if ioout is None:
|
||
|
ioout = ObjectPipe()
|
||
|
elif not isinstance(ioout, SelectableObject):
|
||
|
ioout = self._IO_fdwrapper(None, ioout)
|
||
|
|
||
|
self.ioin[n] = ioin
|
||
|
self.ioout[n] = ioout
|
||
|
ioin.ioname = n
|
||
|
ioout.ioname = n
|
||
|
setattr(self.io, n, self._IO_mixer(ioout, ioin))
|
||
|
setattr(self.oi, n, self._IO_mixer(ioin, ioout))
|
||
|
|
||
|
for stname in self.states:
|
||
|
setattr(self, stname,
|
||
|
_instance_state(getattr(self, stname)))
|
||
|
|
||
|
self.start()
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def __del__(self):
|
||
|
self.stop()
|
||
|
|
||
|
def _run_condition(self, cond, *args, **kargs):
|
||
|
try:
|
||
|
self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
|
||
|
cond(self, *args, **kargs)
|
||
|
except ATMT.NewStateRequested as state_req:
|
||
|
self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501
|
||
|
if cond.atmt_type == ATMT.RECV:
|
||
|
if self.store_packets:
|
||
|
self.packets.append(args[0])
|
||
|
for action in self.actions[cond.atmt_condname]:
|
||
|
self.debug(2, " + Running action [%s]" % action.__name__)
|
||
|
action(self, *state_req.action_args, **state_req.action_kargs)
|
||
|
raise
|
||
|
except Exception as e:
|
||
|
self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501
|
||
|
raise
|
||
|
else:
|
||
|
self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
|
||
|
|
||
|
def _do_start(self, *args, **kargs):
|
||
|
ready = threading.Event()
|
||
|
_t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs) # noqa: E501
|
||
|
_t.setDaemon(True)
|
||
|
_t.start()
|
||
|
ready.wait()
|
||
|
|
||
|
def _do_control(self, ready, *args, **kargs):
|
||
|
with self.started:
|
||
|
self.threadid = threading.currentThread().ident
|
||
|
|
||
|
# Update default parameters
|
||
|
a = args + self.init_args[len(args):]
|
||
|
k = self.init_kargs.copy()
|
||
|
k.update(kargs)
|
||
|
self.parse_args(*a, **k)
|
||
|
|
||
|
# Start the automaton
|
||
|
self.state = self.initial_states[0](self)
|
||
|
self.send_sock = self.send_sock_class(**self.socket_kargs)
|
||
|
self.listen_sock = self.recv_sock_class(**self.socket_kargs)
|
||
|
self.packets = PacketList(name="session[%s]" % self.__class__.__name__) # noqa: E501
|
||
|
|
||
|
singlestep = True
|
||
|
iterator = self._do_iter()
|
||
|
self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
|
||
|
# Sync threads
|
||
|
ready.set()
|
||
|
try:
|
||
|
while True:
|
||
|
c = self.cmdin.recv()
|
||
|
self.debug(5, "Received command %s" % c.type)
|
||
|
if c.type == _ATMT_Command.RUN:
|
||
|
singlestep = False
|
||
|
elif c.type == _ATMT_Command.NEXT:
|
||
|
singlestep = True
|
||
|
elif c.type == _ATMT_Command.FREEZE:
|
||
|
continue
|
||
|
elif c.type == _ATMT_Command.STOP:
|
||
|
break
|
||
|
while True:
|
||
|
state = next(iterator)
|
||
|
if isinstance(state, self.CommandMessage):
|
||
|
break
|
||
|
elif isinstance(state, self.Breakpoint):
|
||
|
c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501
|
||
|
self.cmdout.send(c)
|
||
|
break
|
||
|
if singlestep:
|
||
|
c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501
|
||
|
self.cmdout.send(c)
|
||
|
break
|
||
|
except (StopIteration, RuntimeError):
|
||
|
c = Message(type=_ATMT_Command.END,
|
||
|
result=self.final_state_output)
|
||
|
self.cmdout.send(c)
|
||
|
except Exception as e:
|
||
|
exc_info = sys.exc_info()
|
||
|
self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, traceback.format_exception(*exc_info))) # noqa: E501
|
||
|
m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501
|
||
|
self.cmdout.send(m)
|
||
|
self.debug(3, "Stopping control thread (tid=%i)" % self.threadid)
|
||
|
self.threadid = None
|
||
|
|
||
|
def _do_iter(self):
|
||
|
while True:
|
||
|
try:
|
||
|
self.debug(1, "## state=[%s]" % self.state.state)
|
||
|
|
||
|
# Entering a new state. First, call new state function
|
||
|
if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501
|
||
|
self.breakpointed = self.state.state
|
||
|
yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501
|
||
|
state=self.state.state)
|
||
|
self.breakpointed = None
|
||
|
state_output = self.state.run()
|
||
|
if self.state.error:
|
||
|
raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501
|
||
|
result=state_output, state=self.state.state) # noqa: E501
|
||
|
if self.state.final:
|
||
|
self.final_state_output = state_output
|
||
|
return
|
||
|
|
||
|
if state_output is None:
|
||
|
state_output = ()
|
||
|
elif not isinstance(state_output, list):
|
||
|
state_output = state_output,
|
||
|
|
||
|
# Then check immediate conditions
|
||
|
for cond in self.conditions[self.state.state]:
|
||
|
self._run_condition(cond, *state_output)
|
||
|
|
||
|
# If still there and no conditions left, we are stuck!
|
||
|
if (len(self.recv_conditions[self.state.state]) == 0 and
|
||
|
len(self.ioevents[self.state.state]) == 0 and
|
||
|
len(self.timeout[self.state.state]) == 1):
|
||
|
raise self.Stuck("stuck in [%s]" % self.state.state,
|
||
|
state=self.state.state, result=state_output) # noqa: E501
|
||
|
|
||
|
# Finally listen and pay attention to timeouts
|
||
|
expirations = iter(self.timeout[self.state.state])
|
||
|
next_timeout, timeout_func = next(expirations)
|
||
|
t0 = time.time()
|
||
|
|
||
|
fds = [self.cmdin]
|
||
|
if len(self.recv_conditions[self.state.state]) > 0:
|
||
|
fds.append(self.listen_sock)
|
||
|
for ioev in self.ioevents[self.state.state]:
|
||
|
fds.append(self.ioin[ioev.atmt_ioname])
|
||
|
while True:
|
||
|
t = time.time() - t0
|
||
|
if next_timeout is not None:
|
||
|
if next_timeout <= t:
|
||
|
self._run_condition(timeout_func, *state_output)
|
||
|
next_timeout, timeout_func = next(expirations)
|
||
|
if next_timeout is None:
|
||
|
remain = None
|
||
|
else:
|
||
|
remain = next_timeout - t
|
||
|
|
||
|
self.debug(5, "Select on %r" % fds)
|
||
|
r = select_objects(fds, remain)
|
||
|
self.debug(5, "Selected %r" % r)
|
||
|
for fd in r:
|
||
|
self.debug(5, "Looking at %r" % fd)
|
||
|
if fd == self.cmdin:
|
||
|
yield self.CommandMessage("Received command message") # noqa: E501
|
||
|
elif fd == self.listen_sock:
|
||
|
pkt = self.listen_sock.recv(MTU)
|
||
|
if pkt is not None:
|
||
|
if self.master_filter(pkt):
|
||
|
self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501
|
||
|
for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501
|
||
|
self._run_condition(rcvcond, pkt, *state_output) # noqa: E501
|
||
|
else:
|
||
|
self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501
|
||
|
else:
|
||
|
self.debug(3, "IOEVENT on %s" % fd.ioname)
|
||
|
for ioevt in self.ioevents[self.state.state]:
|
||
|
if ioevt.atmt_ioname == fd.ioname:
|
||
|
self._run_condition(ioevt, fd, *state_output) # noqa: E501
|
||
|
|
||
|
except ATMT.NewStateRequested as state_req:
|
||
|
self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501
|
||
|
self.state = state_req
|
||
|
yield state_req
|
||
|
|
||
|
# Public API
|
||
|
def add_interception_points(self, *ipts):
|
||
|
for ipt in ipts:
|
||
|
if hasattr(ipt, "atmt_state"):
|
||
|
ipt = ipt.atmt_state
|
||
|
self.interception_points.add(ipt)
|
||
|
|
||
|
def remove_interception_points(self, *ipts):
|
||
|
for ipt in ipts:
|
||
|
if hasattr(ipt, "atmt_state"):
|
||
|
ipt = ipt.atmt_state
|
||
|
self.interception_points.discard(ipt)
|
||
|
|
||
|
def add_breakpoints(self, *bps):
|
||
|
for bp in bps:
|
||
|
if hasattr(bp, "atmt_state"):
|
||
|
bp = bp.atmt_state
|
||
|
self.breakpoints.add(bp)
|
||
|
|
||
|
def remove_breakpoints(self, *bps):
|
||
|
for bp in bps:
|
||
|
if hasattr(bp, "atmt_state"):
|
||
|
bp = bp.atmt_state
|
||
|
self.breakpoints.discard(bp)
|
||
|
|
||
|
def start(self, *args, **kargs):
|
||
|
if not self.started.locked():
|
||
|
self._do_start(*args, **kargs)
|
||
|
|
||
|
def run(self, resume=None, wait=True):
|
||
|
if resume is None:
|
||
|
resume = Message(type=_ATMT_Command.RUN)
|
||
|
self.cmdin.send(resume)
|
||
|
if wait:
|
||
|
try:
|
||
|
c = self.cmdout.recv()
|
||
|
except KeyboardInterrupt:
|
||
|
self.cmdin.send(Message(type=_ATMT_Command.FREEZE))
|
||
|
return
|
||
|
if c.type == _ATMT_Command.END:
|
||
|
return c.result
|
||
|
elif c.type == _ATMT_Command.INTERCEPT:
|
||
|
raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501
|
||
|
elif c.type == _ATMT_Command.SINGLESTEP:
|
||
|
raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501
|
||
|
elif c.type == _ATMT_Command.BREAKPOINT:
|
||
|
raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501
|
||
|
elif c.type == _ATMT_Command.EXCEPTION:
|
||
|
six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2])
|
||
|
|
||
|
def runbg(self, resume=None, wait=False):
|
||
|
self.run(resume, wait)
|
||
|
|
||
|
def next(self):
|
||
|
return self.run(resume=Message(type=_ATMT_Command.NEXT))
|
||
|
__next__ = next
|
||
|
|
||
|
def stop(self):
|
||
|
self.cmdin.send(Message(type=_ATMT_Command.STOP))
|
||
|
with self.started:
|
||
|
# Flush command pipes
|
||
|
while True:
|
||
|
r = select_objects([self.cmdin, self.cmdout], 0)
|
||
|
if not r:
|
||
|
break
|
||
|
for fd in r:
|
||
|
fd.recv()
|
||
|
|
||
|
def restart(self, *args, **kargs):
|
||
|
self.stop()
|
||
|
self.start(*args, **kargs)
|
||
|
|
||
|
def accept_packet(self, pkt=None, wait=False):
|
||
|
rsm = Message()
|
||
|
if pkt is None:
|
||
|
rsm.type = _ATMT_Command.ACCEPT
|
||
|
else:
|
||
|
rsm.type = _ATMT_Command.REPLACE
|
||
|
rsm.pkt = pkt
|
||
|
return self.run(resume=rsm, wait=wait)
|
||
|
|
||
|
def reject_packet(self, wait=False):
|
||
|
rsm = Message(type=_ATMT_Command.REJECT)
|
||
|
return self.run(resume=rsm, wait=wait)
|