Allow channel message handlers to short circuit
- a message handler can return logical True to prevent subsequent message handlers from running
This commit is contained in:
parent
a61b15cf6a
commit
e005826151
|
@ -152,7 +152,7 @@ def client_connected(link):
|
||||||
# Register message types and add callback to channel
|
# Register message types and add callback to channel
|
||||||
channel = link.get_channel()
|
channel = link.get_channel()
|
||||||
channel.register_message_type(StringMessage)
|
channel.register_message_type(StringMessage)
|
||||||
channel.add_message_callback(server_message_received)
|
channel.add_message_handler(server_message_received)
|
||||||
|
|
||||||
def client_disconnected(link):
|
def client_disconnected(link):
|
||||||
RNS.log("Client disconnected")
|
RNS.log("Client disconnected")
|
||||||
|
@ -290,7 +290,7 @@ def link_established(link):
|
||||||
# Register messages and add handler to channel
|
# Register messages and add handler to channel
|
||||||
channel = link.get_channel()
|
channel = link.get_channel()
|
||||||
channel.register_message_type(StringMessage)
|
channel.register_message_type(StringMessage)
|
||||||
channel.add_message_callback(client_message_received)
|
channel.add_message_handler(client_message_received)
|
||||||
|
|
||||||
# Inform the user that the server is
|
# Inform the user that the server is
|
||||||
# connected
|
# connected
|
||||||
|
|
|
@ -150,28 +150,34 @@ class Channel(contextlib.AbstractContextManager):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def register_message_type(self, message_class: Type[MessageBase]):
|
def register_message_type(self, message_class: Type[MessageBase]):
|
||||||
if not issubclass(message_class, MessageBase):
|
with self._lock:
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.")
|
if not issubclass(message_class, MessageBase):
|
||||||
if message_class.MSGTYPE is None:
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.")
|
f"{message_class} is not a subclass of {MessageBase}.")
|
||||||
try:
|
if message_class.MSGTYPE is None:
|
||||||
message_class()
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
except Exception as ex:
|
f"{message_class} has invalid MSGTYPE class attribute.")
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
try:
|
||||||
f"{message_class} raised an exception when constructed with no arguments: {ex}")
|
message_class()
|
||||||
|
except Exception as ex:
|
||||||
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
|
f"{message_class} raised an exception when constructed with no arguments: {ex}")
|
||||||
|
|
||||||
self._message_factories[message_class.MSGTYPE] = message_class
|
self._message_factories[message_class.MSGTYPE] = message_class
|
||||||
|
|
||||||
def add_message_callback(self, callback: MessageCallbackType):
|
def add_message_handler(self, callback: MessageCallbackType):
|
||||||
if callback not in self._message_callbacks:
|
with self._lock:
|
||||||
self._message_callbacks.append(callback)
|
if callback not in self._message_callbacks:
|
||||||
|
self._message_callbacks.append(callback)
|
||||||
|
|
||||||
def remove_message_callback(self, callback: MessageCallbackType):
|
def remove_message_handler(self, callback: MessageCallbackType):
|
||||||
self._message_callbacks.remove(callback)
|
with self._lock:
|
||||||
|
self._message_callbacks.remove(callback)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._message_callbacks.clear()
|
with self._lock:
|
||||||
self.clear_rings()
|
self._message_callbacks.clear()
|
||||||
|
self.clear_rings()
|
||||||
|
|
||||||
def clear_rings(self):
|
def clear_rings(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -205,19 +211,29 @@ class Channel(contextlib.AbstractContextManager):
|
||||||
env.tracked = False
|
env.tracked = False
|
||||||
self._rx_ring.remove(env)
|
self._rx_ring.remove(env)
|
||||||
|
|
||||||
|
def _run_callbacks(self, message: MessageBase):
|
||||||
|
with self._lock:
|
||||||
|
cbs = self._message_callbacks.copy()
|
||||||
|
|
||||||
|
for cb in cbs:
|
||||||
|
try:
|
||||||
|
if cb(message):
|
||||||
|
return
|
||||||
|
except Exception as ex:
|
||||||
|
RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR)
|
||||||
|
|
||||||
def receive(self, raw: bytes):
|
def receive(self, raw: bytes):
|
||||||
try:
|
try:
|
||||||
envelope = Envelope(outlet=self._outlet, raw=raw)
|
envelope = Envelope(outlet=self._outlet, raw=raw)
|
||||||
message = envelope.unpack(self._message_factories)
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
message = envelope.unpack(self._message_factories)
|
||||||
is_new = self.emplace_envelope(envelope, self._rx_ring)
|
is_new = self.emplace_envelope(envelope, self._rx_ring)
|
||||||
self.prune_rx_ring()
|
self.prune_rx_ring()
|
||||||
if not is_new:
|
if not is_new:
|
||||||
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
|
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
|
||||||
return
|
return
|
||||||
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
|
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
|
||||||
for cb in self._message_callbacks:
|
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start()
|
||||||
threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start()
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
RNS.log(f"Channel: Error receiving data: {ex}")
|
RNS.log(f"Channel: Error receiving data: {ex}")
|
||||||
|
|
||||||
|
|
|
@ -245,13 +245,49 @@ class TestChannel(unittest.TestCase):
|
||||||
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
|
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
|
||||||
self.assertFalse(envelope.tracked)
|
self.assertFalse(envelope.tracked)
|
||||||
|
|
||||||
|
def test_multiple_handler(self):
|
||||||
|
handler1_called = 0
|
||||||
|
handler1_return = True
|
||||||
|
handler2_called = 0
|
||||||
|
|
||||||
|
def handler1(msg: MessageBase):
|
||||||
|
nonlocal handler1_called, handler1_return
|
||||||
|
self.assertIsInstance(msg, MessageTest)
|
||||||
|
handler1_called += 1
|
||||||
|
return handler1_return
|
||||||
|
|
||||||
|
def handler2(msg: MessageBase):
|
||||||
|
nonlocal handler2_called
|
||||||
|
self.assertIsInstance(msg, MessageTest)
|
||||||
|
handler2_called += 1
|
||||||
|
|
||||||
|
message = MessageTest()
|
||||||
|
self.h.channel.register_message_type(MessageTest)
|
||||||
|
self.h.channel.add_message_handler(handler1)
|
||||||
|
self.h.channel.add_message_handler(handler2)
|
||||||
|
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0)
|
||||||
|
raw = envelope.pack()
|
||||||
|
self.h.channel.receive(raw)
|
||||||
|
|
||||||
|
self.assertEqual(1, handler1_called)
|
||||||
|
self.assertEqual(0, handler2_called)
|
||||||
|
|
||||||
|
handler1_return = False
|
||||||
|
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1)
|
||||||
|
raw = envelope.pack()
|
||||||
|
self.h.channel.receive(raw)
|
||||||
|
|
||||||
|
self.assertEqual(2, handler1_called)
|
||||||
|
self.assertEqual(1, handler2_called)
|
||||||
|
|
||||||
|
|
||||||
def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
|
def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
|
||||||
decoded: [MessageBase] = []
|
decoded: [MessageBase] = []
|
||||||
|
|
||||||
def handle_message(message: MessageBase):
|
def handle_message(message: MessageBase):
|
||||||
decoded.append(message)
|
decoded.append(message)
|
||||||
|
|
||||||
self.h.channel.add_message_callback(handle_message)
|
self.h.channel.add_message_handler(handle_message)
|
||||||
self.assertEqual(len(self.h.outlet.packets), 0)
|
self.assertEqual(len(self.h.outlet.packets), 0)
|
||||||
|
|
||||||
envelope = self.h.channel.send(message)
|
envelope = self.h.channel.send(message)
|
||||||
|
|
|
@ -382,7 +382,7 @@ class TestLink(unittest.TestCase):
|
||||||
|
|
||||||
channel = l1.get_channel()
|
channel = l1.get_channel()
|
||||||
channel.register_message_type(MessageTest)
|
channel.register_message_type(MessageTest)
|
||||||
channel.add_message_callback(handle_message)
|
channel.add_message_handler(handle_message)
|
||||||
channel.send(test_message)
|
channel.send(test_message)
|
||||||
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
@ -466,7 +466,7 @@ def targets(yp=False):
|
||||||
message.data = message.data + " back"
|
message.data = message.data + " back"
|
||||||
channel.send(message)
|
channel.send(message)
|
||||||
channel.register_message_type(MessageTest)
|
channel.register_message_type(MessageTest)
|
||||||
channel.add_message_callback(handle_message)
|
channel.add_message_handler(handle_message)
|
||||||
|
|
||||||
m_rns = RNS.Reticulum("./tests/rnsconfig")
|
m_rns = RNS.Reticulum("./tests/rnsconfig")
|
||||||
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
||||||
|
|
Loading…
Reference in New Issue