Refactor the notifier.wait_for_events code to be clearer. Add _NotifierUserStream.new_listener that accpets a token to avoid races.

This commit is contained in:
Erik Johnston 2015-06-18 15:49:05 +01:00
parent 050ebccf30
commit 22049ea700
3 changed files with 72 additions and 69 deletions

View File

@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken
import synapse.metrics
@ -45,20 +45,16 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
__slots__ = ["deferred"]
def __init__(self, deferred):
self.deferred = deferred
object.__setattr__(self, "deferred", deferred)
def notified(self):
return self.deferred.called
def __getattr__(self, name):
return getattr(self.deferred, name)
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
def __setattr__(self, name, value):
setattr(self.deferred, name, value)
class _NotifierUserStream(object):
@ -75,11 +71,12 @@ class _NotifierUserStream(object):
appservice=None):
self.user = str(user)
self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
event source.
@ -91,12 +88,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
if self.listeners:
self.last_notified_ms = time_now_ms
listeners = self.listeners
self.listeners = set()
for listener in listeners:
listener.notify(self.current_token)
noify_deferred = self.notify_deferred
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
@ -114,6 +109,18 @@ class _NotifierUserStream(object):
self.appservice, set()
).discard(self)
def count_listeners(self):
return len(self.noify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
@ -158,7 +165,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(len(stream.listeners) for stream in all_user_streams)
return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
metrics.register_callback(
@ -286,10 +293,6 @@ class Notifier(object):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
@ -302,54 +305,38 @@ class Notifier(object):
rooms=rooms,
appservice=appservice,
current_token=current_token,
time_now_ms=time_now_ms,
time_now_ms=self.clock.time_msec(),
)
self._register_with_keys(user_stream)
else:
current_token = user_stream.current_token
result = None
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
if result:
defer.returnValue(result)
if timeout:
timer = [None]
listeners = []
timed_out = [False]
listener = None
timer = self.clock.call_later(
timeout/1000., lambda: listener.cancel()
)
def notify_listeners():
user_stream.listeners.difference_update(listeners)
for listener in listeners:
listener.notify(current_token)
del listeners[:]
def _timeout_listener():
timed_out[0] = True
timer[0] = None
notify_listeners()
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
deferred = defer.Deferred()
notify_listeners()
listeners.append(_NotificationListener(deferred))
user_stream.listeners.update(listeners)
new_token = yield deferred
result = yield callback(current_token, new_token)
current_token = new_token
if timer[0] is not None:
prev_token = from_token
while not result:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
# We need to start listening to the streams *before* doing
# the callback, as otherwise we may miss something.
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
self.clock.cancel_call_later(timer, ignore_errs=True)
else:
current_token = user_stream.current_token
result = yield callback(from_token, current_token)
defer.returnValue(result)
@ -367,6 +354,9 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
@ -401,7 +391,7 @@ class Notifier(object):
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
if stream.listeners:
if stream.count_listeners():
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)

View File

@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer):
def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel()
except:
if not ignore_errs:
raise
def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called:

View File

@ -45,7 +45,7 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
object.__setattr__(self, "_observers", set())
def callback(r):
self._result = (True, r)
@ -74,12 +74,21 @@ class ObservableDeferred(object):
def observe(self):
if not self._result:
d = defer.Deferred()
self._observers.append(d)
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
def __getattr__(self, name):
return getattr(self._deferred, name)