Merge branch 'develop' of github.com:matrix-org/synapse into erikj/persist_event_perf

This commit is contained in:
Erik Johnston 2015-06-19 11:50:34 +01:00
commit 63141e77e7
17 changed files with 388 additions and 240 deletions

View File

@ -38,3 +38,7 @@ Brabo <brabo at riseup.net>
Ivan Shapovalov <intelfx100 at gmail.com> Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration * contrib/systemd: a sample systemd unit file and a logger configuration
Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API.

View File

@ -370,6 +370,8 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
request.authenticated_entity = user.to_string()
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError( raise AuthError(

View File

@ -35,7 +35,6 @@ from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@ -61,10 +60,13 @@ import twisted.manhole.telnet
import synapse import synapse
import contextlib
import logging import logging
import os import os
import re
import resource import resource
import subprocess import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
@ -112,7 +114,7 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.config.uploads_path, self.auth, self.content_addr
) )
def build_resource_for_media_repository(self): def build_resource_for_media_repository(self):
@ -142,6 +144,7 @@ class SynapseHomeServer(HomeServer):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_address = listener_config.get("bind_address", "")
tls = listener_config.get("tls", False) tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
if tls and config.no_tls: if tls and config.no_tls:
return return
@ -197,7 +200,8 @@ class SynapseHomeServer(HomeServer):
reactor.listenSSL( reactor.listenSSL(
port, port,
SynapseSite( SynapseSite(
"synapse.access.https", "synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config, listener_config,
root_resource, root_resource,
), ),
@ -208,7 +212,8 @@ class SynapseHomeServer(HomeServer):
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
"synapse.access.https", "synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config, listener_config,
root_resource, root_resource,
), ),
@ -375,7 +380,6 @@ def setup(config_options):
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
upload_dir=os.path.abspath("uploads"),
db_config=config.database_config, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
@ -433,9 +437,70 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class XForwardedForRequest(Request): class SynapseRequest(Request):
def __init__(self, *args, **kw): def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw) Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
""" """
Add a layer on top of another request that only uses the value of an Add a layer on top of another request that only uses the value of an
@ -451,8 +516,16 @@ class XForwardedForRequest(Request):
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
def XForwardedFactory(*args, **kwargs): class SynapseRequestFactory(object):
return XForwardedForRequest(*args, **kwargs) def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site): class SynapseSite(Site):
@ -460,18 +533,17 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's Subclass of a twisted http Site that does access logging with python's
standard logging standard logging
""" """
def __init__(self, logger_name, config, resource, *args, **kwargs): def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)
if config.get("x_forwarded", False):
self.requestFactory = XForwardedFactory self.site_tag = site_tag
self._log_formatter = proxiedLogFormatter
else: proxied = config.get("x_forwarded", False)
self._log_formatter = combinedLogFormatter self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
def log(self, request): def log(self, request):
line = self._log_formatter(self._logDateTime, request) pass
self.access_logger.info(line)
def create_resource_tree(desired_tree, redirect_root_to_web_client=True): def create_resource_tree(desired_tree, redirect_root_to_web_client=True):

View File

@ -148,7 +148,7 @@ class Config(object):
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
@ -209,7 +209,7 @@ class Config(object):
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )

View File

@ -21,13 +21,18 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """ return """
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes # The largest allowed upload size in bytes
max_upload_size: "10M" max_upload_size: "10M"

View File

@ -94,6 +94,7 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue((origin, content))

View File

@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
return return
user_info = yield self.store.get_user_by_id(user_id) user_info = yield self.store.get_user_by_id(user_id)
if not user_info: if user_info:
defer.returnValue(False) defer.returnValue(False)
return return

View File

@ -380,15 +380,6 @@ class MessageHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [ room_members = [
m for m in current_state.values() m for m in current_state.values()
if m.type == EventTypes.Member if m.type == EventTypes.Member
@ -396,20 +387,39 @@ class MessageHandler(BaseHandler):
] ]
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_handlers().presence_handler
presence = []
for m in room_members: @defer.inlineCallbacks
try: def get_presence():
member_presence = yield presence_handler.get_state( presence_defs = yield defer.DeferredList(
[
presence_handler.get_state(
target_user=UserID.from_string(m.user_id), target_user=UserID.from_string(m.user_id),
auth_user=auth_user, auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False,
) )
presence.append(member_presence) for m in room_members
except SynapseError: ],
logger.exception( consumeErrors=True,
"Failed to get member presence of %r", m.user_id
) )
defer.returnValue([p for success, p in presence_defs if success])
presence, (messages, token) = yield defer.gatherResults(
[
get_presence(),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({

View File

@ -191,8 +191,9 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False): def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
if check_auth:
visible = yield self.is_presence_visible( visible = yield self.is_presence_visible(
observer_user=auth_user, observer_user=auth_user,
observed_user=target_user observed_user=target_user
@ -200,15 +201,14 @@ class PresenceHandler(BaseHandler):
if not visible: if not visible:
raise SynapseError(404, "Presence information not visible") raise SynapseError(404, "Presence information not visible")
if target_user in self._user_cachemap:
state = self._user_cachemap[target_user].get_state()
else:
state = yield self.store.get_presence_state(target_user.localpart) state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state: if "mtime" in state:
del state["mtime"] del state["mtime"]
state["presence"] = state.pop("state") state["presence"] = state.pop("state")
if target_user in self._user_cachemap:
cached_state = self._user_cachemap[target_user].get_state()
if "last_active" in cached_state:
state["last_active"] = cached_state["last_active"]
else: else:
# TODO(paul): Have remote server send us permissions set # TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state() state = self._get_or_offline_usercache(target_user).get_state()

View File

@ -61,21 +61,31 @@ class SimpleHttpClient(object):
self.agent = Agent(reactor, pool=pool) self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = preserve_context_over_fn( d = preserve_context_over_fn(
self.agent.request, self.agent.request,
method, *args, **kwargs method, uri, *args, **kwargs
) )
logger.info("Sending request %s %s", method, uri)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
return response return response
def _eb(failure): def _eb(failure):
incoming_responses_counter.inc(method, "ERR") incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, failure.type, failure.getErrorMessage()
)
return failure return failure
d.addCallbacks(_cb, _eb) d.addCallbacks(_cb, _eb)
@ -84,7 +94,9 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
response = yield self.request( response = yield self.request(
@ -105,7 +117,7 @@ class SimpleHttpClient(object):
def post_json_get_json(self, uri, post_json): def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri) logger.debug("HTTP POST %s -> %s", json_str, uri)
response = yield self.request( response = yield self.request(
"POST", "POST",

View File

@ -35,11 +35,13 @@ from syutil.crypto.jsonsign import sign_json
import simplejson as json import simplejson as json
import logging import logging
import sys
import urllib import urllib
import urlparse import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
@ -109,6 +111,8 @@ class MatrixFederationHttpClient(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
@ -123,16 +127,12 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "",) ("", "", path_bytes, param_bytes, query_bytes, "",)
) )
logger.info("Sending request to %s: %s %s", txn_id = "%s-%s" % (method, self._next_id)
destination, method, url_bytes) self._next_id = (self._next_id + 1) % (sys.maxint - 1)
logger.debug( outbound_logger.info(
"Types: %s", "{%s} [%s] Sending request: %s %s",
[ txn_id, destination, method, url_bytes
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
) )
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
@ -141,6 +141,8 @@ class MatrixFederationHttpClient(object):
endpoint = self._getEndpoint(reactor, destination) endpoint = self._getEndpoint(reactor, destination)
log_result = None
try:
while True: while True:
producer = None producer = None
if body_callback: if body_callback:
@ -164,7 +166,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout/1000. if timeout else 60, time_out=timeout/1000. if timeout else 60,
) )
logger.debug("Got response to %s", method) log_result = "%d %s" % (response.code, response.phrase,)
break break
except Exception as e: except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError): if not retry_on_dns_fail and isinstance(e, DNSLookupError):
@ -173,10 +175,14 @@ class MatrixFederationHttpClient(object):
destination, destination,
e e
) )
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise raise
logger.warn( logger.warn(
"Sending request failed to %s: %s %s: %s - %s", "{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination, destination,
method, method,
url_bytes, url_bytes,
@ -184,19 +190,21 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout: if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left)) yield sleep(2 ** (5 - retries_left))
retries_left -= 1 retries_left -= 1
else: else:
raise raise
finally:
logger.info( outbound_logger.info(
"Received response %d %s for %s: %s %s", "{%s} [%s] Result: %s",
response.code, txn_id,
response.phrase,
destination, destination,
method, log_result,
url_bytes
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:

View File

@ -79,17 +79,11 @@ def request_handler(request_handler):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
request_context.request = request_id request_context.request = request_id
code = None with request.processing():
start = self.clock.time_msec()
try: try:
logger.info(
"Received request: %s %s",
request.method, request.path
)
d = request_handler(self, request) d = request_handler(self, request)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield d yield d
code = request.code
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
@ -105,7 +99,6 @@ def request_handler(request_handler):
version_string=self.version_string, version_string=self.version_string,
) )
except: except:
code = 500
logger.exception( logger.exception(
"Failed handle request %s.%s on %r: %r", "Failed handle request %s.%s on %r: %r",
request_handler.__module__, request_handler.__module__,
@ -119,13 +112,6 @@ def request_handler(request_handler):
{"error": "Internal server error"}, {"error": "Internal server error"},
send_cors=True send_cors=True
) )
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler return wrapped_request_handler

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -45,21 +45,11 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
""" """
__slots__ = ["deferred"]
def __init__(self, deferred): def __init__(self, deferred):
self.deferred = deferred self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object): class _NotifierUserStream(object):
"""This represents a user connected to the event stream. """This represents a user connected to the event stream.
@ -75,11 +65,12 @@ class _NotifierUserStream(object):
appservice=None): appservice=None):
self.user = str(user) self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
@ -91,12 +82,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
if self.listeners:
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
listeners = self.listeners noify_deferred = self.notify_deferred
self.listeners = set() self.notify_deferred = ObservableDeferred(defer.Deferred())
for listener in listeners: noify_deferred.callback(self.current_token)
listener.notify(self.current_token)
def remove(self, notifier): def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
@ -114,6 +103,18 @@ class _NotifierUserStream(object):
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
def count_listeners(self):
return len(self.noify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -158,7 +159,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values(): for x in self.appservice_to_user_streams.values():
all_user_streams |= x all_user_streams |= x
return sum(len(stream.listeners) for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
@ -286,10 +287,6 @@ class Notifier(object):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user) user = str(user)
user_stream = self.user_to_user_stream.get(user) user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
@ -302,55 +299,44 @@ class Notifier(object):
rooms=rooms, rooms=rooms,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=time_now_ms, time_now_ms=self.clock.time_msec(),
) )
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
else:
result = None
if timeout:
# Will be set to a _NotificationListener that we'll be waiting on.
# Allows us to cancel it.
listener = None
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out)
prev_token = from_token
while not result:
try:
current_token = user_stream.current_token current_token = user_stream.current_token
listener = [_NotificationListener(deferred)] result = yield callback(prev_token, current_token)
if timeout and not current_token.is_after(from_token):
user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
else:
result = None
timer = [None]
if result: if result:
user_stream.listeners.discard(listener[0]) break
defer.returnValue(result)
return
if timeout: # Now we wait for the _NotifierUserStream to be told there
timed_out = [False] # is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
def _timeout_listener(): self.clock.cancel_call_later(timer, ignore_errs=True)
timed_out[0] = True else:
timer[0] = None current_token = user_stream.current_token
user_stream.listeners.discard(listener[0]) result = yield callback(from_token, current_token)
listener[0].notify(current_token)
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
new_token = yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(deferred)
user_stream.listeners.add(listener[0])
result = yield callback(current_token, new_token)
current_token = new_token
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result) defer.returnValue(result)
@ -368,6 +354,9 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -402,7 +391,7 @@ class Notifier(object):
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values(): for stream in self.user_to_user_stream.values():
if stream.listeners: if stream.count_listeners():
continue continue
if stream.last_notified_ms < expire_before_ts: if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream) expired_streams.append(stream)

View File

@ -39,10 +39,10 @@ class HttpTransactionStore(object):
A tuple of (HTTP response code, response content) or None. A tuple of (HTTP response code, response content) or None.
""" """
try: try:
logger.debug("get_response Key: %s TxnId: %s", key, txn_id) logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key] (last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id: if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", key) logger.info("get_response: Returning a response for %s", txn_id)
return response return response
except KeyError: except KeyError:
pass pass
@ -58,7 +58,7 @@ class HttpTransactionStore(object):
txn_id (str): The transaction ID for this request. txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content) response (tuple): A tuple of (HTTP response code, response content)
""" """
logger.debug("store_response Key: %s TxnId: %s", key, txn_id) logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response) self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response): def store_client_transaction(self, request, txn_id, response):

View File

@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs) return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel() timer.cancel()
except:
if not ignore_errs:
raise
def time_bound_deferred(self, given_deferred, time_out): def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called: if given_deferred.called:

View File

@ -38,6 +38,9 @@ class ObservableDeferred(object):
deferred. deferred.
If consumeErrors is true errors will be captured from the origin deferred. If consumeErrors is true errors will be captured from the origin deferred.
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
""" """
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
@ -45,7 +48,7 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False): def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) self._result = (True, r)
@ -74,12 +77,21 @@ class ObservableDeferred(object):
def observe(self): def observe(self):
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
self._observers.append(d)
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._deferred, name) return getattr(self._deferred, name)

View File

@ -57,6 +57,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, event interested_service, event
) )
@defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=None)
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value={
"name": user_id
})
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar" room_alias_str = "#foo:bar"