Merge branch 'release-v0.20.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-04-11 11:12:37 +01:00
commit 4902db1fc9
100 changed files with 3724 additions and 1355 deletions

View File

@ -1,3 +1,56 @@
Changes in synapse v0.20.0 (2017-04-11)
=======================================
Bug fixes:
* Fix joining rooms over federation where not all servers in the room saw the
new server had joined (PR #2094)
Changes in synapse v0.20.0-rc1 (2017-03-30)
===========================================
Features:
* Add delete_devices API (PR #1993)
* Add phone number registration/login support (PR #1994, #2055)
Changes:
* Use JSONSchema for validation of filters. Thanks @pik! (PR #1783)
* Reread log config on SIGHUP (PR #1982)
* Speed up public room list (PR #1989)
* Add helpful texts to logger config options (PR #1990)
* Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022)
* Add some debug to help diagnose weird federation issue (PR #2035)
* Correctly limit retries for all federation requests (PR #2050, #2061)
* Don't lock table when persisting new one time keys (PR #2053)
* Reduce some CPU work on DB threads (PR #2054)
* Cache hosts in room (PR #2060)
* Batch sending of device list pokes (PR #2063)
* Speed up persist event path in certain edge cases (PR #2070)
Bug fixes:
* Fix bug where current_state_events renamed to current_state_ids (PR #1849)
* Fix routing loop when fetching remote media (PR #1992)
* Fix current_state_events table to not lie (PR #1996)
* Fix CAS login to handle PartialDownloadError (PR #1997)
* Fix assertion to stop transaction queue getting wedged (PR #2010)
* Fix presence to fallback to last_active_ts if it beats the last sync time.
Thanks @Half-Shot! (PR #2014)
* Fix bug when federation received a PDU while a room join is in progress (PR
#2016)
* Fix resetting state on rejected events (PR #2025)
* Fix installation issues in readme. Thanks @ricco386 (PR #2037)
* Fix caching of remote servers' signature keys (PR #2042)
* Fix some leaking log context (PR #2048, #2049, #2057, #2058)
* Fix rejection of invites not reaching sync (PR #2056)
Changes in synapse v0.19.3 (2017-03-20) Changes in synapse v0.19.3 (2017-03-20)
======================================= =======================================

View File

@ -146,6 +146,7 @@ To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate source ~/.synapse/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools pip install --upgrade setuptools
pip install https://github.com/matrix-org/synapse/tarball/master pip install https://github.com/matrix-org/synapse/tarball/master
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
New user localpart: erikj New user localpart: erikj
Password: Password:
Confirm password: Confirm password:
Make admin [no]:
Success! Success!
This process uses a setting ``registration_shared_secret`` in This process uses a setting ``registration_shared_secret`` in
@ -808,7 +810,7 @@ directory of your choice::
Synapse has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
virtualenv env virtualenv -p python2.7 env
source env/bin/activate source env/bin/activate
python synapse/python_dependencies.py | xargs pip install python synapse/python_dependencies.py | xargs pip install
pip install lxml mock pip install lxml mock

View File

@ -39,7 +39,9 @@ loggers:
synapse: synapse:
level: INFO level: INFO
synapse.storage: synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
level: INFO level: INFO
# example of enabling debugging for a component: # example of enabling debugging for a component:

View File

@ -1,10 +1,446 @@
What do I do about "Unexpected logging context" debug log-lines everywhere? Log contexts
============
<Mjark> The logging context lives in thread local storage .. contents::
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
Unanswered: how and when should we preserve log context? To help track the processing of individual requests, synapse uses a
'log context' to track which request it is handling at any given moment. This
is done via a thread-local variable; a ``logging.Filter`` is then used to fish
the information back out of the thread-local variable and add it to each log
record.
Logcontexts are also used for CPU and database accounting, so that we can track
which requests were responsible for high CPU use or database activity.
The ``synapse.util.logcontext`` module provides a facilities for managing the
current log context (as well as providing the ``LoggingContextFilter`` class).
Deferreds make the whole thing complicated, so this document describes how it
all works, and how to write code which follows the rules.
Logcontexts without Deferreds
-----------------------------
In the absence of any Deferred voodoo, things are simple enough. As with any
code of this nature, the rule is that our function should leave things as it
found them:
.. code:: python
from synapse.util import logcontext # omitted from future snippets
def handle_request(request_id):
request_context = logcontext.LoggingContext()
calling_context = logcontext.LoggingContext.current_context()
logcontext.LoggingContext.set_current_context(request_context)
try:
request_context.request = request_id
do_request_handling()
logger.debug("finished")
finally:
logcontext.LoggingContext.set_current_context(calling_context)
def do_request_handling():
logger.debug("phew") # this will be logged against request_id
LoggingContext implements the context management methods, so the above can be
written much more succinctly as:
.. code:: python
def handle_request(request_id):
with logcontext.LoggingContext() as request_context:
request_context.request = request_id
do_request_handling()
logger.debug("finished")
def do_request_handling():
logger.debug("phew")
Using logcontexts with Deferreds
--------------------------------
Deferreds — and in particular, ``defer.inlineCallbacks`` — break
the linear flow of code so that there is no longer a single entry point where
we should set the logcontext and a single exit point where we should remove it.
Consider the example above, where ``do_request_handling`` needs to do some
blocking operation, and returns a deferred:
.. code:: python
@defer.inlineCallbacks
def handle_request(request_id):
with logcontext.LoggingContext() as request_context:
request_context.request = request_id
yield do_request_handling()
logger.debug("finished")
In the above flow:
* The logcontext is set
* ``do_request_handling`` is called, and returns a deferred
* ``handle_request`` yields the deferred
* The ``inlineCallbacks`` wrapper of ``handle_request`` returns a deferred
So we have stopped processing the request (and will probably go on to start
processing the next), without clearing the logcontext.
To circumvent this problem, synapse code assumes that, wherever you have a
deferred, you will want to yield on it. To that end, whereever functions return
a deferred, we adopt the following conventions:
**Rules for functions returning deferreds:**
* If the deferred is already complete, the function returns with the same
logcontext it started with.
* If the deferred is incomplete, the function clears the logcontext before
returning; when the deferred completes, it restores the logcontext before
running any callbacks.
That sounds complicated, but actually it means a lot of code (including the
example above) "just works". There are two cases:
* If ``do_request_handling`` returns a completed deferred, then the logcontext
will still be in place. In this case, execution will continue immediately
after the ``yield``; the "finished" line will be logged against the right
context, and the ``with`` block restores the original context before we
return to the caller.
* If the returned deferred is incomplete, ``do_request_handling`` clears the
logcontext before returning. The logcontext is therefore clear when
``handle_request`` yields the deferred. At that point, the ``inlineCallbacks``
wrapper adds a callback to the deferred, and returns another (incomplete)
deferred to the caller, and it is safe to begin processing the next request.
Once ``do_request_handling``'s deferred completes, it will reinstate the
logcontext, before running the callback added by the ``inlineCallbacks``
wrapper. That callback runs the second half of ``handle_request``, so again
the "finished" line will be logged against the right
context, and the ``with`` block restores the original context.
As an aside, it's worth noting that ``handle_request`` follows our rules -
though that only matters if the caller has its own logcontext which it cares
about.
The following sections describe pitfalls and helpful patterns when implementing
these rules.
Always yield your deferreds
---------------------------
Whenever you get a deferred back from a function, you should ``yield`` on it
as soon as possible. (Returning it directly to your caller is ok too, if you're
not doing ``inlineCallbacks``.) Do not pass go; do not do any logging; do not
call any other functions.
.. code:: python
@defer.inlineCallbacks
def fun():
logger.debug("starting")
yield do_some_stuff() # just like this
d = more_stuff()
result = yield d # also fine, of course
defer.returnValue(result)
def nonInlineCallbacksFun():
logger.debug("just a wrapper really")
return do_some_stuff() # this is ok too - the caller will yield on
# it anyway.
Provided this pattern is followed all the way back up to the callchain to where
the logcontext was set, this will make things work out ok: provided
``do_some_stuff`` and ``more_stuff`` follow the rules above, then so will
``fun`` (as wrapped by ``inlineCallbacks``) and ``nonInlineCallbacksFun``.
It's all too easy to forget to ``yield``: for instance if we forgot that
``do_some_stuff`` returned a deferred, we might plough on regardless. This
leads to a mess; it will probably work itself out eventually, but not before
a load of stuff has been logged against the wrong content. (Normally, other
things will break, more obviously, if you forget to ``yield``, so this tends
not to be a major problem in practice.)
Of course sometimes you need to do something a bit fancier with your Deferreds
- not all code follows the linear A-then-B-then-C pattern. Notes on
implementing more complex patterns are in later sections.
Where you create a new Deferred, make it follow the rules
---------------------------------------------------------
Most of the time, a Deferred comes from another synapse function. Sometimes,
though, we need to make up a new Deferred, or we get a Deferred back from
external code. We need to make it follow our rules.
The easy way to do it is with a combination of ``defer.inlineCallbacks``, and
``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``,
which returns a deferred which will run its callbacks after a given number of
seconds. That might look like:
.. code:: python
# not a logcontext-rules-compliant function
def get_sleep_deferred(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, None)
return d
That doesn't follow the rules, but we can fix it by wrapping it with
``PreserveLoggingContext`` and ``yield`` ing on it:
.. code:: python
@defer.inlineCallbacks
def sleep(seconds):
with PreserveLoggingContext():
yield get_sleep_deferred(seconds)
This technique works equally for external functions which return deferreds,
or deferreds we have made ourselves.
You can also use ``logcontext.make_deferred_yieldable``, which just does the
boilerplate for you, so the above could be written:
.. code:: python
def sleep(seconds):
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
Fire-and-forget
---------------
Sometimes you want to fire off a chain of execution, but not wait for its
result. That might look a bit like this:
.. code:: python
@defer.inlineCallbacks
def do_request_handling():
yield foreground_operation()
# *don't* do this
background_operation()
logger.debug("Request handling complete")
@defer.inlineCallbacks
def background_operation():
yield first_background_step()
logger.debug("Completed first step")
yield second_background_step()
logger.debug("Completed second step")
The above code does a couple of steps in the background after
``do_request_handling`` has finished. The log lines are still logged against
the ``request_context`` logcontext, which may or may not be desirable. There
are two big problems with the above, however. The first problem is that, if
``background_operation`` returns an incomplete Deferred, it will expect its
caller to ``yield`` immediately, so will have cleared the logcontext. In this
example, that means that 'Request handling complete' will be logged without any
context.
The second problem, which is potentially even worse, is that when the Deferred
returned by ``background_operation`` completes, it will restore the original
logcontext. There is nothing waiting on that Deferred, so the logcontext will
leak into the reactor and possibly get attached to some arbitrary future
operation.
There are two potential solutions to this.
One option is to surround the call to ``background_operation`` with a
``PreserveLoggingContext`` call. That will reset the logcontext before
starting ``background_operation`` (so the context restored when the deferred
completes will be the empty logcontext), and will restore the current
logcontext before continuing the foreground process:
.. code:: python
@defer.inlineCallbacks
def do_request_handling():
yield foreground_operation()
# start background_operation off in the empty logcontext, to
# avoid leaking the current context into the reactor.
with PreserveLoggingContext():
background_operation()
# this will now be logged against the request context
logger.debug("Request handling complete")
Obviously that option means that the operations done in
``background_operation`` would be not be logged against a logcontext (though
that might be fixed by setting a different logcontext via a ``with
LoggingContext(...)`` in ``background_operation``).
The second option is to use ``logcontext.preserve_fn``, which wraps a function
so that it doesn't reset the logcontext even when it returns an incomplete
deferred, and adds a callback to the returned deferred to reset the
logcontext. In other words, it turns a function that follows the Synapse rules
about logcontexts and Deferreds into one which behaves more like an external
function — the opposite operation to that described in the previous section.
It can be used like this:
.. code:: python
@defer.inlineCallbacks
def do_request_handling():
yield foreground_operation()
logcontext.preserve_fn(background_operation)()
# this will now be logged against the request context
logger.debug("Request handling complete")
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
but the fact that it does ``preserve_context_over_deferred`` on its results
means that its use is fraught with difficulty.
Passing synapse deferreds into third-party functions
----------------------------------------------------
A typical example of this is where we want to collect together two or more
deferred via ``defer.gatherResults``:
.. code:: python
d1 = operation1()
d2 = operation2()
d3 = defer.gatherResults([d1, d2])
This is really a variation of the fire-and-forget problem above, in that we are
firing off ``d1`` and ``d2`` without yielding on them. The difference
is that we now have third-party code attached to their callbacks. Anyway either
technique given in the `Fire-and-forget`_ section will work.
Of course, the new Deferred returned by ``gatherResults`` needs to be wrapped
in order to make it follow the logcontext rules before we can yield it, as
described in `Where you create a new Deferred, make it follow the rules`_.
So, option one: reset the logcontext before starting the operations to be
gathered:
.. code:: python
@defer.inlineCallbacks
def do_request_handling():
with PreserveLoggingContext():
d1 = operation1()
d2 = operation2()
result = yield defer.gatherResults([d1, d2])
In this case particularly, though, option two, of using
``logcontext.preserve_fn`` almost certainly makes more sense, so that
``operation1`` and ``operation2`` are both logged against the original
logcontext. This looks like:
.. code:: python
@defer.inlineCallbacks
def do_request_handling():
d1 = logcontext.preserve_fn(operation1)()
d2 = logcontext.preserve_fn(operation2)()
with PreserveLoggingContext():
result = yield defer.gatherResults([d1, d2])
Was all this really necessary?
------------------------------
The conventions used work fine for a linear flow where everything happens in
series via ``defer.inlineCallbacks`` and ``yield``, but are certainly tricky to
follow for any more exotic flows. It's hard not to wonder if we could have done
something else.
We're not going to rewrite Synapse now, so the following is entirely of
academic interest, but I'd like to record some thoughts on an alternative
approach.
I briefly prototyped some code following an alternative set of rules. I think
it would work, but I certainly didn't get as far as thinking how it would
interact with concepts as complicated as the cache descriptors.
My alternative rules were:
* functions always preserve the logcontext of their caller, whether or not they
are returning a Deferred.
* Deferreds returned by synapse functions run their callbacks in the same
context as the function was orignally called in.
The main point of this scheme is that everywhere that sets the logcontext is
responsible for clearing it before returning control to the reactor.
So, for example, if you were the function which started a ``with
LoggingContext`` block, you wouldn't ``yield`` within it — instead you'd start
off the background process, and then leave the ``with`` block to wait for it:
.. code:: python
def handle_request(request_id):
with logcontext.LoggingContext() as request_context:
request_context.request = request_id
d = do_request_handling()
def cb(r):
logger.debug("finished")
d.addCallback(cb)
return d
(in general, mixing ``with LoggingContext`` blocks and
``defer.inlineCallbacks`` in the same function leads to slighly
counter-intuitive code, under this scheme).
Because we leave the original ``with`` block as soon as the Deferred is
returned (as opposed to waiting for it to be resolved, as we do today), the
logcontext is cleared before control passes back to the reactor; so if there is
some code within ``do_request_handling`` which needs to wait for a Deferred to
complete, there is no need for it to worry about clearing the logcontext before
doing so:
.. code:: python
def handle_request():
r = do_some_stuff()
r.addCallback(do_some_more_stuff)
return r
— and provided ``do_some_stuff`` follows the rules of returning a Deferred which
runs its callbacks in the original logcontext, all is happy.
The business of a Deferred which runs its callbacks in the original logcontext
isn't hard to achieve — we have it today, in the shape of
``logcontext._PreservingContextDeferred``:
.. code:: python
def do_some_stuff():
deferred = do_some_io()
pcd = _PreservingContextDeferred(LoggingContext.current_context())
deferred.chainDeferred(pcd)
return pcd
It turns out that, thanks to the way that Deferreds chain together, we
automatically get the property of a context-preserving deferred with
``defer.inlineCallbacks``, provided the final Defered the function ``yields``
on has that property. So we can just write:
.. code:: python
@defer.inlineCallbacks
def handle_request():
yield do_some_stuff()
yield do_some_more_stuff()
To conclude: I think this scheme would have worked equally well, with less
danger of messing it up, and probably made some more esoteric code easier to
write. But again — changing the conventions of the entire Synapse codebase is
not a sensible option for the marginal improvement offered.

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.19.3" __version__ = "0.20.0"

View File

@ -23,7 +23,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn from synapse.util import logcontext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -209,8 +209,7 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
preserve_context_over_fn( logcontext.preserve_fn(self.store.insert_client_ip)(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -44,6 +45,7 @@ class JoinRules(object):
class LoginType(object): class LoginType(object):
PASSWORD = u"m.login.password" PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy" DUMMY = u"m.login.dummy"

View File

@ -15,6 +15,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,27 +51,35 @@ class Codes(object):
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes.
Attributes:
code (int): HTTP error code
msg (str): string describing the error
"""
def __init__(self, code, msg): def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code self.code = code
self.msg = msg self.msg = msg
self.response_code_message = None
def error_dict(self): def error_dict(self):
return cs_error(self.msg) return cs_error(self.msg)
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base error which can be caught for all synapse events.""" """A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN): def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error. """Constructs a synapse error.
Args: Args:
code (int): The integer error code (an HTTP response code) code (int): The integer error code (an HTTP response code)
msg (str): The human-readable error message. msg (str): The human-readable error message.
err (str): The error code e.g 'M_FORBIDDEN' errcode (str): The matrix error code e.g 'M_FORBIDDEN'
""" """
super(SynapseError, self).__init__(code, msg) super(SynapseError, self).__init__(code, msg)
self.errcode = errcode self.errcode = errcode
@ -81,6 +90,39 @@ class SynapseError(CodeMessageException):
self.errcode, self.errcode,
) )
@classmethod
def from_http_response_exception(cls, err):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to
decide how to map the failure onto a matrix error to send back to the
client.
An attempt is made to parse the body of the http response as a matrix
error. If that succeeds, the errcode and error message from the body
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Args:
err (HttpResponseException):
Returns:
SynapseError:
"""
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json.loads(err.response)
except ValueError:
j = {}
errcode = j.get('errcode', Codes.UNKNOWN)
errmsg = j.get('error', err.msg)
res = SynapseError(err.code, errmsg, errcode)
return res
class RegistrationError(SynapseError): class RegistrationError(SynapseError):
"""An error raised when a registration event fails.""" """An error raised when a registration event fails."""
@ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError):
class NotFoundError(SynapseError): class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for""" """An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs): def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__( super(NotFoundError, self).__init__(
404, 404,
"Not found", msg,
**kwargs errcode=errcode
) )
@ -173,7 +213,6 @@ class LimitExceededError(SynapseError):
errcode=Codes.LIMIT_EXCEEDED): errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode) super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
self.response_code_message = "Too Many Requests"
def error_dict(self): def error_dict(self):
return cs_error( return cs_error(
@ -243,6 +282,19 @@ class FederationError(RuntimeError):
class HttpResponseException(CodeMessageException): class HttpResponseException(CodeMessageException):
"""
Represents an HTTP-level failure of an outbound request
Attributes:
response (str): body of response
"""
def __init__(self, code, msg, response): def __init__(self, code, msg, response):
self.response = response """
Args:
code (int): HTTP status code
msg (str): reason phrase from HTTP response status line
response (str): body of response
"""
super(HttpResponseException, self).__init__(code, msg) super(HttpResponseException, self).__init__(code, msg)
self.response = response

View File

@ -13,11 +13,174 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
import jsonschema
from jsonschema import FormatChecker
FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"ephemeral": {
"$ref": "#/definitions/room_event_filter"
},
"include_leave": {
"type": "boolean"
},
"state": {
"$ref": "#/definitions/room_event_filter"
},
"timeline": {
"$ref": "#/definitions/room_event_filter"
},
"account_data": {
"$ref": "#/definitions/room_event_filter"
},
}
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"contains_url": {
"type": "boolean"
}
}
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_user_id"
}
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_room_id"
}
}
USER_FILTER_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "schema for a Sync filter",
"type": "object",
"definitions": {
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
},
"properties": {
"presence": {
"$ref": "#/definitions/filter"
},
"account_data": {
"$ref": "#/definitions/filter"
},
"room": {
"$ref": "#/definitions/room_filter"
},
"event_format": {
"type": "string",
"enum": ["client", "federation"]
},
"event_fields": {
"type": "array",
"items": {
"type": "string",
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
"pattern": "^((?!\\\).)*$"
}
}
},
"additionalProperties": False
}
@FormatChecker.cls_checks('matrix_room_id')
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks('matrix_user_id')
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object): class Filtering(object):
@ -52,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an # NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of # individual top-level key e.g. public_user_data. Filters are made of
# many definitions. # many definitions.
try:
top_level_definitions = [ jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
"presence", "account_data" format_checker=FormatChecker())
] except jsonschema.ValidationError as e:
raise SynapseError(400, e.message)
room_level_definitions = [
"state", "timeline", "ephemeral", "account_data"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
self._check_definition_room_lists(user_filter_json["room"])
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they
are present
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
self._check_definition_room_lists(definition)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
class FilterCollection(object): class FilterCollection(object):
@ -253,19 +329,35 @@ class Filter(object):
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
sender = event.user_id
room_id = None
ev_type = "m.presence"
is_url = False
else:
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
# Presence events have their 'sender' in content.user_id # Presence events had their 'sender' in content.user_id, but are
# now handled above. We don't know if anything else uses this
# form. TODO: Check this and probably remove it.
content = event.get("content") content = event.get("content")
# account_data has been allowed to have non-dict content, so check type first # account_data has been allowed to have non-dict content, so
# check type first
if isinstance(content, dict): if isinstance(content, dict):
sender = content.get("user_id") sender = content.get("user_id")
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
is_url = "url" in event.get("content", {})
return self.check_fields( return self.check_fields(
event.get("room_id", None), room_id,
sender, sender,
event.get("type", None), ev_type,
"url" in event.get("content", {}) is_url,
) )
def check_fields(self, room_id, sender, event_type, contains_url): def check_fields(self, room_id, sender, event_type, contains_url):

View File

@ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -157,7 +157,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.appservice" assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -187,7 +187,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners) ps.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
DirectoryStore, DirectoryStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
@ -171,7 +173,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.client_reader" assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -193,7 +195,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -31,7 +31,7 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -162,7 +162,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_reader" assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -184,7 +184,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -35,7 +35,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -160,7 +160,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_sender" assert config.worker_app == "synapse.app.federation_sender"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -193,7 +193,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners) ps.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -20,6 +20,8 @@ import gc
import logging import logging
import os import os
import sys import sys
import synapse.config.logger
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.python_dependencies import ( from synapse.python_dependencies import (
@ -50,7 +52,7 @@ from synapse.api.urls import (
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
@ -286,7 +288,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue. # generating config files and shouldn't try to continue.
sys.exit(0) sys.exit(0)
config.setup_logging() synapse.config.logger.setup_logging(config, use_worker_options=False)
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
@ -454,7 +456,12 @@ def run(hs):
def in_thread(): def in_thread():
# Uncomment to enable tracing of log context changes. # Uncomment to enable tracing of log context changes.
# sys.settrace(logcontext_tracer) # sys.settrace(logcontext_tracer)
with LoggingContext("run"):
# make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds: if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds) gc.set_threshold(*hs.config.gc_thresholds)

View File

@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -32,7 +33,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore( class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
ClientIpStore, ClientIpStore,
@ -168,7 +170,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository" assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -190,7 +192,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -31,7 +31,8 @@ from synapse.storage.engines import create_engine
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn, \
PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -245,7 +246,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.pusher" assert config.worker_app == "synapse.app.pusher"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -275,7 +276,11 @@ def start(config_options):
ps.start_listening(config.worker_listeners) ps.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging from synapse.config.logger import setup_logging
from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn, \
PreserveLoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -399,8 +399,7 @@ class SynchrotronServer(HomeServer):
position = row[position_index] position = row[position_index]
user_id = row[user_index] user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id) room_ids = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event( notifier.on_new_event(
"device_list_key", position, rooms=room_ids, "device_list_key", position, rooms=room_ids,
@ -411,11 +410,16 @@ class SynchrotronServer(HomeServer):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
max_position = stream["position"] max_position = stream["position"]
event_map = yield store.get_events([row[1] for row in stream["rows"]])
for row in stream["rows"]: for row in stream["rows"]:
position = row[0] position = row[0]
internal = json.loads(row[1]) event_id = row[1]
event_json = json.loads(row[2]) event = event_map.get(event_id, None)
event = FrozenEvent(event_json, internal_metadata_dict=internal) if not event:
continue
extra_users = () extra_users = ()
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
extra_users = (event.state_key,) extra_users = (event.state_key,)
@ -478,7 +482,7 @@ def start(config_options):
assert config.worker_app == "synapse.app.synchrotron" assert config.worker_app == "synapse.app.synchrotron"
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -497,7 +501,11 @@ def start(config_options):
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def run(): def run():
with LoggingContext("run"): # make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running") logger.info("Running")
change_resource_limit(config.soft_file_limit) change_resource_limit(config.soft_file_limit)
if config.gc_thresholds: if config.gc_thresholds:

View File

@ -23,14 +23,27 @@ import signal
import subprocess import subprocess
import sys import sys
import yaml import yaml
import errno
import time
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
YELLOW = "\x1b[1;33m"
RED = "\x1b[1;31m" RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
def pid_running(pid):
try:
os.kill(pid, 0)
return True
except OSError, err:
if err.errno == errno.EPERM:
return True
return False
def write(message, colour=NORMAL, stream=sys.stdout): def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL: if colour == NORMAL:
stream.write(message + "\n") stream.write(message + "\n")
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n") stream.write(colour + message + NORMAL + "\n")
def abort(message, colour=RED, stream=sys.stderr):
write(message, colour, stream)
sys.exit(1)
def start(configfile): def start(configfile):
write("Starting ...") write("Starting ...")
args = SYNAPSE args = SYNAPSE
@ -45,7 +63,8 @@ def start(configfile):
try: try:
subprocess.check_call(args) subprocess.check_call(args)
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) write("started synapse.app.homeserver(%r)" %
(configfile,), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
write( write(
"error starting (exit code: %d); see above for logs" % e.returncode, "error starting (exit code: %d); see above for logs" % e.returncode,
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
def stop(pidfile, app): def stop(pidfile, app):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
try:
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN) write("stopped %s" % (app,), colour=GREEN)
except OSError, err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
elif err.errno == errno.EPERM:
abort("Cannot stop %s: Operation not permitted" % (app,))
else:
abort("Cannot stop %s: Unknown error" % (app,))
Worker = collections.namedtuple("Worker", [ Worker = collections.namedtuple("Worker", [
@ -190,7 +217,19 @@ def main():
if start_stop_synapse: if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver") stop(pidfile, "synapse.app.homeserver")
# TODO: Wait for synapse to actually shutdown before starting it again # Wait for synapse to actually shutdown before starting it again
if action == "restart":
running_pids = []
if start_stop_synapse and os.path.exists(pidfile):
running_pids.append(int(open(pidfile).read()))
for worker in workers:
if os.path.exists(worker.pidfile):
running_pids.append(int(open(worker.pidfile).read()))
if len(running_pids) > 0:
write("Waiting for process to exit before restarting...")
for running_pid in running_pids:
while pid_running(running_pid):
time.sleep(0.2)
if action == "start" or action == "restart": if action == "start" or action == "restart":
if start_stop_synapse: if start_stop_synapse:

View File

@ -45,7 +45,6 @@ handlers:
maxBytes: 104857600 maxBytes: 104857600
backupCount: 10 backupCount: 10
filters: [context] filters: [context]
level: INFO
console: console:
class: logging.StreamHandler class: logging.StreamHandler
formatter: precise formatter: precise
@ -56,6 +55,8 @@ loggers:
level: INFO level: INFO
synapse.storage.SQL: synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
level: INFO level: INFO
root: root:
@ -68,6 +69,7 @@ class LoggingConfig(Config):
def read_config(self, config): def read_config(self, config):
self.verbosity = config.get("verbose", 0) self.verbosity = config.get("verbose", 0)
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
@ -77,10 +79,10 @@ class LoggingConfig(Config):
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
) )
return """ return """
# Logging verbosity level. # Logging verbosity level. Ignored if log_config is specified.
verbose: 0 verbose: 0
# File to write logging to # File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s" log_file: "%(log_file)s"
# A yaml python logging config file # A yaml python logging config file
@ -90,6 +92,8 @@ class LoggingConfig(Config):
def read_arguments(self, args): def read_arguments(self, args):
if args.verbose is not None: if args.verbose is not None:
self.verbosity = args.verbose self.verbosity = args.verbose
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
if args.log_config is not None: if args.log_config is not None:
self.log_config = args.log_config self.log_config = args.log_config
if args.log_file is not None: if args.log_file is not None:
@ -99,16 +103,22 @@ class LoggingConfig(Config):
logging_group = parser.add_argument_group("logging") logging_group = parser.add_argument_group("logging")
logging_group.add_argument( logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count', '-v', '--verbose', dest="verbose", action='count',
help="The verbosity level." help="The verbosity level. Specify multiple times to increase "
"verbosity. (Ignored if --log-config is specified.)"
) )
logging_group.add_argument( logging_group.add_argument(
'-f', '--log-file', dest="log_file", '-f', '--log-file', dest="log_file",
help="File to log to." help="File to log to. (Ignored if --log-config is specified.)"
) )
logging_group.add_argument( logging_group.add_argument(
'--log-config', dest="log_config", default=None, '--log-config', dest="log_config", default=None,
help="Python logging config file" help="Python logging config file"
) )
logging_group.add_argument(
'-n', '--no-redirect-stdio',
action='store_true', default=None,
help="Do not redirect stdout/stderr to the log"
)
def generate_files(self, config): def generate_files(self, config):
log_config = config.get("log_config") log_config = config.get("log_config")
@ -118,11 +128,22 @@ class LoggingConfig(Config):
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
) )
def setup_logging(self):
setup_logging(self.log_config, self.log_file, self.verbosity)
def setup_logging(config, use_worker_options=False):
""" Set up python logging
Args:
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'.
"""
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
log_file = (config.worker_log_file if use_worker_options
else config.log_file)
def setup_logging(log_config=None, log_file=None, verbosity=None):
log_format = ( log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s" " - %(message)s"
@ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
level = logging.INFO level = logging.INFO
level_for_storage = logging.INFO level_for_storage = logging.INFO
if verbosity: if config.verbosity:
level = logging.DEBUG level = logging.DEBUG
if verbosity > 1: if config.verbosity > 1:
level_for_storage = logging.DEBUG level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option # FIXME: we need a logging.WARN for a -q quiet option
@ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.info("Closing log file due to SIGHUP") logger.info("Closing log file due to SIGHUP")
handler.doRollover() handler.doRollover()
logger.info("Opened new log file due to SIGHUP") logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
@ -169,9 +182,26 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
logger.addHandler(handler) logger.addHandler(handler)
else: else:
def load_log_config():
with open(log_config, 'r') as f: with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f)) logging.config.dictConfig(yaml.load(f))
def sighup(signum, stack):
# it might be better to use a file watcher or something for this.
logging.info("Reloading log config from %s due to SIGHUP",
log_config)
load_log_config()
load_log_config()
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
# It's critical to point twisted's internal logging somewhere, otherwise it # It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object; # stacks up and leaks kup to 64K object;
# see: https://twistedmatrix.com/trac/ticket/8164 # see: https://twistedmatrix.com/trac/ticket/8164
@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
# #
# However this may not be too much of a problem if we are just writing to a file. # However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver() observer = STDLibLogObserver()
globalLogBeginner.beginLoggingTo([observer]) globalLogBeginner.beginLoggingTo(
[observer],
redirectStandardIO=not config.no_redirect_stdio,
)

View File

@ -15,7 +15,6 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import ( from synapse.util.logcontext import (
@ -96,10 +95,11 @@ class Keyring(object):
verify_requests = [] verify_requests = []
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
logger.warn("Request from %s: no supported signature keys",
server_name)
deferred = defer.fail(SynapseError( deferred = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
@ -108,6 +108,9 @@ class Keyring(object):
else: else:
deferred = defer.Deferred() deferred = defer.Deferred()
logger.debug("Verifying for %s with key_ids %s",
server_name, key_ids)
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred server_name, key_ids, json_object, deferred
) )
@ -142,6 +145,9 @@ class Keyring(object):
json_object = verify_request.json_object json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % (
key_id, verify_key.alg, verify_key.version, server_name,
))
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:
@ -231,8 +237,14 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for """Tries to find at least one key for each verify request
each group.
For each verify_request, verify_request.deferred is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
Args:
verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
# These are functions that produce keys given a list of key ids # These are functions that produce keys given a list of key ids
@ -245,8 +257,11 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
# dict[str, dict[str, VerifyKey]]: results so far.
# map server_name -> key_id -> VerifyKey
merged_results = {} merged_results = {}
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {} missing_keys = {}
for verify_request in verify_requests: for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update( missing_keys.setdefault(verify_request.server_name, set()).update(
@ -308,6 +323,16 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (list[(str, iterable[str])]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
server_name -> key_id -> VerifyKey
"""
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( preserve_fn(self.store.get_server_verify_keys)(
@ -356,12 +381,6 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids): def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(server_name, key_ids): def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(

View File

@ -15,6 +15,32 @@
class EventContext(object): class EventContext(object):
"""
Attributes:
current_state_ids (dict[(str, str), str]):
The current state map including the current event.
(type, state_key) -> event_id
prev_state_ids (dict[(str, str), str]):
The current state map excluding the current event.
(type, state_key) -> event_id
state_group (int): state group id
rejected (bool|str): A rejection reason if the event was rejected, else
False
push_actions (list[(str, list[object])]): list of (user_id, actions)
tuples
prev_group (int): Previously persisted state group. ``None`` for an
outlier.
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
(type, state_key) -> event_id. ``None`` for an outlier.
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
"""
__slots__ = [ __slots__ = [
"current_state_ids", "current_state_ids",
"prev_state_ids", "prev_state_ids",

View File

@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent, builder from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import copy import copy
import itertools import itertools
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=False): retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler(). handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details args (dict): Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the a Deferred which will eventually yield a JSON object from the
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type) sent_queries_counter.inc(query_type)
return self.transport_layer.make_query( return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
@log_function @log_function
@ -234,13 +237,6 @@ class FederationClient(FederationBase):
continue continue
try: try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
with limiter:
transaction_data = yield self.transport_layer.get_event( transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout, destination, event_id, timeout=timeout,
) )

View File

@ -52,7 +52,6 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer("fed_server") self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
@ -147,11 +146,15 @@ class FederationServer(FederationBase):
# check that it's actually being sent from a valid destination to # check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6 # workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id): if transaction.origin != get_domain_from_id(pdu.event_id):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
if not ( if not (
pdu.type == 'm.room.member' and pdu.type == 'm.room.member' and
pdu.content and pdu.content and
pdu.content.get("membership", None) == 'join' and pdu.content.get("membership", None) == 'join'
self.hs.is_mine_id(pdu.state_key)
): ):
logger.info( logger.info(
"Discarding PDU %s from invalid origin %s", "Discarding PDU %s from invalid origin %s",
@ -165,7 +168,7 @@ class FederationServer(FederationBase):
) )
try: try:
yield self._handle_new_pdu(transaction.origin, pdu) yield self._handle_received_pdu(transaction.origin, pdu)
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
@ -497,27 +500,16 @@ class FederationServer(FederationBase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def _handle_received_pdu(self, origin, pdu):
def _handle_new_pdu(self, origin, pdu, get_missing=True): """ Process a PDU received in a federation /send/ transaction.
# We reprocess pdus when we have seen them only as outliers Args:
existing = yield self._get_persisted_pdu( origin (str): server which sent the pdu
origin, pdu.event_id, do_auth=False pdu (FrozenEvent): received pdu
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
return
Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match
"""
# Check signature. # Check signature.
try: try:
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
@ -529,143 +521,7 @@ class FederationServer(FederationBase):
affected=pdu.event_id, affected=pdu.event_id,
) )
state = None yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring lock for room %r to fetch %d missing events: %r...",
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired lock for room %r to fetch %d missing events",
pdu.room_id, len(prevs - seen),
)
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if prevs - seen:
latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest |= seen
logger.info(
"Missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
# XXX: we set timeout to 10s to help workaround
# https://github.com/matrix-org/synapse/issues/1733.
# The reason is to avoid holding the linearizer lock
# whilst processing inbound /send transactions, causing
# FDs to stack up and block other inbound transactions
# which empirically can currently take up to 30 minutes.
#
# N.B. this explicitly disables retry attempts.
#
# N.B. this also increases our chances of falling back to
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
timeout=10000,
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self._handle_new_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
pdu,
state=state,
auth_chain=auth_chain,
)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
def __init__(self, hs): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.presence_map = {} self.presence_map = {}
self.presence_changed = sorteddict() self.presence_changed = sorteddict()
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
else: else:
self.edus[pos] = edu self.edus[pos] = edu
self.notifier.on_new_replication_data()
def send_presence(self, destination, states): def send_presence(self, destination, states):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
(destination, state.user_id) for state in states (destination, state.user_id) for state in states
] ]
self.notifier.on_new_replication_data()
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
self.failures[pos] = (destination, str(failure)) self.failures[pos] = (destination, str(failure))
self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
self.device_messages[pos] = destination self.device_messages[pos] = destination
self.notifier.on_new_replication_data()
def get_current_token(self): def get_current_token(self):
return self.pos - 1 return self.pos - 1

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
from twisted.internet import defer from twisted.internet import defer
@ -22,9 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import ( from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -99,7 +97,12 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
# destination -> stream_id of last successfully sent device list
# update.
self.last_device_list_stream_id_by_dest = {} self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
@ -300,20 +303,20 @@ class TransactionQueue(object):
) )
return return
pending_pdus = []
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
# This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client,
# hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store)
# XXX: what's this for? # XXX: what's this for?
yield run_on_reactor() yield run_on_reactor()
pending_pdus = []
while True: while True:
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
@ -369,7 +372,6 @@ class TransactionQueue(object):
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
) )
if success: if success:
# Remove the acknowledged device messages from the database # Remove the acknowledged device messages from the database
@ -387,12 +389,24 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else: else:
break break
except NotRetryingDestination: except NotRetryingDestination as e:
logger.debug( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now", "dropping transaction for now",
destination, destination,
datetime.datetime.fromtimestamp(
(e.retry_last_ts + e.retry_interval) / 1000.0
),
) )
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
destination,
e,
)
for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
destination)
finally: finally:
# We want to be *very* sure we delete this after we stop processing # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None) self.pending_transactions.pop(destination, None)
@ -432,7 +446,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, limiter): pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -442,7 +456,6 @@ class TransactionQueue(object):
success = True success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id) txn_id = str(self._next_txn_id)
@ -483,7 +496,6 @@ class TransactionQueue(object):
len(failures), len(failures),
) )
with limiter:
# Actually send the transaction # Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age # FIXME (erikj): This is a bit of a hack to make the Pdu age
@ -543,31 +555,5 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
success = False success = False
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
defer.returnValue(success) defer.returnValue(success)

View File

@ -163,6 +163,7 @@ class TransportLayerClient(object):
data=json_data, data=json_data,
json_data_callback=json_data_callback, json_data_callback=json_data_callback,
long_retries=True, long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
) )
logger.debug( logger.debug(
@ -174,7 +175,8 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail): def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json( content = yield self.client.get_json(
@ -183,6 +185,7 @@ class TransportLayerClient(object):
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000, timeout=10000,
ignore_backoff=ignore_backoff,
) )
defer.returnValue(content) defer.returnValue(content)
@ -242,6 +245,7 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)
@ -269,6 +273,7 @@ class TransportLayerClient(object):
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args, args=args,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,6 +48,7 @@ class AuthHandler(BaseHandler):
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha, LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth, LoginType.DUMMY: self._check_dummy_auth,
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -307,31 +309,47 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _): def _check_email_identity(self, authdict, _):
return self._check_threepid('email', authdict)
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
yield run_on_reactor() yield run_on_reactor()
if 'threepid_creds' not in authdict: if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds'] threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds) threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid: if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid['medium'] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'" % (
medium, threepid['medium'],
),
errcode=Codes.UNAUTHORIZED
)
threepid['threepid_creds'] = authdict['threepid_creds'] threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid) defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self): def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key} return {"public_key": self.hs.config.recaptcha_public_key}

View File

@ -169,6 +169,40 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
device_ids (str): The list of device IDs to delete
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_devices(user_id, device_ids)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
@ -214,8 +248,7 @@ class DeviceHandler(BaseHandler):
user_id, device_ids, list(hosts) user_id, device_ids, list(hosts)
) )
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids, "device_list_key", position, rooms=room_ids,
@ -236,8 +269,7 @@ class DeviceHandler(BaseHandler):
user_id (str) user_id (str)
from_token (StreamToken) from_token (StreamToken)
""" """
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
# First we check if any devices have changed # First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed( changed = yield self.store.get_user_whose_devices_changed(
@ -262,7 +294,7 @@ class DeviceHandler(BaseHandler):
# ordering: treat it the same as a new room # ordering: treat it the same as a new room
event_ids = [] event_ids = []
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.store.get_current_state_ids(room_id)
# special-case for an empty prev state: include all members # special-case for an empty prev state: include all members
# in the changed list # in the changed list
@ -313,8 +345,8 @@ class DeviceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):
user_id = user.to_string() user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
if not rooms: if not room_ids:
# We no longer share rooms with this user, so we'll no longer # We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB. # receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
@ -370,8 +402,8 @@ class DeviceListEduUpdater(object):
logger.warning("Got device list update edu for %r from %r", user_id, origin) logger.warning("Got device list update edu for %r from %r", user_id, origin)
return return
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
if not rooms: if not room_ids:
# We don't share any rooms with this user. Ignore update, as we # We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates. # probably won't get any further updates.
return return

View File

@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
logging.warn("Error retrieving alias") logging.warn("Error retrieving alias")

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -121,10 +121,6 @@ class E2eKeysHandler(object):
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.query_client_keys( remote_result = yield self.federation.query_client_keys(
destination, destination,
{"device_keys": destination_query}, {"device_keys": destination_query},
@ -239,10 +235,6 @@ class E2eKeysHandler(object):
def claim_client_keys(destination): def claim_client_keys(destination):
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys( remote_result = yield self.federation.claim_client_keys(
destination, destination,
{"one_time_keys": device_keys}, {"one_time_keys": device_keys},
@ -316,7 +308,7 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we # old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with # need to double-check the device is registered to avoid ending up with
# keys without a corresponding device. # keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id) yield self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
import synapse.util.logcontext
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -31,7 +32,7 @@ from synapse.util.logcontext import (
) )
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
@ -79,29 +80,216 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
@log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): @log_function
""" Called by the ReplicationLayer when we have a new pdu. We need to def on_receive_pdu(self, origin, pdu, get_missing=True):
do auth checks and put it through the StateHandler. """ Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
auth_chain and state are None if we already have the necessary state Args:
and prev_events in the db origin (str): server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu (FrozenEvent): received PDU
get_missing (bool): True if we should fetch missing prev_events
Returns (Deferred): completes with None
""" """
event = pdu
logger.debug("Got event: %s", event.event_id) # We reprocess pdus when we have seen them only as outliers
existing = yield self.get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
return
# If we are currently in the process of joining this room, then we # If we are currently in the process of joining this room, then we
# queue up events for later processing. # queue up events for later processing.
if event.room_id in self.room_queues: if pdu.room_id in self.room_queues:
self.room_queues[event.room_id].append((pdu, origin)) logger.info("Ignoring PDU %s for room %s from %s for now; join "
"in progress", pdu.event_id, pdu.room_id, origin)
self.room_queues[pdu.room_id].append((pdu, origin))
return return
logger.debug("Processing event: %s", event.event_id) state = None
logger.debug("Event: %s", event) auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring lock for room %r to fetch %d missing events: %r...",
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired lock for room %r to fetch %d missing events",
pdu.room_id, len(prevs - seen),
)
yield self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.replication_layer.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self._process_received_pdu(
origin,
pdu,
state=state,
auth_chain=auth_chain,
)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
pdu: received pdu
prevs (str[]): List of event ids which we are missing
min_depth (int): Minimum depth of events to return.
Returns:
Deferred<dict(str, str?)>: updated have_seen dictionary
"""
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if not prevs - seen:
# nothing left to do
defer.returnValue(have_seen)
latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest |= seen
logger.info(
"Missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
# XXX: we set timeout to 10s to help workaround
# https://github.com/matrix-org/synapse/issues/1733.
# The reason is to avoid holding the linearizer lock
# whilst processing inbound /send transactions, causing
# FDs to stack up and block other inbound transactions
# which empirically can currently take up to 30 minutes.
#
# N.B. this explicitly disables retry attempts.
#
# N.B. this also increases our chances of falling back to
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.replication_layer.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
timeout=10000,
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self.on_receive_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
defer.returnValue(have_seen)
@log_function
@defer.inlineCallbacks
def _process_received_pdu(self, origin, pdu, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
event = pdu
logger.debug("Processing event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work # in the room work
@ -670,8 +858,6 @@ class FederationHandler(BaseHandler):
""" """
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
yield self.store.clean_room_for_join(room_id)
origin, event = yield self._make_and_verify_event( origin, event = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
@ -680,7 +866,15 @@ class FederationHandler(BaseHandler):
content, content,
) )
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
assert (room_id not in self.room_queues)
self.room_queues[room_id] = [] self.room_queues[room_id] = []
yield self.store.clean_room_for_join(room_id)
handled_events = set() handled_events = set()
try: try:
@ -733,17 +927,36 @@ class FederationHandler(BaseHandler):
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
del self.room_queues[room_id] del self.room_queues[room_id]
for p, origin in room_queue: # we don't need to wait for the queued events to be processed -
if p.event_id in handled_events: # it's just a best-effort thing at this point. We do want to do
continue # them roughly in order, though, otherwise we'll end up making
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
try: synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
self.on_receive_pdu(origin, p) room_queue
except: )
logger.exception("Couldn't handle pdu")
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks
def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
and the servers that sent them
"""
for p, origin in room_queue:
try:
logger.info("Processing queued PDU %s which was received "
"while we were joining %s", p.event_id, p.room_id)
yield self.on_receive_pdu(origin, p)
except Exception as e:
logger.warn(
"Error handling queued PDU %s from %s: %s",
p.event_id, origin, e)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, room_id, user_id):
@ -791,9 +1004,19 @@ class FederationHandler(BaseHandler):
) )
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
# Send this event on behalf of the origin server since they may not # Send this event on behalf of the origin server.
# have an up to data view of the state of the room at this event so #
# will not know which servers to send the event to. # The reasons we have the destination server rather than the origin
# server send it are slightly mysterious: the origin server should have
# all the neccessary state once it gets the response to the send_join,
# so it could send the event itself if it wanted to. It may be that
# doing it this way reduces failure modes, or avoids certain attacks
# where a new server selectively tells a subset of the federation that
# it has joined.
#
# The fact is that, as of the current writing, Synapse doesn't send out
# the join event over federation after joining, and changing it now
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -878,15 +1101,15 @@ class FederationHandler(BaseHandler):
user_id, user_id,
"leave" "leave"
) )
signed_event = self._sign_event(event) event = self._sign_event(event)
except SynapseError: except SynapseError:
raise raise
except CodeMessageException as e: except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e) logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite") raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/ # Try the host that we succesfully called /make_leave/ on first for
# request first. # the /send_leave/ request.
try: try:
target_hosts.remove(origin) target_hosts.remove(origin)
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
@ -896,7 +1119,7 @@ class FederationHandler(BaseHandler):
try: try:
yield self.replication_layer.send_leave( yield self.replication_layer.send_leave(
target_hosts, target_hosts,
signed_event event
) )
except SynapseError: except SynapseError:
raise raise
@ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
"""
Args:
origin:
event:
state:
auth_events:
Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext
"""
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state, event, old_state=state,
) )

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -150,7 +151,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs) params.update(kwargs)
try: try:
data = yield self.http_client.post_urlencoded_get_json( data = yield self.http_client.post_json_get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, id_server,
"/_matrix/identity/api/v1/validate/email/requestToken" "/_matrix/identity/api/v1/validate/email/requestToken"
@ -161,3 +162,37 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e
@defer.inlineCallbacks
def requestMsisdnToken(
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = {
'country': country,
'phone_number': phone_number,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

View File

@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
UserID, StreamToken, UserID, StreamToken,
@ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler):
"content": content, "content": content,
}) })
now = self.clock.time_msec()
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": [
{
"type": "m.presence",
"content": format_user_presence_state(event, now),
}
for event in presence
],
"account_data": account_data_events, "account_data": account_data_events,
"receipts": receipt, "receipts": receipt,
"end": now_token.to_string(), "end": now_token.to_string(),

View File

@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -556,9 +557,9 @@ class PresenceHandler(object):
room_ids_to_states = {} room_ids_to_states = {}
users_to_states = {} users_to_states = {}
for state in states: for state in states:
events = yield self.store.get_rooms_for_user(state.user_id) room_ids = yield self.store.get_rooms_for_user(state.user_id)
for e in events: for room_id in room_ids:
room_ids_to_states.setdefault(e.room_id, []).append(state) room_ids_to_states.setdefault(room_id, []).append(state)
plist = yield self.store.get_presence_list_observers_accepted(state.user_id) plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
for u in plist: for u in plist:
@ -574,8 +575,7 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
users = yield self.store.get_users_in_room(room_id) hosts = yield self.store.get_hosts_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
@ -719,9 +719,7 @@ class PresenceHandler(object):
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue(updates)
format_user_presence_state(state, now) for state in updates
])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False): def set_state(self, target_user, state, ignore_status_msg=False):
@ -795,6 +793,9 @@ class PresenceHandler(object):
as_event=False, as_event=False,
) )
now = self.clock.time_msec()
results[:] = [format_user_presence_state(r, now) for r in results]
is_accepted = { is_accepted = {
row["observed_user_id"]: row["accepted"] for row in presence_list row["observed_user_id"]: row["accepted"] for row in presence_list
} }
@ -847,6 +848,7 @@ class PresenceHandler(object):
) )
state_dict = yield self.get_state(observed_user, as_event=False) state_dict = yield self.get_state(observed_user, as_event=False)
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.send_edu( self.federation.send_edu(
destination=observer_user.domain, destination=observer_user.domain,
@ -910,11 +912,12 @@ class PresenceHandler(object):
def is_visible(self, observed_user, observer_user): def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence. """Returns whether a user can see another user's presence.
""" """
observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) observer_room_ids = yield self.store.get_rooms_for_user(
observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) observer_user.to_string()
)
observer_room_ids = set(r.room_id for r in observer_rooms) observed_room_ids = yield self.store.get_rooms_for_user(
observed_room_ids = set(r.room_id for r in observed_rooms) observed_user.to_string()
)
if observer_room_ids & observed_room_ids: if observer_room_ids & observed_room_ids:
defer.returnValue(True) defer.returnValue(True)
@ -979,14 +982,18 @@ def should_notify(old_state, new_state):
return False return False
def format_user_presence_state(state, now): def format_user_presence_state(state, now, include_user_id=True):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
""" """
content = { content = {
"presence": state.state, "presence": state.state,
"user_id": state.user_id,
} }
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts: if state.last_active_ts:
content["last_active_ago"] = now - state.last_active_ts content["last_active_ago"] = now - state.last_active_ts
if state.status_msg and state.state != PresenceState.OFFLINE: if state.status_msg and state.state != PresenceState.OFFLINE:
@ -1025,7 +1032,6 @@ class PresenceEventSource(object):
# sending down the rare duplicate is not a concern. # sending down the rare duplicate is not a concern.
with Measure(self.clock, "presence.get_new_events"): with Measure(self.clock, "presence.get_new_events"):
user_id = user.to_string()
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
@ -1034,18 +1040,7 @@ class PresenceEventSource(object):
max_token = self.store.get_current_presence_token() max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart) users_interested_in = yield self._get_interested_in(user, explicit_room_id)
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
@ -1073,15 +1068,12 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed) updates = yield presence.current_state_for_users(user_ids_changed)
now = self.clock.time_msec() if include_offline:
defer.returnValue((updates.values(), max_token))
else:
defer.returnValue(([ defer.returnValue(([
{ s for s in updates.itervalues()
"type": "m.presence", if s.state != PresenceState.OFFLINE
"content": format_user_presence_state(s, now),
}
for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE
], max_token)) ], max_token))
def get_current_key(self): def get_current_key(self):
@ -1090,6 +1082,31 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
updates for
"""
user_id = user.to_string()
plist = yield self.store.get_presence_list_accepted(
user.localpart, on_invalidate=cache_context.invalidate,
)
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate,
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate,
)
users_interested_in.update(user_ids)
defer.returnValue(users_interested_in)
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
@ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# If there are have been no sync for a while (and none ongoing), # If there are have been no sync for a while (and none ongoing),
# set presence to offline # set presence to offline
if user_id not in syncing_user_ids: if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: # If the user has done something recently but hasn't synced,
# don't set them as offline.
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,
status_msg=None, status_msg=None,

View File

@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "displayname", "field": "displayname",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "avatar_url", "field": "avatar_url",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -156,11 +158,11 @@ class ProfileHandler(BaseHandler):
self.ratelimit(requester) self.ratelimit(requester)
joins = yield self.store.get_rooms_for_user( room_ids = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),
) )
for j in joins: for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
@ -171,12 +173,12 @@ class ProfileHandler(BaseHandler):
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,
j.room_id, room_id,
"join", # We treat a profile update like a join. "join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic. ratelimit=False, # Try to hide that these events aren't atomic.
) )
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",
j.room_id, str(e.message) room_id, str(e.message)
) )

View File

@ -210,10 +210,9 @@ class ReceiptEventSource(object):
else: else:
from_key = None from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string()) room_ids = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms( events = yield self.store.get_linearized_receipts_for_rooms(
rooms, room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
) )

View File

@ -21,6 +21,7 @@ from synapse.api.constants import (
EventTypes, JoinRules, EventTypes, JoinRules,
) )
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
@ -62,6 +63,10 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
""" """
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
limit, since_token, bool(search_filter), network_tuple,
)
if search_filter: if search_filter:
# We explicitly don't bother caching searches or requests for # We explicitly don't bother caching searches or requests for
# appservice specific lists. # appservice specific lists.
@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
rooms_to_order_value = {} rooms_to_order_value = {}
rooms_to_num_joined = {} rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = [] newly_visible = []
newly_unpublished = [] newly_unpublished = []
@ -116,12 +120,18 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_order_for_room(room_id): def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) # Most of the rooms won't have changed between the since token and
if not latest_event_ids: # now (especially if the since token is "now"). So, we can ask what
# the current users are in a room (that will hit a cache) and then
# check if the room has changed since the since token. (We have to
# do it in that order to avoid races).
# If things have changed then fall back to getting the current state
# at the since token.
joined_users = yield self.store.get_users_in_room(room_id)
if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room( latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token room_id, stream_token
) )
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids: if not latest_event_ids:
return return
@ -129,6 +139,7 @@ class RoomListHandler(BaseHandler):
joined_users = yield self.state_handler.get_current_user_in_room( joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids, room_id, latest_event_ids,
) )
num_joined_users = len(joined_users) num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users rooms_to_num_joined[room_id] = num_joined_users
@ -165,19 +176,19 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse() rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to # Actually generate the entries. _append_room_entry_to_chunk will append to
# chunk but will stop if len(chunk) > limit # chunk but will stop if len(chunk) > limit
chunk = [] chunk = []
if limit and not search_filter: if limit and not search_filter:
step = limit + 1 step = limit + 1
for i in xrange(0, len(rooms_to_scan), step): for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop # We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry # at first iteration, but occaisonally _append_room_entry_to_chunk
# won't append to the chunk and so we need to loop again. # won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that # We don't want to scan over the entire range either as that
# would potentially waste a lot of work. # would potentially waste a lot of work.
yield concurrently_execute( yield concurrently_execute(
lambda r: self._generate_room_entry( lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r], r, rooms_to_num_joined[r],
chunk, limit, search_filter chunk, limit, search_filter
), ),
@ -187,7 +198,7 @@ class RoomListHandler(BaseHandler):
break break
else: else:
yield concurrently_execute( yield concurrently_execute(
lambda r: self._generate_room_entry( lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r], r, rooms_to_num_joined[r],
chunk, limit, search_filter chunk, limit, search_filter
), ),
@ -256,21 +267,35 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
search_filter): search_filter):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
"""
if limit and len(chunk) > limit + 1: if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it. # We've already got enough, so lets just drop it.
return return
result = yield self._generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
"""Returns the entry for a room
"""
result = { result = {
"room_id": room_id, "room_id": room_id,
"num_joined_members": num_joined_users, "num_joined_members": num_joined_users,
} }
current_state_ids = yield self.state_handler.get_current_state_ids(room_id) current_state_ids = yield self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate,
)
event_map = yield self.store.get_events([ event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items() event_id for key, event_id in current_state_ids.iteritems()
if key[0] in ( if key[0] in (
EventTypes.JoinRules, EventTypes.JoinRules,
EventTypes.Name, EventTypes.Name,
@ -294,7 +319,9 @@ class RoomListHandler(BaseHandler):
if join_rule and join_rule != JoinRules.PUBLIC: if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None) defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
if aliases: if aliases:
result["aliases"] = aliases result["aliases"] = aliases
@ -334,8 +361,7 @@ class RoomListHandler(BaseHandler):
if avatar_url: if avatar_url:
result["avatar_url"] = avatar_url result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter): defer.returnValue(result)
chunk.append(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None, def get_remote_public_room_list(self, server_name, limit=None, since_token=None,

View File

@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.types import RoomStreamToken
from twisted.internet import defer from twisted.internet import defer
@ -225,8 +226,7 @@ class SyncHandler(object):
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
@ -568,16 +568,15 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key: if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set() user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed( changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key since_token.device_list_key
) )
for other_user_id in changed: for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id) other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms): if room_ids.intersection(other_room_ids):
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed) defer.returnValue(user_ids_changed)
@ -721,14 +720,14 @@ class SyncHandler(object):
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())
if extra_users_ids:
states = yield self.presence_handler.get_states( states = yield self.presence_handler.get_states(
extra_users_ids, extra_users_ids,
as_event=True,
) )
presence.extend(states) presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user # Deduplicate the presence entries so that there's at most one per user
presence = {p["content"]["user_id"]: p for p in presence}.values() presence = {p.user_id: p for p in presence}.values()
presence = sync_config.filter_collection.filter_presence( presence = sync_config.filter_collection.filter_presence(
presence presence
@ -765,6 +764,21 @@ class SyncHandler(object):
) )
sync_result_builder.now_token = now_token sync_result_builder.now_token = now_token
# We check up front if anything has changed, if it hasn't then there is
# no point in going futher.
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = yield self._have_rooms_changed(sync_result_builder)
if not have_changed:
tags_by_room = yield self.store.get_updated_tags(
user_id,
since_token.account_data_key,
)
if not tags_by_room:
logger.debug("no-oping sync")
defer.returnValue(([], []))
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id, "m.ignored_user_list", user_id=user_id,
) )
@ -774,13 +788,12 @@ class SyncHandler(object):
else: else:
ignored_users = frozenset() ignored_users = frozenset()
if sync_result_builder.since_token: if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users) res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res room_entries, invited, newly_joined_rooms = res
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
user_id, user_id, since_token.account_data_key,
sync_result_builder.since_token.account_data_key,
) )
else: else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users) res = yield self._get_all_rooms(sync_result_builder, ignored_users)
@ -805,7 +818,7 @@ class SyncHandler(object):
# Now we want to get any newly joined users # Now we want to get any newly joined users
newly_joined_users = set() newly_joined_users = set()
if sync_result_builder.since_token: if since_token:
for joined_sync in sync_result_builder.joined: for joined_sync in sync_result_builder.joined:
it = itertools.chain( it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.values() joined_sync.timeline.events, joined_sync.state.values()
@ -817,6 +830,38 @@ class SyncHandler(object):
defer.returnValue((newly_joined_rooms, newly_joined_users)) defer.returnValue((newly_joined_rooms, newly_joined_users))
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
assert since_token
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
if rooms_changed:
defer.returnValue(True)
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users): def _get_rooms_changed(self, sync_result_builder, ignored_users):
"""Gets the the changes that have happened since the last sync. """Gets the the changes that have happened since the last sync.
@ -841,8 +886,7 @@ class SyncHandler(object):
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)
else: else:
rooms = yield self.store.get_rooms_for_user(user_id) joined_room_ids = yield self.store.get_rooms_for_user(user_id)
joined_room_ids = set(r.room_id for r in rooms)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(

View File

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.client import readBody, HTTPConnectionPool, Agent
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool reactor, MatrixFederationEndpointFactory(hs), pool=pool
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
@ -103,12 +103,40 @@ class MatrixFederationHttpClient(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True, query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False): timeout=None, long_retries=False,
""" Creates and sends a request to the given url ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
Returns:
Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
)
destination = destination.encode("ascii")
path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
@ -144,8 +172,7 @@ class MatrixFederationHttpClient(object):
try: try:
def send_request(): def send_request():
request_deferred = preserve_context_over_fn( request_deferred = self.agent.request(
self.agent.request,
method, method,
url_bytes, url_bytes,
Headers(headers_dict), Headers(headers_dict),
@ -157,7 +184,8 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn(send_request) with logcontext.PreserveLoggingContext():
response = yield send_request()
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break
@ -214,7 +242,8 @@ class MatrixFederationHttpClient(object):
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -248,7 +277,9 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False, timeout=None): long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -263,11 +294,19 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests)
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
if not json_data_callback: if not json_data_callback:
@ -282,26 +321,29 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data) producer = _JsonProducer(json_data)
return producer return producer
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"PUT", "PUT",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
backoff_on_404=backoff_on_404,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
timeout=None): timeout=None, ignore_backoff=False):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -314,11 +356,15 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -327,27 +373,29 @@ class MatrixFederationHttpClient(object):
) )
return _JsonProducer(data) return _JsonProducer(data)
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"POST", "POST",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True, def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None): timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
Args: Args:
@ -359,11 +407,16 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will giving up. None indicates no timeout and that the request will
be retried. be retried.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`, The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body. where `response` is a dict representing the decoded JSON body.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
@ -380,36 +433,47 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None): retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
"""GETs a file from a given homeserver """GETs a file from a given homeserver
Args: Args:
destination (str): The remote server to send the HTTP request to. destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET. path (str): The HTTP path to GET.
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string. args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
A (int,dict) tuple of the file length and a dict of the response Deferred: resolves with an (int,dict) tuple of the file length and
headers. a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code
>= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
encoded_args = {} encoded_args = {}
@ -419,26 +483,27 @@ class MatrixFederationHttpClient(object):
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True) query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
length = yield preserve_context_over_fn( with logcontext.PreserveLoggingContext():
_readBodyToFile, length = yield _readBodyToFile(
response, output_stream, max_size response, output_stream, max_size
) )
except: except:

View File

@ -192,6 +192,16 @@ def parse_json_object_from_request(request):
return content return content
def assert_params_in_request(body, required):
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
from synapse.util import DeferredTimedOutError from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
notified_events_counter = metrics.register_counter("notified_events") notified_events_counter = metrics.register_counter("notified_events")
users_woken_by_stream_counter = metrics.register_counter(
"users_woken_by_stream", labels=["stream"]
)
# TODO(paul): Should be shared somewhere # TODO(paul): Should be shared somewhere
def count(func, l): def count(func, l):
@ -73,6 +78,13 @@ class _NotifierUserStream(object):
self.user_id = user_id self.user_id = user_id
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
# The last token for which we should wake up any streams that have a
# token that comes before it. This gets updated everytime we get poked.
# We start it at the current token since if we get any streams
# that have a token from before we have no idea whether they should be
# woken up or not, so lets just wake them up.
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -89,9 +101,12 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred noify_deferred = self.notify_deferred
users_woken_by_stream_counter.inc(stream_key)
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
@ -113,8 +128,14 @@ class _NotifierUserStream(object):
def new_listener(self, token): def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token """Returns a deferred that is resolved when there is a new token
greater than the given token. greater than the given token.
Args:
token: The token from which we are streaming from, i.e. we shouldn't
notify for things that happened before this.
""" """
if self.current_token.is_after(token): # Immediately wake up stream if something has already since happened
# since their last token.
if self.last_notified_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token)) return _NotificationListener(defer.succeed(self.current_token))
else: else:
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
@ -283,8 +304,7 @@ class Notifier(object):
if user_stream is None: if user_stream is None:
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id) room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
@ -294,40 +314,44 @@ class Notifier(object):
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
result = None result = None
prev_token = from_token
if timeout: if timeout:
end_time = self.clock.time_msec() + timeout end_time = self.clock.time_msec() + timeout
prev_token = from_token
while not result: while not result:
try: try:
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
now = self.clock.time_msec() now = self.clock.time_msec()
if end_time <= now: if end_time <= now:
break break
# Now we wait for the _NotifierUserStream to be told there # Now we wait for the _NotifierUserStream to be told there
# is a new token. # is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token) listener = user_stream.new_listener(prev_token)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield self.clock.time_bound_deferred( yield self.clock.time_bound_deferred(
listener.deferred, listener.deferred,
time_out=(end_time - now) / 1000. time_out=(end_time - now) / 1000.
) )
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
except DeferredTimedOutError: except DeferredTimedOutError:
break break
except defer.CancelledError: except defer.CancelledError:
break break
else:
if result is None:
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token current_token = user_stream.current_token
result = yield callback(from_token, current_token) result = yield callback(prev_token, current_token)
defer.returnValue(result) defer.returnValue(result)
@ -388,6 +412,15 @@ class Notifier(object):
new_events, new_events,
is_peeking=is_peeking, is_peeking=is_peeking,
) )
elif name == "presence":
now = self.clock.time_msec()
new_events[:] = [
{
"type": "m.presence",
"content": format_user_presence_state(event, now),
}
for event in new_events
]
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
@ -420,8 +453,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id): def _get_room_ids(self, user, explicit_room_id):
joined_rooms = yield self.store.get_rooms_for_user(user.to_string()) joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
joined_room_ids = map(lambda r: r.room_id, joined_rooms)
if explicit_room_id: if explicit_room_id:
if explicit_room_id in joined_room_ids: if explicit_room_id in joined_room_ids:
defer.returnValue(([explicit_room_id], True)) defer.returnValue(([explicit_room_id], True))

View File

@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_room_state(room_id): def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state_ids(room_id) room_state = yield self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email

View File

@ -17,6 +17,7 @@ import logging
import re import re
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
return self._value_cache.get(dotted_key, None) return self._value_cache.get(dotted_key, None)
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
register_cache("regex_push_cache", regex_cache)
def _glob_matches(glob, value, word_boundary=False): def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob. """Tests if value matches glob.
@ -137,7 +143,29 @@ def _glob_matches(glob, value, word_boundary=False):
Returns: Returns:
bool bool
""" """
try: try:
r = regex_cache.get((glob, word_boundary), None)
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, word_boundary)] = r
return r.search(value)
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
def _glob_to_re(glob, word_boundary):
"""Generates regex for a given glob.
Args:
glob (string)
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
regex object
"""
if IS_GLOB.search(glob): if IS_GLOB.search(glob):
r = re.escape(glob) r = re.escape(glob)
@ -156,25 +184,20 @@ def _glob_matches(glob, value, word_boundary=False):
) )
if word_boundary: if word_boundary:
r = r"\b%s\b" % (r,) r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value) return re.compile(r, flags=re.IGNORECASE)
else: else:
r = r + "$" r = "^" + r + "$"
r = _compile_regex(r)
return r.match(value) return re.compile(r, flags=re.IGNORECASE)
elif word_boundary: elif word_boundary:
r = re.escape(glob) r = re.escape(glob)
r = r"\b%s\b" % (r,) r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value) return re.compile(r, flags=re.IGNORECASE)
else: else:
return value.lower() == glob.lower() r = "^" + re.escape(glob) + "$"
except re.error: return re.compile(r, flags=re.IGNORECASE)
logger.warn("Failed to parse glob to regex: %r", glob)
return False
def _flatten_dict(d, prefix=[], result={}): def _flatten_dict(d, prefix=[], result={}):
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
_flatten_dict(value, prefix=(prefix + [key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result
regex_cache = LruCache(5000)
def _compile_regex(regex_str):
r = regex_cache.get(regex_str, None)
if r:
return r
r = re.compile(regex_str, flags=re.IGNORECASE)
regex_cache[regex_str] = r
return r

View File

@ -33,13 +33,13 @@ def get_badge_count(store, user_id):
badge = len(invites) badge = len(invites)
for r in joins: for room_id in joins:
if r.room_id in my_receipts_by_room: if room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id] last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield ( notifs = yield (
store.get_unread_event_push_actions_by_room_for_user( store.get_unread_event_push_actions_by_room_for_user(
r.room_id, user_id, last_unread_event_id room_id, user_id, last_unread_event_id
) )
) )
# return one badge count per conversation, as count per # return one badge count per conversation, as count per

View File

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +19,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
@ -37,6 +39,7 @@ REQUIREMENTS = {
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View File

@ -283,12 +283,12 @@ class ReplicationResource(Resource):
if request_events != upto_events_token: if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, ( writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group" "position", "event_id", "room_id", "type", "state_key",
), position=upto_events_token) ), position=upto_events_token)
if request_backfill != upto_backfill_token: if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, ( writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group", "position", "event_id", "room_id", "type", "state_key", "redacts",
), position=upto_backfill_token) ), position=upto_backfill_token)
writer.write_header_and_rows( writer.write_header_and_rows(

View File

@ -27,4 +27,9 @@ class SlavedIdTracker(object):
self._current = (max if self.step > 0 else min)(self._current, new_id) self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self): def get_current_token(self):
"""
Returns:
int
"""
return self._current return self._current

View File

@ -16,7 +16,6 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
@ -25,7 +24,6 @@ from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
import logging import logging
@ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore):
get_recent_event_ids_for_room = ( get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
@ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__ _get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
@ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore):
return super(SlavedEventStore, self).process_replication(result) return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled): def _process_replication_row(self, row, backfilled):
internal = json.loads(row[1]) stream_ordering = row[0] if not backfilled else -row[0]
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event( self.invalidate_caches_for_event(
event, backfilled, stream_ordering, row[1], row[2], row[3], row[4], row[5],
backfilled=backfilled,
) )
def invalidate_caches_for_event(self, event, backfilled): def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
self._invalidate_get_event_cache(event.event_id) etype, state_key, redacts, backfilled):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,)) self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many( self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
(event.room_id,) (room_id,)
) )
if not backfilled: if not backfilled:
self._events_stream_cache.entity_has_changed( self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering room_id, stream_ordering
) )
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many( if redacts:
# (event.room_id,) self._invalidate_get_event_cache(redacts)
# )
if event.type == EventTypes.Redaction: if etype == EventTypes.Member:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering state_key, stream_ordering
) )
self.get_invited_rooms_for_user.invalidate((event.state_key,)) self.get_invited_rooms_for_user.invalidate((state_key,))
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return

View File

@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
self.presence_stream_cache.entity_has_changed( self.presence_stream_cache.entity_has_changed(
user_id, position user_id, position
) )
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result) return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,10 +34,55 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from twisted.web.client import PartialDownloadError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or "number" not in identifier:
raise SynapseError(400, "Invalid phone-type identifier")
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
@ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission: if "password" not in login_submission:
address = login_submission['address'] raise SynapseError(400, "Missing parameter: password")
if login_submission['medium'] == 'email':
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier:
raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address']
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], address identifier['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else:
user_id = login_submission['user'] identifier = {
"type": "m.id.user",
"user": user_id,
}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(
@ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet):
"ticket": request.args["ticket"], "ticket": request.args["ticket"],
"service": self.cas_service_url "service": self.cas_service_url
} }
try:
body = yield http_client.get_raw(uri, args) body = yield http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url) result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result) defer.returnValue(result)

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import UserID
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs) super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not allowed to see their presence.") raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user) state = yield self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
defer.returnValue((200, state)) defer.returnValue((200, state))

View File

@ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
rooms = yield self.store.get_rooms_for_user(requester.user.to_string()) room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
defer.returnValue((200, {"joined_rooms": list(room_ids)})) defer.returnValue((200, {"joined_rooms": list(room_ids)}))

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,8 +18,11 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -28,11 +32,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__() super(EmailPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] assert_params_in_request(body, [
absent = [] 'id_server', 'client_secret', 'email', 'send_attempt'
for k in required: ])
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.datastore = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_v2_patterns("/account/password$")
@ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet):
authed, result, params, _ = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
@ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower() threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address'] threepid['medium'], threepid['address']
) )
if not threepid_user_id: if not threepid_user_id:
@ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__() super(EmailThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet):
if absent: if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
]
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_v2_patterns("/account/3pid$")
@ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.datastore.user_get_threepids(
requester.user.to_string() requester.user.to_string()
) )
@ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']: if 'bind' in body and body['bind']:
logger.debug( logger.debug(
"Binding emails %s to %s", "Binding threepid %s to %s",
threepid, user_id threepid, user_id
) )
yield self.identity_handler.bind_threepid( yield self.identity_handler.bind_threepid(
@ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server) EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server)

View File

@ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
try:
body = servlet.parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
if 'devices' not in body:
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
body['devices'],
)
defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False) releases=[], v2_alpha=False)
@ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd # Copyright 2015 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,7 +20,10 @@ import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -43,7 +47,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
@ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(RegisterRequestTokenRestServlet, self).__init__() super(EmailRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] assert_params_in_request(body, [
absent = [] 'id_server', 'client_secret', 'email', 'send_attempt'
for k in required: ])
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_v2_patterns("/register$")
@ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id, assigned_user_id=registered_user_id,
) )
# Only give msisdn flows if the x_show_msisdn flag is given:
# this is a hack to work around the fact that clients were shipped
# that use fallback registration if they see any flows that they don't
# recognise, which means we break registration for these clients if we
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
# Android <=0.6.9 have fallen below an acceptable threshold, this
# parameter should go away and we should always advertise msisdn flows.
show_msisdn = False
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
[LoginType.RECAPTCHA], [LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
] ]
if show_msisdn:
flows.extend([
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else: else:
flows = [ flows = [
[LoginType.DUMMY], [LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY],
] ]
if show_msisdn:
flows.extend([
[LoginType.MSISDN],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
@ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session", "Already registered user ID %r for this session",
registered_user_id registered_user_id
) )
# don't re-register the email address # don't re-register the threepids
add_email = False add_email = False
add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
@ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet):
) )
add_email = True add_email = True
add_msisdn = True
return_dict = yield self._create_registration_details( return_dict = yield self._create_registration_details(
registered_user_id, params registered_user_id, params
@ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet):
params.get("bind_email") params.get("bind_email")
) )
if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(
registered_user_id, threepid, return_dict["access_token"],
params.get("bind_msisdn")
)
defer.returnValue((200, return_dict)) defer.returnValue((200, return_dict))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet):
""" """
reqd = ('medium', 'address', 'validated_at') reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd): if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
defer.returnValue() return
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
@ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
"""Add a phone number as a 3pid identifier
Also optionally binds msisdn to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_registration_details(self, user_id, params): def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
@ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
self.device_handler.check_device_registered( yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
@ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server) EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean RestServlet, parse_string, parse_integer, parse_boolean
) )
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
@ -28,7 +29,6 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
import copy
import itertools import itertools
import logging import logging
@ -194,12 +194,18 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content)) defer.returnValue((200, response_content))
def encode_presence(self, events, time_now): def encode_presence(self, events, time_now):
formatted = [] return {
for event in events: "events": [
event = copy.deepcopy(event) {
event['sender'] = event['content'].pop('user_id') "type": "m.presence",
formatted.append(event) "sender": event.user_id,
return {"events": formatted} "content": format_user_presence_state(
event, time_now, include_user_id=False
),
}
for event in events
]
}
def encode_joined(self, rooms, time_now, token_id, event_fields): def encode_joined(self, rooms, time_now, token_id, event_fields):
""" """

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404 from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -81,6 +82,17 @@ class DownloadResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name): def _respond_remote_file(self, request, server_name, media_id, name):
# don't forward requests for remote media if allow_remote is false
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
media_info = yield self.media_repo.get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"] media_type = media_info["media_type"]

View File

@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, threads
import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
from .upload_resource import UploadResource from .upload_resource import UploadResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths from .filepath import MediaFilePaths
from twisted.web.resource import Resource
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, HttpResponseException, \
NotFoundError
from twisted.internet import defer, threads
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@ -157,11 +158,34 @@ class MediaRepository(object):
try: try:
length, headers = yield self.client.get_file( length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f, server_name, request_path, output_stream=f,
max_size=self.max_upload_size, max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
) )
except Exception as e: except twisted.internet.error.DNSLookupError as e:
logger.warn("Failed to fetch remoted media %r", e) logger.warn("HTTP error fetching remote media %s/%s: %r",
raise SynapseError(502, "Failed to fetch remoted media") server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View File

@ -177,17 +177,12 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """Build an EventContext structure for the event.
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
If `event` has `auth_events` then this will also fill out the
`auth_events` field on `context` from the `current_state`.
Args: Args:
event (EventBase) event (synapse.events.EventBase):
Returns: Returns:
an EventContext synapse.events.snapshot.EventContext:
""" """
context = EventContext() context = EventContext()
@ -200,11 +195,11 @@ class StateHandler(object):
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state(): if event.is_state():
context.current_state_events = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
key = (event.type, event.state_key) key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id context.current_state_ids[key] = event.event_id
else: else:
context.current_state_events = context.prev_state_ids context.current_state_ids = context.prev_state_ids
else: else:
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {} context.prev_state_ids = {}

View File

@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self.txn, name, value) setattr(self.txn, name, value)
def __iter__(self):
return self.txn.__iter__()
def execute(self, sql, *args): def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3): def interval(self, interval_duration, limit=3):
counters = [] counters = []
for name, (count, cum_time) in self.current_counters.items(): for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(( counters.append((
(cum_time - prev_time) / interval_duration, (cum_time - prev_time) / interval_duration,
@ -357,7 +360,7 @@ class SQLBaseStore(object):
""" """
col_headers = list(intern(column[0]) for column in cursor.description) col_headers = list(intern(column[0]) for column in cursor.description)
results = list( results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall() dict(zip(col_headers, row)) for row in cursor
) )
return results return results
@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol, def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"): desc="_simple_select_onecol"):
@ -712,7 +715,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.items(): for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -840,6 +843,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
txn : Transaction object
table : string giving the table name
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
"""
if not iterable:
return
sql = "DELETE FROM %s" % table
clauses = []
values = []
clauses.append(
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
values.extend(iterable)
for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
if clauses:
sql = "%s WHERE %s" % (
sql,
" AND ".join(clauses),
)
return txn.execute(sql, values)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000): max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
@ -860,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
for row in rows for row in txn
} }
txn.close()
if cache: if cache:
min_val = min(cache.values()) min_val = min(cache.itervalues())
else: else:
min_val = max_value min_val = max_value

View File

@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
global_account_data = { global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall() row[0]: json.loads(row[1]) for row in txn
} }
sql = ( sql = (
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
account_data_by_room = {} account_data_by_room = {}
for row in txn.fetchall(): for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2]) room_account_data[row[1]] = json.loads(row[2])

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.async
from ._base import SQLBaseStore from ._base import SQLBaseStore
from . import engines from . import engines
@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_performance = {} self._background_update_performance = {}
self._background_update_queue = [] self._background_update_queue = []
self._background_update_handlers = {} self._background_update_handlers = {}
self._background_update_timer = None
@defer.inlineCallbacks @defer.inlineCallbacks
def start_doing_background_updates(self): def start_doing_background_updates(self):
assert self._background_update_timer is None, \
"background updates already running"
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while True: while True:
sleep = defer.Deferred() yield synapse.util.async.sleep(
self._background_update_timer = self._clock.call_later( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
)
try:
yield sleep
finally:
self._background_update_timer = None
try: try:
result = yield self.do_next_background_update( result = yield self.do_next_background_update(

View File

@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"]) message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall(): for row in txn:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] device = row[0]
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + devices)
for row in txn.fetchall(): for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device = row[0]
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit user_id, device_id, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall()) rows.extend(txn)
return rows return rows
@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
""" """
Args: Args:
destination(str): The name of the remote server. destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream last_stream_id(int|long): The last position of the device message stream
that the server sent up to. that the server sent up to.
current_stream_id(int): The current position of the device current_stream_id(int|long): The current position of the device
message stream. message stream.
Returns: Returns:
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit destination, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:

View File

@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
desc="delete_device", desc="delete_device",
) )
def delete_devices(self, user_id, device_ids):
"""Deletes several devices.
Args:
user_id (str): The ID of the user which owns the devices
device_ids (list): The IDs of the devices to delete
Returns:
defer.Deferred
"""
return self._simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
keyvalues={"user_id": user_id},
desc="delete_devices",
)
def update_device(self, user_id, device_id, new_display_name=None): def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device. """Update a device.
@ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers """Get stream of updates to send to remote servers
Returns: Returns:
(now_stream_id, [ { updates }, .. ]) (int, list[dict]): current stream id and list of updates
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
@ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id GROUP BY user_id, device_id
LIMIT 20
""" """
txn.execute( txn.execute(
sql, (destination, from_stream_id, now_stream_id, False) sql, (destination, from_stream_id, now_stream_id, False)
) )
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows} query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
) )

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import ujson as json import ujson as json
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
return result return result
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn): """Insert some new one time keys for a device.
for (algorithm, key_id, json_bytes) in key_list:
self._simple_upsert_txn( Checks if any of the keys are already inserted, if they are then check
txn, table="e2e_one_time_keys_json", if they match. If they don't then we raise an error.
"""
# First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=[key_id for _, key_id, _ in key_list],
retcols=("algorithm", "key_id", "key_json",),
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
},
desc="add_e2e_one_time_keys_check",
)
existing_key_map = {
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
}
new_keys = [] # Keys that we need to insert
for algorithm, key_id, json_bytes in key_list:
ex_bytes = existing_key_map.get((algorithm, key_id), None)
if ex_bytes:
if json_bytes != ex_bytes:
raise SynapseError(
400, "One time key with key_id %r already exists" % (key_id,)
)
else:
new_keys.append((algorithm, key_id, json_bytes))
def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self._simple_insert_many_txn(
txn, table="e2e_one_time_keys_json",
values=[
{
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm, "algorithm": algorithm,
"key_id": key_id, "key_id": key_id,
},
values={
"ts_added_ms": time_now, "ts_added_ms": time_now,
"key_json": json_bytes, "key_json": json_bytes,
} }
for algorithm, key_id, json_bytes in new_keys
],
) )
return self.runInteraction( yield self.runInteraction(
"add_e2e_one_time_keys", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
def count_e2e_one_time_keys(self, user_id, device_id): def count_e2e_one_time_keys(self, user_id, device_id):
@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
) )
txn.execute(sql, (user_id, device_id)) txn.execute(sql, (user_id, device_id))
result = {} result = {}
for algorithm, key_count in txn.fetchall(): for algorithm, key_count in txn:
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.runInteraction( return self.runInteraction(
@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall(): for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))
sql = ( sql = (

View File

@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),), base_sql % (",".join(["?"] * len(chunk)),),
chunk chunk
) )
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,)) txn.execute(sql, (room_id, False,))
return dict(txn.fetchall()) return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
@ -201,9 +201,9 @@ class EventFederationStore(SQLBaseStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth): def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id) min_depth = self._get_min_depth_interaction(txn, room_id)
do_insert = depth < min_depth if min_depth else True if min_depth and depth >= min_depth:
return
if do_insert:
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="room_depth", table="room_depth",
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall() return [event_id for event_id, in txn]
return [event_id for event_id, in rows]
return self.runInteraction( return self.runInteraction(
"get_forward_extremeties_for_room", "get_forward_extremeties_for_room",
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for row in txn.fetchall(): for row in txn:
if row[1] not in event_results: if row[1] not in event_results:
queue.put((-row[0], row[1])) queue.put((-row[0], row[1]))
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn:
new_front.add(e_id) new_front.add(e_id)
new_front -= earliest_events new_front -= earliest_events

View File

@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?" " stream_ordering >= ? AND stream_ordering <= ?"
) )
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f) ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret) defer.returnValue(ret)

View File

@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict from collections import deque, namedtuple, OrderedDict
from functools import wraps from functools import wraps
import synapse
import synapse.metrics import synapse.metrics
import logging import logging
import math import math
import ujson as json import ujson as json
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled): def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options. """Add events to the queue, with the given persist_event options.
Args:
room_id (str):
events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool):
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
) )
deferreds.append(d) deferreds.append(d)
for room_id in partitioned.keys(): for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return preserve_context_over_deferred( return preserve_context_over_deferred(
@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False): def persist_event(self, event, context, backfilled=False):
"""
Args:
event (EventBase):
context (EventContext):
backfilled (bool):
Returns:
Deferred: resolves to (int, int): the stream ordering of ``event``,
and the stream ordering of the latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], event.room_id, [(event, context)],
backfilled=backfilled, backfilled=backfilled,
@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False, def _persist_events(self, events_and_contexts, backfilled=False,
delete_existing=False): delete_existing=False):
"""Persist events to db
Args:
events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool):
delete_existing (bool):
Returns:
Deferred: resolves when the events have been persisted
"""
if not events_and_contexts: if not events_and_contexts:
return return
@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context) (event, context)
) )
for room_id, ev_ctx_rm in events_by_room.items(): for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing # Work out new extremities by recursively adding and removing
# the new events. # the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room( latest_event_ids = yield self.get_latest_event_ids_in_room(
@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
# Now we need to work out the different state sets for # Now we need to work out the different state sets for
# each state extremities # each state extremities
state_sets = [] state_sets = []
state_groups = set()
missing_event_ids = [] missing_event_ids = []
was_updated = False was_updated = False
for event_id in new_latest_event_ids: for event_id in new_latest_event_ids:
@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
if event_id == ev.event_id: if event_id == ev.event_id:
if ctx.current_state_ids is None: if ctx.current_state_ids is None:
raise Exception("Unknown current state") raise Exception("Unknown current state")
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_sets.append(ctx.current_state_ids) state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"): if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True was_updated = True
if ctx.state_group:
# Add this as a seen state group (if it has a state
# group)
state_groups.add(ctx.state_group)
break break
else: else:
# If we couldn't find it, then we'll need to pull # If we couldn't find it, then we'll need to pull
@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
missing_event_ids, missing_event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues()) - state_groups
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values()) if groups:
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids: if not new_latest_event_ids:
current_state = {} current_state = {}
elif was_updated: elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
current_state = state_sets[0]
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)
if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
current_state = yield resolve_events( current_state = yield resolve_events(
state_sets, state_sets,
state_map_factory=lambda ev_ids: self.get_events( state_map_factory=get_events,
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
else: else:
return return
existing_state_rows = yield self._simple_select_list( existing_state = yield self.get_current_state_ids(room_id)
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows) existing_events = set(existing_state.itervalues())
new_events = set(ev_id for ev_id in current_state.itervalues()) new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events changed_events = existing_events ^ new_events
@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
return return
to_delete = { to_delete = {
(row["type"], row["state_key"]): row["event_id"] key: ev_id for key, ev_id in existing_state.iteritems()
for row in existing_state_rows if ev_id in changed_events
if row["event_id"] in changed_events
} }
events_to_insert = (new_events - existing_events) events_to_insert = (new_events - existing_events)
to_insert = { to_insert = {
@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
and the rejections table. Things reading from those table will need to check and the rejections table. Things reading from those table will need to check
whether the event was rejected. whether the event was rejected.
If delete_existing is True then existing events will be purged from the Args:
database before insertion. This is useful when retrying due to IntegrityError. txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]):
events to persist
backfilled (bool): True if the events were backfilled
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
current_state_for_room (dict[str, (list[str], list[str])]):
The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of event ids to be removed
from the current state, and a list of event ids to be added to
the current state.
new_forward_extremeties (dict[str, list[str]]):
The new forward extremities for each room. For each room, a
list of the event ids which are the forward extremities.
""" """
self._update_current_state_txn(txn, current_state_for_room)
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state_tuple in current_state_for_room.iteritems(): self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
max_stream_order=max_stream_order,
)
# Ensure that we don't have the same event twice.
events_and_contexts = self._filter_events_and_contexts_for_duplicates(
events_and_contexts,
)
self._update_room_depths_txn(
txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
events_and_contexts = self._update_outliers_txn(
txn,
events_and_contexts=events_and_contexts,
)
# From this point onwards the events are only events that we haven't
# seen before.
if delete_existing:
# For paranoia reasons, we go and delete all the existing entries
# for these events so we can reinsert them.
# This gets around any problems with some tables already having
# entries.
self._delete_existing_rows_txn(
txn,
events_and_contexts=events_and_contexts,
)
self._store_event_txn(
txn,
events_and_contexts=events_and_contexts,
)
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
events_and_contexts = self._store_rejected_events_txn(
txn,
events_and_contexts=events_and_contexts,
)
# From this point onwards the events are only ones that weren't
# rejected.
self._update_metadata_tables_txn(
txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
)
def _update_current_state_txn(self, txn, state_delta_by_room):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple to_delete, to_insert = current_state_tuple
txn.executemany( txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?", "DELETE FROM current_state_events WHERE event_id = ?",
@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
txn, self.get_users_in_room, (room_id,) txn, self.get_users_in_room, (room_id,)
) )
for room_id, new_extrem in new_forward_extremeties.items(): self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="event_forward_extremities", table="event_forward_extremities",
@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id, "event_id": ev_id,
"room_id": room_id, "room_id": room_id,
} }
for room_id, new_extrem in new_forward_extremeties.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem for ev_id in new_extrem
], ],
) )
@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"stream_ordering": max_stream_order, "stream_ordering": max_stream_order,
} }
for room_id, new_extrem in new_forward_extremeties.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem for event_id in new_extrem
] ]
) )
# Ensure that we don't have the same event twice. @classmethod
# Pick the earliest non-outlier if there is one, else the earliest one. def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = OrderedDict() new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts: for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id) prev_event_context = new_events_and_contexts.get(event.event_id)
@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
new_events_and_contexts[event.event_id] = (event, context) new_events_and_contexts[event.event_id] = (event, context)
else: else:
new_events_and_contexts[event.event_id] = (event, context) new_events_and_contexts[event.event_id] = (event, context)
return new_events_and_contexts.values()
events_and_contexts = new_events_and_contexts.values() def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
"""Update min_depth for each room
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth) event.depth, depth_updates.get(event.room_id, event.depth)
) )
for room_id, depth in depth_updates.items(): for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth) self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
"""Update any outliers with new event info.
This turns outliers into ex-outliers (unless the new event was
rejected).
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
Returns:
list[(EventBase, EventContext)] new list, without events which
are already in the events table.
"""
txn.execute( txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)), ",".join(["?"] * len(events_and_contexts)),
@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
have_persisted = { have_persisted = {
event_id: outlier event_id: outlier
for event_id, outlier in txn.fetchall() for event_id, outlier in txn
} }
to_remove = set() to_remove = set()
for event, context in events_and_contexts: for event, context in events_and_contexts:
if context.rejected:
# If the event is rejected then we don't care if the event
# was an outlier or not.
if event.event_id in have_persisted:
# If we have already seen the event then ignore it.
to_remove.add(event)
continue
if event.event_id not in have_persisted: if event.event_id not in have_persisted:
continue continue
to_remove.add(event) to_remove.add(event)
if context.rejected:
# If the event is rejected then we don't care if the event
# was an outlier or not.
continue
outlier_persisted = have_persisted[event.event_id] outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted: if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as # We received a copy of an event that we had already stored as
@ -741,34 +918,16 @@ class EventsStore(SQLBaseStore):
# event isn't an outlier any more. # event isn't an outlier any more.
self._update_backward_extremeties(txn, [event]) self._update_backward_extremeties(txn, [event])
events_and_contexts = [ return [
ec for ec in events_and_contexts if ec[0] not in to_remove ec for ec in events_and_contexts if ec[0] not in to_remove
] ]
@classmethod
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
if not events_and_contexts: if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to # nothing to do here
# be storing at least one element.
return return
# From this point onwards the events are only events that we haven't
# seen before.
def event_dict(event):
return {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
if delete_existing:
# For paranoia reasons, we go and delete all the existing entries
# for these events so we can reinsert them.
# This gets around any problems with some tables already having
# entries.
logger.info("Deleting existing") logger.info("Deleting existing")
for table in ( for table in (
@ -800,6 +959,25 @@ class EventsStore(SQLBaseStore):
[(ev.event_id,) for ev, _ in events_and_contexts] [(ev.event_id,) for ev, _ in events_and_contexts]
) )
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
"""
if not events_and_contexts:
# nothing to do here
return
def event_dict(event):
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
return d
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
], ],
) )
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
Returns:
list[(EventBase, EventContext)] new list, without the rejected
events.
"""
# Remove the rejected events from the list now that we've added them # Remove the rejected events from the list now that we've added them
# to the events table and the events_json table. # to the events table and the events_json table.
to_remove = set() to_remove = set()
@ -853,16 +1044,23 @@ class EventsStore(SQLBaseStore):
) )
to_remove.add(event) to_remove.add(event)
events_and_contexts = [ return [
ec for ec in events_and_contexts if ec[0] not in to_remove ec for ec in events_and_contexts if ec[0] not in to_remove
] ]
if not events_and_contexts: def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
# Make sure we don't pass an empty list to functions that expect to """Update all the miscellaneous tables for new events
# be storing at least one element.
return
# From this point onwards the events are only ones that weren't rejected. Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
backfilled (bool): True if the events were backfilled
"""
if not events_and_contexts:
# nothing to do here
return
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table. # Insert all the push actions into the event_push_actions table.
@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
], ],
) )
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# Update the event_forward_extremities, event_backward_extremities and # Update the event_forward_extremities, event_backward_extremities and
# event_edges tables. # event_edges tables.
self._handle_mult_prev_events( self._handle_mult_prev_events(
@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
# Prefill the event cache # Prefill the event cache
self._add_to_cache(txn, events_and_contexts) self._add_to_cache(txn, events_and_contexts)
if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
return
return
def _add_to_cache(self, txn, events_and_contexts): def _add_to_cache(self, txn, events_and_contexts):
to_prefill = [] to_prefill = []
@ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore):
def get_all_new_events_txn(txn): def get_all_new_events_txn(txn):
sql = ( sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" FROM events as e" " state_key, redacts"
" JOIN event_json as ej" " FROM events AS e"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id" " LEFT JOIN redactions USING (event_id)"
" LEFT JOIN event_to_state_groups as eg" " LEFT JOIN state_events USING (event_id)"
" ON e.event_id = eg.event_id" " WHERE ? < stream_ordering AND stream_ordering <= ?"
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY stream_ordering ASC"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?" " LIMIT ?"
) )
if have_forward_events: if have_forward_events:
@ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = [] forward_ex_outliers = []
sql = ( sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json," "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" eg.state_group" " state_key, redacts"
" FROM events as e" " FROM events AS e"
" JOIN event_json as ej" " LEFT JOIN redactions USING (event_id)"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id" " LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_to_state_groups as eg" " WHERE ? > stream_ordering AND stream_ordering >= ?"
" ON e.event_id = eg.event_id" " ORDER BY stream_ordering DESC"
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
" ORDER BY e.stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
) )
if have_backfill_events: if have_backfill_events:
@ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in curr_state.items() for key, state_id in curr_state.iteritems()
], ],
) )

View File

@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
key_ids key_ids
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
key_ids (list of str): List of key_ids to try and look up. key_ids (iterable[str]): key_ids to try and look up.
Returns: Returns:
(list of VerifyKey): The verification keys. Deferred: resolves to dict[str, VerifyKey]: map from
key_id to verification key.
""" """
keys = {} keys = {}
for key_id in key_ids: for key_id in key_ids:

View File

@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
), ),
(current_version,) (current_version,)
) )
applied_deltas = [d for d, in txn.fetchall()] applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None

View File

@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed, self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id, state.user_id, stream_id,
) )
self._invalidate_cache_and_stream( txn.call_after(
txn, self._get_presence_for_user, (state.user_id,) self._get_presence_for_user.invalidate, (state.user_id,)
) )
# Actually insert new rows # Actually insert new rows

View File

@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
) )
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results and topological_ordering: if topological_ordering:
for to, so, _ in results: for to, so, _ in txn:
if int(to) > topological_ordering: if int(to) > topological_ordering:
return False return False
elif int(to) == topological_ordering and int(so) >= stream_ordering: elif int(to) == topological_ordering and int(so) >= stream_ordering:

View File

@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)" " WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn.fetchall()) return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)

View File

@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",), sql % ("AND appservice_id IS NULL",),
(stream_id,) (stream_id,)
) )
return dict(txn.fetchall()) return dict(txn)
else: else:
# We want to get from all lists, so we need to aggregate the results # We want to get from all lists, so we need to aggregate the results
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {} results = {}
# A room is visible if its visible on any list. # A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall(): for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False) results[room_id] = bool(visibility) or results.get(room_id, False)
return results return results

View File

@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering) yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
sql = (
rows = self._get_members_rows_txn( "SELECT m.user_id FROM room_memberships as m"
txn, " INNER JOIN current_state_events as c"
room_id=room_id, " ON m.event_id = c.event_id "
membership=Membership.JOIN, " AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
) )
return [r["user_id"] for r in rows] txn.execute(sql, (room_id, Membership.JOIN,))
return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
@cached() @cached()
@ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore):
return results return results
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): @cachedInlineCallbacks(max_entries=500000, iterable=True)
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = (
"SELECT m.* FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
return rows
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is( """Returns a set of room_ids the user is currently joined to
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context): def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id` """Returns the set of users who share a room with `user_id`
""" """
rooms = yield self.get_rooms_for_user( room_ids = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate, user_id, on_invalidate=cache_context.invalidate,
) )
user_who_share_room = set() user_who_share_room = set()
for room in rooms: for room_id in room_ids:
user_ids = yield self.get_users_in_room( user_ids = yield self.get_users_in_room(
room.room_id, on_invalidate=cache_context.invalidate, room_id, on_invalidate=cache_context.invalidate,
) )
user_who_share_room.update(user_ids) user_who_share_room.update(user_ids)

View File

@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events): def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU """Store a hash for a PDU

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
@cachedInlineCallbacks(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
defer.returnValue({
(r["type"], r["state_key"]): r["event_id"] for r in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: if not event_ids:
@ -78,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state) defer.returnValue(group_to_state)
@ -96,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ [
ev_id for group_ids in group_to_ids.values() ev_id for group_ids in group_to_ids.itervalues()
for ev_id in group_ids.values() for ev_id in group_ids.itervalues()
], ],
get_prev_content=False get_prev_content=False
) )
defer.returnValue({ defer.returnValue({
group: [ group: [
state_event_map[v] for v in event_id_map.values() if v in state_event_map state_event_map[v] for v in event_id_map.itervalues()
if v in state_event_map
] ]
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.iteritems()
}) })
def _have_persisted_state_group_txn(self, txn, state_group): def _have_persisted_state_group_txn(self, txn, state_group):
@ -124,6 +137,16 @@ class StateStore(SQLBaseStore):
continue continue
if context.current_state_ids is None: if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue continue
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
@ -168,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.delta_ids.items() for key, state_id in context.delta_ids.iteritems()
], ],
) )
else: else:
@ -183,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.current_state_ids.items() for key, state_id in context.current_state_ids.iteritems()
], ],
) )
@ -195,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id, "state_group": state_group_id,
"event_id": event_id, "event_id": event_id,
} }
for event_id, state_group_id in state_groups.items() for event_id, state_group_id in state_groups.iteritems()
], ],
) )
@ -319,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args) args.extend(where_args)
txn.execute(sql % (where_clause,), args) txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn) for row in txn:
for row in rows: typ, state_key, event_id = row
key = (row["type"], row["state_key"]) key = (typ, state_key)
results[group][key] = row["event_id"] results[group][key] = event_id
else: else:
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( where_clause = "AND (%s)" % (
@ -351,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,), " WHERE state_group = ? %s" % (where_clause,),
args args
) )
rows = txn.fetchall() results[group].update(
results[group].update({ ((typ, state_key), event_id)
(typ, state_key): event_id for typ, state_key, event_id in txn
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
}) )
# If the lengths match then we must have all the types, # If the lengths match then we must have all the types,
# so no need to go walk further down the tree. # so no need to go walk further down the tree.
@ -393,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False get_prev_content=False
) )
event_to_state = { event_to_state = {
event_id: { event_id: {
k: state_event_map[v] k: state_event_map[v]
for k, v in group_to_state[group].items() for k, v in group_to_state[group].iteritems()
if v in state_event_map if v in state_event_map
} }
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -430,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -474,7 +496,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types) state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=10000) @cached(num_args=2, max_entries=100000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",
@ -547,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None) got_all = not (missing_types or types is None)
return { return {
k: v for k, v in state_dict_ids.items() k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1]) if include(k[0], k[1])
}, missing_types, got_all }, missing_types, got_all
@ -606,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched # Now we want to update the cache with all the things we fetched
# from the database. # from the database.
for group, group_state_dict in group_to_state_dict.items(): for group, group_state_dict in group_to_state_dict.iteritems():
if types: if types:
# We delibrately put key -> None mappings into the cache to # We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've # cache absence of the key, on the assumption that if we've
@ -621,10 +643,10 @@ class StateStore(SQLBaseStore):
else: else:
state_dict = results[group] state_dict = results[group]
state_dict.update({ state_dict.update(
(intern_string(k[0]), intern_string(k[1])): v ((intern_string(k[0]), intern_string(k[1])), v)
for k, v in group_state_dict.items() for k, v in group_state_dict.iteritems()
}) )
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,
@ -635,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache. # used for bookkeeping in the cache.
for group, state_dict in results.items(): for group, state_dict in results.iteritems():
results[group] = { results[group] = {
key: event_id key: event_id
for key, event_id in state_dict.items() for key, event_id in state_dict.iteritems()
if event_id if event_id
} }
@ -727,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys # of keys
delta_state = { delta_state = {
key: value for key, value in curr_state.items() key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value if prev_state.get(key, None) != value
} }
@ -767,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in delta_state.items() for key, state_id in delta_state.iteritems()
], ],
) )

View File

@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
updatevalues={"stream_id": stream_id}, updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos", desc="update_federation_out_pos",
) )
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)

View File

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
tags = [] tags = []
for tag, content in txn.fetchall(): for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content) tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}" tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json)) results.append((stream_id, user_id, room_id, tag_json))
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?" " WHERE user_id = ? AND stream_id > ?"
) )
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed( changed = self._account_data_stream_cache.has_entity_changed(

View File

@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1): def _load_current_id(db_conn, table, column, step=1):
"""
Args:
db_conn (object):
table (str):
column (str):
step (int):
Returns:
int
"""
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
Returns:
int
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:

View File

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError): class DeferredTimedOutError(SynapseError):
def __init__(self): def __init__(self):
super(SynapseError).__init__(504, "Timed out") super(SynapseError, self).__init__(504, "Timed out")
def unwrapFirstError(failure): def unwrapFirstError(failure):
@ -93,8 +93,10 @@ class Clock(object):
ret_deferred = defer.Deferred() ret_deferred = defer.Deferred()
def timed_out_fn(): def timed_out_fn():
e = DeferredTimedOutError()
try: try:
ret_deferred.errback(DeferredTimedOutError()) ret_deferred.errback(e)
except: except:
pass pass
@ -114,7 +116,7 @@ class Clock(object):
ret_deferred.addBoth(cancel) ret_deferred.addBoth(cancel)
def sucess(res): def success(res):
try: try:
ret_deferred.callback(res) ret_deferred.callback(res)
except: except:
@ -128,7 +130,7 @@ class Clock(object):
except: except:
pass pass
given_deferred.addCallbacks(callback=sucess, errback=err) given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn) timer = self.call_later(time_out, timed_out_fn)

View File

@ -15,12 +15,9 @@
import logging import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
@ -189,7 +186,55 @@ class Cache(object):
self.cache.clear() self.cache.clear()
class CacheDescriptor(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
arg_spec = inspect.getargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
if num_args is None:
num_args = len(all_args) - 1
if cache_context:
num_args -= 1
if len(all_args) < num_args + 1:
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
"**kwargs)"
% (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
self.arg_names = all_args[1:num_args + 1]
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
)
self.add_cache_context = cache_context
class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that This caches deferreds, rather than the results themselves. Deferreds that
@ -217,52 +262,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2) defer.returnValue(r1 + r2)
Args:
num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False): inlineCallbacks=False, cache_context=False, iterable=False):
super(CacheDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
cache_context=cache_context)
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
@ -308,11 +325,9 @@ class CacheDescriptor(object):
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result) observer.addCallback(check_result)
return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, logcontext.preserve_fn(self.function_to_call),
self.function_to_call,
obj, *args, **kwargs obj, *args, **kwargs
) )
@ -322,10 +337,11 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, ret, callback=invalidate_callback) cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
return preserve_context_over_deferred(ret.observe()) return logcontext.make_deferred_yieldable(observer)
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
@ -338,48 +354,40 @@ class CacheDescriptor(object):
return wrapped return wrapped
class CacheListDescriptor(object): class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped function.
Once wrapped, the function returns either a Deferred which resolves to
the list of results, or (if all results were cached), just the list of
results.
""" """
def __init__(self, orig, cached_method_name, list_name, num_args=1, def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False): inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
method_name (str); The name of the chached method. cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks be wrapped by defer.inlineCallbacks
""" """
self.orig = orig super(CacheListDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cached_method_name = cached_method_name self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
@ -425,8 +433,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred( ret_d = defer.maybeDeferred(
preserve_context_over_fn, logcontext.preserve_fn(self.function_to_call),
self.function_to_call,
**args_to_call **args_to_call
) )
@ -435,7 +442,6 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that # We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache. # we can insert the new deferred into the cache.
for arg in missing: for arg in missing:
with PreserveLoggingContext():
observer = ret_d.observe() observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer.addCallback(lambda r, arg: r.get(arg, None), arg)
@ -463,7 +469,7 @@ class CacheListDescriptor(object):
results.update(res) results.update(res)
return results return results
return preserve_context_over_deferred(defer.gatherResults( return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(), cached_defers.values(),
consumeErrors=True, consumeErrors=True,
).addCallback(update_results_dict).addErrback( ).addCallback(update_results_dict).addErrback(
@ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False): iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
@ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
iterable=False): cache_context=False, iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
) )
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use. cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache. num_args (int): Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`? `defer.inlineCallbacks`?

View File

@ -50,7 +50,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
assert type(stream_pos) is int assert type(stream_pos) is int or type(stream_pos) is long
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses() self.metrics.inc_misses()

View File

@ -12,6 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Thread-local-alike tracking of log contexts within synapse
This module provides objects and utilities for tracking contexts through
synapse code, so that log lines can include a request identifier, and so that
CPU and database activity can be accounted for against the request that caused
them.
See doc/log_contexts.rst for details on how this works.
"""
from twisted.internet import defer from twisted.internet import defer
import threading import threading
@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
def preserve_context_over_deferred(deferred, context=None): def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
Deprecated: this almost certainly doesn't do want you want, ie make
the deferred follow the synapse logcontext rules: try
``make_deferred_yieldable`` instead.
""" """
if context is None: if context is None:
context = LoggingContext.current_context() context = LoggingContext.current_context()
@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
def preserve_fn(f): def preserve_fn(f):
"""Ensures that function is called with correct context and that context is """Wraps a function, to ensure that the current context is restored after
restored after return. Useful for wrapping functions that return a deferred return from the function, and that the sentinel context is set once the
which you don't yield on. deferred returned by the funtion completes.
Useful for wrapping functions that return a deferred which you don't yield
on.
""" """
def reset_context(result):
LoggingContext.set_current_context(LoggingContext.sentinel)
return result
# XXX: why is this here rather than inside g? surely we want to preserve
# the context from the time the function was called, not when it was
# wrapped?
current = LoggingContext.current_context() current = LoggingContext.current_context()
def g(*args, **kwargs): def g(*args, **kwargs):
with PreserveLoggingContext(current):
res = f(*args, **kwargs) res = f(*args, **kwargs)
if isinstance(res, defer.Deferred): if isinstance(res, defer.Deferred) and not res.called:
return preserve_context_over_deferred( # The function will have reset the context before returning, so
res, context=LoggingContext.sentinel # we need to restore it now.
) LoggingContext.set_current_context(current)
else:
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
# get leaked into the reactor or some other function which
# wasn't expecting it. We therefore need to reset the context
# here.
#
# (If this feels asymmetric, consider it this way: we are
# effectively forking a new thread of execution. We are
# probably currently within a ``with LoggingContext()`` block,
# which is supposed to have a single entry and exit point. But
# by spawning off another deferred, we are effectively
# adding a new exit point.)
res.addBoth(reset_context)
return res return res
return g return g
@defer.inlineCallbacks
def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
(This is more-or-less the opposite operation to preserve_fn.)
"""
with PreserveLoggingContext():
r = yield deferred
defer.returnValue(r)
# modules to ignore in `logcontext_tracer` # modules to ignore in `logcontext_tracer`
_to_ignore = [ _to_ignore = [
"synapse.util.logcontext", "synapse.util.logcontext",

40
synapse/util/msisdn.py Normal file
View File

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import phonenumbers
from synapse.api.errors import SynapseError
def phone_number_to_msisdn(country, number):
"""
Takes an ISO-3166-1 2 letter country code and phone number and
returns an msisdn representing the canonical version of that
phone number.
Args:
country (str): ISO-3166-1 2 letter country code
number (str): Phone number in a national or international format
Returns:
(str) The canonical form of the phone number, as an msisdn
Raises:
SynapseError if the number could not be parsed.
"""
try:
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
return phonenumbers.format_number(
phoneNumber, phonenumbers.PhoneNumberFormat.E164
)[1:]

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.logcontext
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, **kwargs): def get_retry_limiter(destination, clock, store, ignore_backoff=False,
**kwargs):
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500) CodeMessageException with code < 500)
Args:
destination (str): name of homeserver
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. We will still update the next
retry_interval on success/failure.
Example usage: Example usage:
try: try:
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec()) now = int(clock.time_msec())
if retry_last_ts + retry_interval > now: if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination( raise NotRetryingDestination(
retry_last_ts=retry_last_ts, retry_last_ts=retry_last_ts,
retry_interval=retry_interval, retry_interval=retry_interval,
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException): if exc_type is None:
valid_err_code = True
elif not issubclass(exc_type, Exception):
# avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True
elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other # Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to # APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS # handle 404 as some remote servers will return a 404 when the HS
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
else: else:
valid_err_code = False valid_err_code = False
if exc_type is None or valid_err_code: if valid_err_code:
# We connected successfully. # We connected successfully.
if not self.retry_interval: if not self.retry_interval:
return return
logger.debug("Connection to %s was successful; clearing backoff",
self.destination)
retry_last_ts = 0 retry_last_ts = 0
self.retry_interval = 0 self.retry_interval = 0
else: else:
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
else: else:
self.retry_interval = self.min_retry_interval self.retry_interval = self.min_retry_interval
logger.debug(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval
)
retry_last_ts = int(self.clock.time_msec()) retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks @defer.inlineCallbacks
@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
"Failed to store set_destination_retry_timings", "Failed to store set_destination_retry_timings",
) )
store_retry_timings() # we deliberately do this in the background.
synapse.util.logcontext.preserve_fn(store_retry_timings)()

View File

@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY: if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave" prev_membership = "leave"
# Always allow the user to see their own leave events, otherwise
# they won't see the room disappear if they reject the invite
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
return True
new_priority = MEMBERSHIP_PRIORITY.index(membership) new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority: if old_priority < new_priority:

View File

@ -23,6 +23,9 @@ from tests.utils import (
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.errors import SynapseError
import jsonschema
user_localpart = "test_user" user_localpart = "test_user"
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
{"event_fields": ["\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}}
]
for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error:
self.filtering.check_valid_filter(filter)
self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
{
"room": {
"timeline": {"limit": 20},
"state": {"not_types": ["m.room.member"]},
"ephemeral": {"limit": 0, "not_types": ["*"]},
"include_leave": False,
"rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]}
}
},
{
"room": {
"state": {
"types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"]
},
"timeline": {
"limit": 10,
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
}
},
"presence": {
"types": ["m.presence"],
"not_senders": ["@alice:example.com"]
},
"event_format": "client",
"event_fields": ["type", "content", "sender"]
}
]
for filter in valid_filters:
try:
self.filtering.check_valid_filter(filter)
except jsonschema.ValidationError as e:
self.fail(e)
def test_limits_are_applied(self):
# TODO
pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self):
definition = { definition = {
"types": ["m.room.message", "org.matrix.foo.bar"] "types": ["m.room.message", "org.matrix.foo.bar"]

View File

@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
"room_alias": "#another:remote", "room_alias": "#another:remote",
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
state = UserPresenceState.default(user_id) state = UserPresenceState.default(user_id)
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE,
last_active_ts=now, last_active_ts=0,
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
) )

View File

@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
query_type="profile", query_type="profile",
args={"user_id": "@alice:remote", "field": "displayname"} args={"user_id": "@alice:remote", "field": "displayname"},
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View File

@ -68,7 +68,7 @@ class ReplicationResourceCase(unittest.TestCase):
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
self.assertEquals(body["events"]["field_names"], [ self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group" "position", "event_id", "room_id", "type", "state_key",
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self): def test_select_one_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchall.return_value = [("Value",)] self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol( value = yield self.datastore._simple_select_one_onecol(
table="tablename", table="tablename",
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self): def test_select_list(self):
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = ( self.mock_txn.description = (
("colA", None, None, None, None, None, None), ("colA", None, None, None, None, None, None),
) )

View File

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signedjson.key
from twisted.internet import defer
import tests.unittest
import tests.utils
class KeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.keys.KeyStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_server_verify_keys(self):
key1 = signedjson.key.decode_verify_key_base64(
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
key2 = signedjson.key.decode_verify_key_base64(
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key1
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key2
)
res = yield self.store.get_server_verify_keys(
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res["ed25519:key1"].version, "key1")
self.assertEqual(res["ed25519:key2"].version, "key2")

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import mock
from synapse.api.errors import SynapseError
from synapse.util import async
from synapse.util import logcontext
from twisted.internet import defer
from synapse.util.caches import descriptors
from tests import unittest
logger = logging.getLogger(__name__)
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
@defer.inlineCallbacks
def test_cache_num_args(self):
"""Only the first num_args arguments should matter to the cache"""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(num_args=1)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3)
obj.mock.reset_mock()
# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 5)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
complete_lookup = defer.Deferred()
class Cls(object):
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
with logcontext.PreserveLoggingContext():
yield complete_lookup
defer.returnValue(1)
return inner_fn()
@defer.inlineCallbacks
def do_lookup():
with logcontext.LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
self.assertEqual(logcontext.LoggingContext.current_context(),
c1)
defer.returnValue(r)
def check_result(r):
self.assertEqual(r, 1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
d2.addCallback(check_result)
# let the lookup complete
complete_lookup.callback(None)
return defer.gatherResults([d1, d2])
def test_cache_logcontexts_with_exception(self):
"""Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception"""
class Cls(object):
@descriptors.cached()
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
yield async.run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@defer.inlineCallbacks
def do_lookup():
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
yield obj.fn(1)
self.fail("No exception thrown")
except SynapseError:
pass
self.assertEqual(logcontext.LoggingContext.current_context(),
c1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel)
return d1

33
tests/util/test_clock.py Normal file
View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse import util
from twisted.internet import defer
from tests import unittest
class ClockTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_time_bound_deferred(self):
# just a deferred which never resolves
slow_deferred = defer.Deferred()
clock = util.Clock()
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
try:
yield time_bound
self.fail("Expected timedout error, but got nothing")
except util.DeferredTimedOutError:
pass

View File

@ -1,8 +1,10 @@
import twisted.python.failure
from twisted.internet import defer from twisted.internet import defer
from twisted.internet import reactor from twisted.internet import reactor
from .. import unittest from .. import unittest
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase):
context_one.test_key = "one" context_one.test_key = "one"
yield sleep(0) yield sleep(0)
self._check_test_key("one") self._check_test_key("one")
def _test_preserve_fn(self, function):
sentinel_context = LoggingContext.current_context()
callback_completed = [False]
@defer.inlineCallbacks
def cb():
context_one.test_key = "one"
yield function()
self._check_test_key("one")
callback_completed[0] = True
with LoggingContext() as context_one:
context_one.test_key = "one"
# fire off function, but don't wait on it.
logcontext.preserve_fn(cb)()
self._check_test_key("one")
# now wait for the function under test to have run, and check that
# the logcontext is left in a sane state.
d2 = defer.Deferred()
def check_logcontext():
if not callback_completed[0]:
reactor.callLater(0.01, check_logcontext)
return
# make sure that the context was reset before it got thrown back
# into the reactor
try:
self.assertIs(LoggingContext.current_context(),
sentinel_context)
d2.callback(None)
except BaseException:
d2.errback(twisted.python.failure.Failure())
reactor.callLater(0.01, check_logcontext)
# test is done once d2 finishes
return d2
def test_preserve_fn_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
yield sleep(0)
return self._test_preserve_fn(blocking_function)
def test_preserve_fn_with_non_blocking_fn(self):
@defer.inlineCallbacks
def nonblocking_function():
with logcontext.PreserveLoggingContext():
yield defer.succeed(None)
return self._test_preserve_fn(nonblocking_function)