Merge branch 'release-v0.20.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
4902db1fc9
53
CHANGES.rst
53
CHANGES.rst
|
@ -1,3 +1,56 @@
|
||||||
|
Changes in synapse v0.20.0 (2017-04-11)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix joining rooms over federation where not all servers in the room saw the
|
||||||
|
new server had joined (PR #2094)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.20.0-rc1 (2017-03-30)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add delete_devices API (PR #1993)
|
||||||
|
* Add phone number registration/login support (PR #1994, #2055)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Use JSONSchema for validation of filters. Thanks @pik! (PR #1783)
|
||||||
|
* Reread log config on SIGHUP (PR #1982)
|
||||||
|
* Speed up public room list (PR #1989)
|
||||||
|
* Add helpful texts to logger config options (PR #1990)
|
||||||
|
* Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022)
|
||||||
|
* Add some debug to help diagnose weird federation issue (PR #2035)
|
||||||
|
* Correctly limit retries for all federation requests (PR #2050, #2061)
|
||||||
|
* Don't lock table when persisting new one time keys (PR #2053)
|
||||||
|
* Reduce some CPU work on DB threads (PR #2054)
|
||||||
|
* Cache hosts in room (PR #2060)
|
||||||
|
* Batch sending of device list pokes (PR #2063)
|
||||||
|
* Speed up persist event path in certain edge cases (PR #2070)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix bug where current_state_events renamed to current_state_ids (PR #1849)
|
||||||
|
* Fix routing loop when fetching remote media (PR #1992)
|
||||||
|
* Fix current_state_events table to not lie (PR #1996)
|
||||||
|
* Fix CAS login to handle PartialDownloadError (PR #1997)
|
||||||
|
* Fix assertion to stop transaction queue getting wedged (PR #2010)
|
||||||
|
* Fix presence to fallback to last_active_ts if it beats the last sync time.
|
||||||
|
Thanks @Half-Shot! (PR #2014)
|
||||||
|
* Fix bug when federation received a PDU while a room join is in progress (PR
|
||||||
|
#2016)
|
||||||
|
* Fix resetting state on rejected events (PR #2025)
|
||||||
|
* Fix installation issues in readme. Thanks @ricco386 (PR #2037)
|
||||||
|
* Fix caching of remote servers' signature keys (PR #2042)
|
||||||
|
* Fix some leaking log context (PR #2048, #2049, #2057, #2058)
|
||||||
|
* Fix rejection of invites not reaching sync (PR #2056)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.19.3 (2017-03-20)
|
Changes in synapse v0.19.3 (2017-03-20)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
|
|
@ -146,6 +146,7 @@ To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
source ~/.synapse/bin/activate
|
source ~/.synapse/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install https://github.com/matrix-org/synapse/tarball/master
|
pip install https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
|
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
|
||||||
New user localpart: erikj
|
New user localpart: erikj
|
||||||
Password:
|
Password:
|
||||||
Confirm password:
|
Confirm password:
|
||||||
|
Make admin [no]:
|
||||||
Success!
|
Success!
|
||||||
|
|
||||||
This process uses a setting ``registration_shared_secret`` in
|
This process uses a setting ``registration_shared_secret`` in
|
||||||
|
@ -808,7 +810,7 @@ directory of your choice::
|
||||||
Synapse has a number of external dependencies, that are easiest
|
Synapse has a number of external dependencies, that are easiest
|
||||||
to install using pip and a virtualenv::
|
to install using pip and a virtualenv::
|
||||||
|
|
||||||
virtualenv env
|
virtualenv -p python2.7 env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
python synapse/python_dependencies.py | xargs pip install
|
python synapse/python_dependencies.py | xargs pip install
|
||||||
pip install lxml mock
|
pip install lxml mock
|
||||||
|
|
|
@ -39,7 +39,9 @@ loggers:
|
||||||
synapse:
|
synapse:
|
||||||
level: INFO
|
level: INFO
|
||||||
|
|
||||||
synapse.storage:
|
synapse.storage.SQL:
|
||||||
|
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||||
|
# information such as access tokens.
|
||||||
level: INFO
|
level: INFO
|
||||||
|
|
||||||
# example of enabling debugging for a component:
|
# example of enabling debugging for a component:
|
||||||
|
|
|
@ -1,10 +1,446 @@
|
||||||
What do I do about "Unexpected logging context" debug log-lines everywhere?
|
Log contexts
|
||||||
|
============
|
||||||
|
|
||||||
<Mjark> The logging context lives in thread local storage
|
.. contents::
|
||||||
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
|
|
||||||
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
|
|
||||||
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
|
|
||||||
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
|
|
||||||
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
|
|
||||||
|
|
||||||
Unanswered: how and when should we preserve log context?
|
To help track the processing of individual requests, synapse uses a
|
||||||
|
'log context' to track which request it is handling at any given moment. This
|
||||||
|
is done via a thread-local variable; a ``logging.Filter`` is then used to fish
|
||||||
|
the information back out of the thread-local variable and add it to each log
|
||||||
|
record.
|
||||||
|
|
||||||
|
Logcontexts are also used for CPU and database accounting, so that we can track
|
||||||
|
which requests were responsible for high CPU use or database activity.
|
||||||
|
|
||||||
|
The ``synapse.util.logcontext`` module provides a facilities for managing the
|
||||||
|
current log context (as well as providing the ``LoggingContextFilter`` class).
|
||||||
|
|
||||||
|
Deferreds make the whole thing complicated, so this document describes how it
|
||||||
|
all works, and how to write code which follows the rules.
|
||||||
|
|
||||||
|
Logcontexts without Deferreds
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
In the absence of any Deferred voodoo, things are simple enough. As with any
|
||||||
|
code of this nature, the rule is that our function should leave things as it
|
||||||
|
found them:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
from synapse.util import logcontext # omitted from future snippets
|
||||||
|
|
||||||
|
def handle_request(request_id):
|
||||||
|
request_context = logcontext.LoggingContext()
|
||||||
|
|
||||||
|
calling_context = logcontext.LoggingContext.current_context()
|
||||||
|
logcontext.LoggingContext.set_current_context(request_context)
|
||||||
|
try:
|
||||||
|
request_context.request = request_id
|
||||||
|
do_request_handling()
|
||||||
|
logger.debug("finished")
|
||||||
|
finally:
|
||||||
|
logcontext.LoggingContext.set_current_context(calling_context)
|
||||||
|
|
||||||
|
def do_request_handling():
|
||||||
|
logger.debug("phew") # this will be logged against request_id
|
||||||
|
|
||||||
|
|
||||||
|
LoggingContext implements the context management methods, so the above can be
|
||||||
|
written much more succinctly as:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def handle_request(request_id):
|
||||||
|
with logcontext.LoggingContext() as request_context:
|
||||||
|
request_context.request = request_id
|
||||||
|
do_request_handling()
|
||||||
|
logger.debug("finished")
|
||||||
|
|
||||||
|
def do_request_handling():
|
||||||
|
logger.debug("phew")
|
||||||
|
|
||||||
|
|
||||||
|
Using logcontexts with Deferreds
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
Deferreds — and in particular, ``defer.inlineCallbacks`` — break
|
||||||
|
the linear flow of code so that there is no longer a single entry point where
|
||||||
|
we should set the logcontext and a single exit point where we should remove it.
|
||||||
|
|
||||||
|
Consider the example above, where ``do_request_handling`` needs to do some
|
||||||
|
blocking operation, and returns a deferred:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request(request_id):
|
||||||
|
with logcontext.LoggingContext() as request_context:
|
||||||
|
request_context.request = request_id
|
||||||
|
yield do_request_handling()
|
||||||
|
logger.debug("finished")
|
||||||
|
|
||||||
|
|
||||||
|
In the above flow:
|
||||||
|
|
||||||
|
* The logcontext is set
|
||||||
|
* ``do_request_handling`` is called, and returns a deferred
|
||||||
|
* ``handle_request`` yields the deferred
|
||||||
|
* The ``inlineCallbacks`` wrapper of ``handle_request`` returns a deferred
|
||||||
|
|
||||||
|
So we have stopped processing the request (and will probably go on to start
|
||||||
|
processing the next), without clearing the logcontext.
|
||||||
|
|
||||||
|
To circumvent this problem, synapse code assumes that, wherever you have a
|
||||||
|
deferred, you will want to yield on it. To that end, whereever functions return
|
||||||
|
a deferred, we adopt the following conventions:
|
||||||
|
|
||||||
|
**Rules for functions returning deferreds:**
|
||||||
|
|
||||||
|
* If the deferred is already complete, the function returns with the same
|
||||||
|
logcontext it started with.
|
||||||
|
* If the deferred is incomplete, the function clears the logcontext before
|
||||||
|
returning; when the deferred completes, it restores the logcontext before
|
||||||
|
running any callbacks.
|
||||||
|
|
||||||
|
That sounds complicated, but actually it means a lot of code (including the
|
||||||
|
example above) "just works". There are two cases:
|
||||||
|
|
||||||
|
* If ``do_request_handling`` returns a completed deferred, then the logcontext
|
||||||
|
will still be in place. In this case, execution will continue immediately
|
||||||
|
after the ``yield``; the "finished" line will be logged against the right
|
||||||
|
context, and the ``with`` block restores the original context before we
|
||||||
|
return to the caller.
|
||||||
|
|
||||||
|
* If the returned deferred is incomplete, ``do_request_handling`` clears the
|
||||||
|
logcontext before returning. The logcontext is therefore clear when
|
||||||
|
``handle_request`` yields the deferred. At that point, the ``inlineCallbacks``
|
||||||
|
wrapper adds a callback to the deferred, and returns another (incomplete)
|
||||||
|
deferred to the caller, and it is safe to begin processing the next request.
|
||||||
|
|
||||||
|
Once ``do_request_handling``'s deferred completes, it will reinstate the
|
||||||
|
logcontext, before running the callback added by the ``inlineCallbacks``
|
||||||
|
wrapper. That callback runs the second half of ``handle_request``, so again
|
||||||
|
the "finished" line will be logged against the right
|
||||||
|
context, and the ``with`` block restores the original context.
|
||||||
|
|
||||||
|
As an aside, it's worth noting that ``handle_request`` follows our rules -
|
||||||
|
though that only matters if the caller has its own logcontext which it cares
|
||||||
|
about.
|
||||||
|
|
||||||
|
The following sections describe pitfalls and helpful patterns when implementing
|
||||||
|
these rules.
|
||||||
|
|
||||||
|
Always yield your deferreds
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
Whenever you get a deferred back from a function, you should ``yield`` on it
|
||||||
|
as soon as possible. (Returning it directly to your caller is ok too, if you're
|
||||||
|
not doing ``inlineCallbacks``.) Do not pass go; do not do any logging; do not
|
||||||
|
call any other functions.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def fun():
|
||||||
|
logger.debug("starting")
|
||||||
|
yield do_some_stuff() # just like this
|
||||||
|
|
||||||
|
d = more_stuff()
|
||||||
|
result = yield d # also fine, of course
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
def nonInlineCallbacksFun():
|
||||||
|
logger.debug("just a wrapper really")
|
||||||
|
return do_some_stuff() # this is ok too - the caller will yield on
|
||||||
|
# it anyway.
|
||||||
|
|
||||||
|
Provided this pattern is followed all the way back up to the callchain to where
|
||||||
|
the logcontext was set, this will make things work out ok: provided
|
||||||
|
``do_some_stuff`` and ``more_stuff`` follow the rules above, then so will
|
||||||
|
``fun`` (as wrapped by ``inlineCallbacks``) and ``nonInlineCallbacksFun``.
|
||||||
|
|
||||||
|
It's all too easy to forget to ``yield``: for instance if we forgot that
|
||||||
|
``do_some_stuff`` returned a deferred, we might plough on regardless. This
|
||||||
|
leads to a mess; it will probably work itself out eventually, but not before
|
||||||
|
a load of stuff has been logged against the wrong content. (Normally, other
|
||||||
|
things will break, more obviously, if you forget to ``yield``, so this tends
|
||||||
|
not to be a major problem in practice.)
|
||||||
|
|
||||||
|
Of course sometimes you need to do something a bit fancier with your Deferreds
|
||||||
|
- not all code follows the linear A-then-B-then-C pattern. Notes on
|
||||||
|
implementing more complex patterns are in later sections.
|
||||||
|
|
||||||
|
Where you create a new Deferred, make it follow the rules
|
||||||
|
---------------------------------------------------------
|
||||||
|
|
||||||
|
Most of the time, a Deferred comes from another synapse function. Sometimes,
|
||||||
|
though, we need to make up a new Deferred, or we get a Deferred back from
|
||||||
|
external code. We need to make it follow our rules.
|
||||||
|
|
||||||
|
The easy way to do it is with a combination of ``defer.inlineCallbacks``, and
|
||||||
|
``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``,
|
||||||
|
which returns a deferred which will run its callbacks after a given number of
|
||||||
|
seconds. That might look like:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
# not a logcontext-rules-compliant function
|
||||||
|
def get_sleep_deferred(seconds):
|
||||||
|
d = defer.Deferred()
|
||||||
|
reactor.callLater(seconds, d.callback, None)
|
||||||
|
return d
|
||||||
|
|
||||||
|
That doesn't follow the rules, but we can fix it by wrapping it with
|
||||||
|
``PreserveLoggingContext`` and ``yield`` ing on it:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def sleep(seconds):
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
yield get_sleep_deferred(seconds)
|
||||||
|
|
||||||
|
This technique works equally for external functions which return deferreds,
|
||||||
|
or deferreds we have made ourselves.
|
||||||
|
|
||||||
|
You can also use ``logcontext.make_deferred_yieldable``, which just does the
|
||||||
|
boilerplate for you, so the above could be written:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def sleep(seconds):
|
||||||
|
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
|
||||||
|
|
||||||
|
|
||||||
|
Fire-and-forget
|
||||||
|
---------------
|
||||||
|
|
||||||
|
Sometimes you want to fire off a chain of execution, but not wait for its
|
||||||
|
result. That might look a bit like this:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request_handling():
|
||||||
|
yield foreground_operation()
|
||||||
|
|
||||||
|
# *don't* do this
|
||||||
|
background_operation()
|
||||||
|
|
||||||
|
logger.debug("Request handling complete")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def background_operation():
|
||||||
|
yield first_background_step()
|
||||||
|
logger.debug("Completed first step")
|
||||||
|
yield second_background_step()
|
||||||
|
logger.debug("Completed second step")
|
||||||
|
|
||||||
|
The above code does a couple of steps in the background after
|
||||||
|
``do_request_handling`` has finished. The log lines are still logged against
|
||||||
|
the ``request_context`` logcontext, which may or may not be desirable. There
|
||||||
|
are two big problems with the above, however. The first problem is that, if
|
||||||
|
``background_operation`` returns an incomplete Deferred, it will expect its
|
||||||
|
caller to ``yield`` immediately, so will have cleared the logcontext. In this
|
||||||
|
example, that means that 'Request handling complete' will be logged without any
|
||||||
|
context.
|
||||||
|
|
||||||
|
The second problem, which is potentially even worse, is that when the Deferred
|
||||||
|
returned by ``background_operation`` completes, it will restore the original
|
||||||
|
logcontext. There is nothing waiting on that Deferred, so the logcontext will
|
||||||
|
leak into the reactor and possibly get attached to some arbitrary future
|
||||||
|
operation.
|
||||||
|
|
||||||
|
There are two potential solutions to this.
|
||||||
|
|
||||||
|
One option is to surround the call to ``background_operation`` with a
|
||||||
|
``PreserveLoggingContext`` call. That will reset the logcontext before
|
||||||
|
starting ``background_operation`` (so the context restored when the deferred
|
||||||
|
completes will be the empty logcontext), and will restore the current
|
||||||
|
logcontext before continuing the foreground process:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request_handling():
|
||||||
|
yield foreground_operation()
|
||||||
|
|
||||||
|
# start background_operation off in the empty logcontext, to
|
||||||
|
# avoid leaking the current context into the reactor.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
background_operation()
|
||||||
|
|
||||||
|
# this will now be logged against the request context
|
||||||
|
logger.debug("Request handling complete")
|
||||||
|
|
||||||
|
Obviously that option means that the operations done in
|
||||||
|
``background_operation`` would be not be logged against a logcontext (though
|
||||||
|
that might be fixed by setting a different logcontext via a ``with
|
||||||
|
LoggingContext(...)`` in ``background_operation``).
|
||||||
|
|
||||||
|
The second option is to use ``logcontext.preserve_fn``, which wraps a function
|
||||||
|
so that it doesn't reset the logcontext even when it returns an incomplete
|
||||||
|
deferred, and adds a callback to the returned deferred to reset the
|
||||||
|
logcontext. In other words, it turns a function that follows the Synapse rules
|
||||||
|
about logcontexts and Deferreds into one which behaves more like an external
|
||||||
|
function — the opposite operation to that described in the previous section.
|
||||||
|
It can be used like this:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request_handling():
|
||||||
|
yield foreground_operation()
|
||||||
|
|
||||||
|
logcontext.preserve_fn(background_operation)()
|
||||||
|
|
||||||
|
# this will now be logged against the request context
|
||||||
|
logger.debug("Request handling complete")
|
||||||
|
|
||||||
|
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
|
||||||
|
but the fact that it does ``preserve_context_over_deferred`` on its results
|
||||||
|
means that its use is fraught with difficulty.
|
||||||
|
|
||||||
|
Passing synapse deferreds into third-party functions
|
||||||
|
----------------------------------------------------
|
||||||
|
|
||||||
|
A typical example of this is where we want to collect together two or more
|
||||||
|
deferred via ``defer.gatherResults``:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
d1 = operation1()
|
||||||
|
d2 = operation2()
|
||||||
|
d3 = defer.gatherResults([d1, d2])
|
||||||
|
|
||||||
|
This is really a variation of the fire-and-forget problem above, in that we are
|
||||||
|
firing off ``d1`` and ``d2`` without yielding on them. The difference
|
||||||
|
is that we now have third-party code attached to their callbacks. Anyway either
|
||||||
|
technique given in the `Fire-and-forget`_ section will work.
|
||||||
|
|
||||||
|
Of course, the new Deferred returned by ``gatherResults`` needs to be wrapped
|
||||||
|
in order to make it follow the logcontext rules before we can yield it, as
|
||||||
|
described in `Where you create a new Deferred, make it follow the rules`_.
|
||||||
|
|
||||||
|
So, option one: reset the logcontext before starting the operations to be
|
||||||
|
gathered:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request_handling():
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
d1 = operation1()
|
||||||
|
d2 = operation2()
|
||||||
|
result = yield defer.gatherResults([d1, d2])
|
||||||
|
|
||||||
|
In this case particularly, though, option two, of using
|
||||||
|
``logcontext.preserve_fn`` almost certainly makes more sense, so that
|
||||||
|
``operation1`` and ``operation2`` are both logged against the original
|
||||||
|
logcontext. This looks like:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_request_handling():
|
||||||
|
d1 = logcontext.preserve_fn(operation1)()
|
||||||
|
d2 = logcontext.preserve_fn(operation2)()
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
result = yield defer.gatherResults([d1, d2])
|
||||||
|
|
||||||
|
|
||||||
|
Was all this really necessary?
|
||||||
|
------------------------------
|
||||||
|
|
||||||
|
The conventions used work fine for a linear flow where everything happens in
|
||||||
|
series via ``defer.inlineCallbacks`` and ``yield``, but are certainly tricky to
|
||||||
|
follow for any more exotic flows. It's hard not to wonder if we could have done
|
||||||
|
something else.
|
||||||
|
|
||||||
|
We're not going to rewrite Synapse now, so the following is entirely of
|
||||||
|
academic interest, but I'd like to record some thoughts on an alternative
|
||||||
|
approach.
|
||||||
|
|
||||||
|
I briefly prototyped some code following an alternative set of rules. I think
|
||||||
|
it would work, but I certainly didn't get as far as thinking how it would
|
||||||
|
interact with concepts as complicated as the cache descriptors.
|
||||||
|
|
||||||
|
My alternative rules were:
|
||||||
|
|
||||||
|
* functions always preserve the logcontext of their caller, whether or not they
|
||||||
|
are returning a Deferred.
|
||||||
|
|
||||||
|
* Deferreds returned by synapse functions run their callbacks in the same
|
||||||
|
context as the function was orignally called in.
|
||||||
|
|
||||||
|
The main point of this scheme is that everywhere that sets the logcontext is
|
||||||
|
responsible for clearing it before returning control to the reactor.
|
||||||
|
|
||||||
|
So, for example, if you were the function which started a ``with
|
||||||
|
LoggingContext`` block, you wouldn't ``yield`` within it — instead you'd start
|
||||||
|
off the background process, and then leave the ``with`` block to wait for it:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def handle_request(request_id):
|
||||||
|
with logcontext.LoggingContext() as request_context:
|
||||||
|
request_context.request = request_id
|
||||||
|
d = do_request_handling()
|
||||||
|
|
||||||
|
def cb(r):
|
||||||
|
logger.debug("finished")
|
||||||
|
|
||||||
|
d.addCallback(cb)
|
||||||
|
return d
|
||||||
|
|
||||||
|
(in general, mixing ``with LoggingContext`` blocks and
|
||||||
|
``defer.inlineCallbacks`` in the same function leads to slighly
|
||||||
|
counter-intuitive code, under this scheme).
|
||||||
|
|
||||||
|
Because we leave the original ``with`` block as soon as the Deferred is
|
||||||
|
returned (as opposed to waiting for it to be resolved, as we do today), the
|
||||||
|
logcontext is cleared before control passes back to the reactor; so if there is
|
||||||
|
some code within ``do_request_handling`` which needs to wait for a Deferred to
|
||||||
|
complete, there is no need for it to worry about clearing the logcontext before
|
||||||
|
doing so:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def handle_request():
|
||||||
|
r = do_some_stuff()
|
||||||
|
r.addCallback(do_some_more_stuff)
|
||||||
|
return r
|
||||||
|
|
||||||
|
— and provided ``do_some_stuff`` follows the rules of returning a Deferred which
|
||||||
|
runs its callbacks in the original logcontext, all is happy.
|
||||||
|
|
||||||
|
The business of a Deferred which runs its callbacks in the original logcontext
|
||||||
|
isn't hard to achieve — we have it today, in the shape of
|
||||||
|
``logcontext._PreservingContextDeferred``:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def do_some_stuff():
|
||||||
|
deferred = do_some_io()
|
||||||
|
pcd = _PreservingContextDeferred(LoggingContext.current_context())
|
||||||
|
deferred.chainDeferred(pcd)
|
||||||
|
return pcd
|
||||||
|
|
||||||
|
It turns out that, thanks to the way that Deferreds chain together, we
|
||||||
|
automatically get the property of a context-preserving deferred with
|
||||||
|
``defer.inlineCallbacks``, provided the final Defered the function ``yields``
|
||||||
|
on has that property. So we can just write:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request():
|
||||||
|
yield do_some_stuff()
|
||||||
|
yield do_some_more_stuff()
|
||||||
|
|
||||||
|
To conclude: I think this scheme would have worked equally well, with less
|
||||||
|
danger of messing it up, and probably made some more esoteric code easier to
|
||||||
|
write. But again — changing the conventions of the entire Synapse codebase is
|
||||||
|
not a sensible option for the marginal improvement offered.
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.19.3"
|
__version__ = "0.20.0"
|
||||||
|
|
|
@ -23,7 +23,7 @@ from synapse import event_auth
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import AuthError, Codes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util import logcontext
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -209,8 +209,7 @@ class Auth(object):
|
||||||
default=[""]
|
default=[""]
|
||||||
)[0]
|
)[0]
|
||||||
if user and access_token and ip_addr:
|
if user and access_token and ip_addr:
|
||||||
preserve_context_over_fn(
|
logcontext.preserve_fn(self.store.insert_client_ip)(
|
||||||
self.store.insert_client_ip,
|
|
||||||
user=user,
|
user=user,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -44,6 +45,7 @@ class JoinRules(object):
|
||||||
class LoginType(object):
|
class LoginType(object):
|
||||||
PASSWORD = u"m.login.password"
|
PASSWORD = u"m.login.password"
|
||||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||||
|
MSISDN = u"m.login.msisdn"
|
||||||
RECAPTCHA = u"m.login.recaptcha"
|
RECAPTCHA = u"m.login.recaptcha"
|
||||||
DUMMY = u"m.login.dummy"
|
DUMMY = u"m.login.dummy"
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
"""Contains exceptions and error codes."""
|
"""Contains exceptions and error codes."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -50,27 +51,35 @@ class Codes(object):
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
"""An exception with integer code and message string attributes."""
|
"""An exception with integer code and message string attributes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
code (int): HTTP error code
|
||||||
|
msg (str): string describing the error
|
||||||
|
"""
|
||||||
def __init__(self, code, msg):
|
def __init__(self, code, msg):
|
||||||
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
||||||
self.code = code
|
self.code = code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.response_code_message = None
|
|
||||||
|
|
||||||
def error_dict(self):
|
def error_dict(self):
|
||||||
return cs_error(self.msg)
|
return cs_error(self.msg)
|
||||||
|
|
||||||
|
|
||||||
class SynapseError(CodeMessageException):
|
class SynapseError(CodeMessageException):
|
||||||
"""A base error which can be caught for all synapse events."""
|
"""A base exception type for matrix errors which have an errcode and error
|
||||||
|
message (as well as an HTTP status code).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||||
|
"""
|
||||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
||||||
"""Constructs a synapse error.
|
"""Constructs a synapse error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
code (int): The integer error code (an HTTP response code)
|
code (int): The integer error code (an HTTP response code)
|
||||||
msg (str): The human-readable error message.
|
msg (str): The human-readable error message.
|
||||||
err (str): The error code e.g 'M_FORBIDDEN'
|
errcode (str): The matrix error code e.g 'M_FORBIDDEN'
|
||||||
"""
|
"""
|
||||||
super(SynapseError, self).__init__(code, msg)
|
super(SynapseError, self).__init__(code, msg)
|
||||||
self.errcode = errcode
|
self.errcode = errcode
|
||||||
|
@ -81,6 +90,39 @@ class SynapseError(CodeMessageException):
|
||||||
self.errcode,
|
self.errcode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_http_response_exception(cls, err):
|
||||||
|
"""Make a SynapseError based on an HTTPResponseException
|
||||||
|
|
||||||
|
This is useful when a proxied request has failed, and we need to
|
||||||
|
decide how to map the failure onto a matrix error to send back to the
|
||||||
|
client.
|
||||||
|
|
||||||
|
An attempt is made to parse the body of the http response as a matrix
|
||||||
|
error. If that succeeds, the errcode and error message from the body
|
||||||
|
are used as the errcode and error message in the new synapse error.
|
||||||
|
|
||||||
|
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
||||||
|
set to the reason code from the HTTP response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
err (HttpResponseException):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SynapseError:
|
||||||
|
"""
|
||||||
|
# try to parse the body as json, to get better errcode/msg, but
|
||||||
|
# default to M_UNKNOWN with the HTTP status as the error text
|
||||||
|
try:
|
||||||
|
j = json.loads(err.response)
|
||||||
|
except ValueError:
|
||||||
|
j = {}
|
||||||
|
errcode = j.get('errcode', Codes.UNKNOWN)
|
||||||
|
errmsg = j.get('error', err.msg)
|
||||||
|
|
||||||
|
res = SynapseError(err.code, errmsg, errcode)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class RegistrationError(SynapseError):
|
class RegistrationError(SynapseError):
|
||||||
"""An error raised when a registration event fails."""
|
"""An error raised when a registration event fails."""
|
||||||
|
@ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError):
|
||||||
|
|
||||||
class NotFoundError(SynapseError):
|
class NotFoundError(SynapseError):
|
||||||
"""An error indicating we can't find the thing you asked for"""
|
"""An error indicating we can't find the thing you asked for"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
|
||||||
if "errcode" not in kwargs:
|
|
||||||
kwargs["errcode"] = Codes.NOT_FOUND
|
|
||||||
super(NotFoundError, self).__init__(
|
super(NotFoundError, self).__init__(
|
||||||
404,
|
404,
|
||||||
"Not found",
|
msg,
|
||||||
**kwargs
|
errcode=errcode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,7 +213,6 @@ class LimitExceededError(SynapseError):
|
||||||
errcode=Codes.LIMIT_EXCEEDED):
|
errcode=Codes.LIMIT_EXCEEDED):
|
||||||
super(LimitExceededError, self).__init__(code, msg, errcode)
|
super(LimitExceededError, self).__init__(code, msg, errcode)
|
||||||
self.retry_after_ms = retry_after_ms
|
self.retry_after_ms = retry_after_ms
|
||||||
self.response_code_message = "Too Many Requests"
|
|
||||||
|
|
||||||
def error_dict(self):
|
def error_dict(self):
|
||||||
return cs_error(
|
return cs_error(
|
||||||
|
@ -243,6 +282,19 @@ class FederationError(RuntimeError):
|
||||||
|
|
||||||
|
|
||||||
class HttpResponseException(CodeMessageException):
|
class HttpResponseException(CodeMessageException):
|
||||||
|
"""
|
||||||
|
Represents an HTTP-level failure of an outbound request
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
response (str): body of response
|
||||||
|
"""
|
||||||
def __init__(self, code, msg, response):
|
def __init__(self, code, msg, response):
|
||||||
self.response = response
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (int): HTTP status code
|
||||||
|
msg (str): reason phrase from HTTP response status line
|
||||||
|
response (str): body of response
|
||||||
|
"""
|
||||||
super(HttpResponseException, self).__init__(code, msg)
|
super(HttpResponseException, self).__init__(code, msg)
|
||||||
|
self.response = response
|
||||||
|
|
|
@ -13,11 +13,174 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.types import UserID, RoomID
|
from synapse.types import UserID, RoomID
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
import jsonschema
|
||||||
|
from jsonschema import FormatChecker
|
||||||
|
|
||||||
|
FILTER_SCHEMA = {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"limit": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"senders": {
|
||||||
|
"$ref": "#/definitions/user_id_array"
|
||||||
|
},
|
||||||
|
"not_senders": {
|
||||||
|
"$ref": "#/definitions/user_id_array"
|
||||||
|
},
|
||||||
|
# TODO: We don't limit event type values but we probably should...
|
||||||
|
# check types are valid event types
|
||||||
|
"types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"not_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ROOM_FILTER_SCHEMA = {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"not_rooms": {
|
||||||
|
"$ref": "#/definitions/room_id_array"
|
||||||
|
},
|
||||||
|
"rooms": {
|
||||||
|
"$ref": "#/definitions/room_id_array"
|
||||||
|
},
|
||||||
|
"ephemeral": {
|
||||||
|
"$ref": "#/definitions/room_event_filter"
|
||||||
|
},
|
||||||
|
"include_leave": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"$ref": "#/definitions/room_event_filter"
|
||||||
|
},
|
||||||
|
"timeline": {
|
||||||
|
"$ref": "#/definitions/room_event_filter"
|
||||||
|
},
|
||||||
|
"account_data": {
|
||||||
|
"$ref": "#/definitions/room_event_filter"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ROOM_EVENT_FILTER_SCHEMA = {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"limit": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"senders": {
|
||||||
|
"$ref": "#/definitions/user_id_array"
|
||||||
|
},
|
||||||
|
"not_senders": {
|
||||||
|
"$ref": "#/definitions/user_id_array"
|
||||||
|
},
|
||||||
|
"types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"not_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rooms": {
|
||||||
|
"$ref": "#/definitions/room_id_array"
|
||||||
|
},
|
||||||
|
"not_rooms": {
|
||||||
|
"$ref": "#/definitions/room_id_array"
|
||||||
|
},
|
||||||
|
"contains_url": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
USER_ID_ARRAY_SCHEMA = {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "matrix_user_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ROOM_ID_ARRAY_SCHEMA = {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "matrix_room_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
USER_FILTER_SCHEMA = {
|
||||||
|
"$schema": "http://json-schema.org/draft-04/schema#",
|
||||||
|
"description": "schema for a Sync filter",
|
||||||
|
"type": "object",
|
||||||
|
"definitions": {
|
||||||
|
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
|
||||||
|
"user_id_array": USER_ID_ARRAY_SCHEMA,
|
||||||
|
"filter": FILTER_SCHEMA,
|
||||||
|
"room_filter": ROOM_FILTER_SCHEMA,
|
||||||
|
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
|
||||||
|
},
|
||||||
|
"properties": {
|
||||||
|
"presence": {
|
||||||
|
"$ref": "#/definitions/filter"
|
||||||
|
},
|
||||||
|
"account_data": {
|
||||||
|
"$ref": "#/definitions/filter"
|
||||||
|
},
|
||||||
|
"room": {
|
||||||
|
"$ref": "#/definitions/room_filter"
|
||||||
|
},
|
||||||
|
"event_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["client", "federation"]
|
||||||
|
},
|
||||||
|
"event_fields": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
# Don't allow '\\' in event field filters. This makes matching
|
||||||
|
# events a lot easier as we can then use a negative lookbehind
|
||||||
|
# assertion to split '\.' If we allowed \\ then it would
|
||||||
|
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
||||||
|
"pattern": "^((?!\\\).)*$"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@FormatChecker.cls_checks('matrix_room_id')
|
||||||
|
def matrix_room_id_validator(room_id_str):
|
||||||
|
return RoomID.from_string(room_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
@FormatChecker.cls_checks('matrix_user_id')
|
||||||
|
def matrix_user_id_validator(user_id_str):
|
||||||
|
return UserID.from_string(user_id_str)
|
||||||
|
|
||||||
|
|
||||||
class Filtering(object):
|
class Filtering(object):
|
||||||
|
@ -52,98 +215,11 @@ class Filtering(object):
|
||||||
# NB: Filters are the complete json blobs. "Definitions" are an
|
# NB: Filters are the complete json blobs. "Definitions" are an
|
||||||
# individual top-level key e.g. public_user_data. Filters are made of
|
# individual top-level key e.g. public_user_data. Filters are made of
|
||||||
# many definitions.
|
# many definitions.
|
||||||
|
try:
|
||||||
top_level_definitions = [
|
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
|
||||||
"presence", "account_data"
|
format_checker=FormatChecker())
|
||||||
]
|
except jsonschema.ValidationError as e:
|
||||||
|
raise SynapseError(400, e.message)
|
||||||
room_level_definitions = [
|
|
||||||
"state", "timeline", "ephemeral", "account_data"
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in top_level_definitions:
|
|
||||||
if key in user_filter_json:
|
|
||||||
self._check_definition(user_filter_json[key])
|
|
||||||
|
|
||||||
if "room" in user_filter_json:
|
|
||||||
self._check_definition_room_lists(user_filter_json["room"])
|
|
||||||
for key in room_level_definitions:
|
|
||||||
if key in user_filter_json["room"]:
|
|
||||||
self._check_definition(user_filter_json["room"][key])
|
|
||||||
|
|
||||||
if "event_fields" in user_filter_json:
|
|
||||||
if type(user_filter_json["event_fields"]) != list:
|
|
||||||
raise SynapseError(400, "event_fields must be a list of strings")
|
|
||||||
for field in user_filter_json["event_fields"]:
|
|
||||||
if not isinstance(field, basestring):
|
|
||||||
raise SynapseError(400, "Event field must be a string")
|
|
||||||
# Don't allow '\\' in event field filters. This makes matching
|
|
||||||
# events a lot easier as we can then use a negative lookbehind
|
|
||||||
# assertion to split '\.' If we allowed \\ then it would
|
|
||||||
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
|
||||||
if r'\\' in field:
|
|
||||||
raise SynapseError(
|
|
||||||
400, r'The escape character \ cannot itself be escaped'
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_definition_room_lists(self, definition):
|
|
||||||
"""Check that "rooms" and "not_rooms" are lists of room ids if they
|
|
||||||
are present
|
|
||||||
|
|
||||||
Args:
|
|
||||||
definition(dict): The filter definition
|
|
||||||
Raises:
|
|
||||||
SynapseError: If there was a problem with this definition.
|
|
||||||
"""
|
|
||||||
# check rooms are valid room IDs
|
|
||||||
room_id_keys = ["rooms", "not_rooms"]
|
|
||||||
for key in room_id_keys:
|
|
||||||
if key in definition:
|
|
||||||
if type(definition[key]) != list:
|
|
||||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
|
||||||
for room_id in definition[key]:
|
|
||||||
RoomID.from_string(room_id)
|
|
||||||
|
|
||||||
def _check_definition(self, definition):
|
|
||||||
"""Check if the provided definition is valid.
|
|
||||||
|
|
||||||
This inspects not only the types but also the values to make sure they
|
|
||||||
make sense.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
definition(dict): The filter definition
|
|
||||||
Raises:
|
|
||||||
SynapseError: If there was a problem with this definition.
|
|
||||||
"""
|
|
||||||
# NB: Filters are the complete json blobs. "Definitions" are an
|
|
||||||
# individual top-level key e.g. public_user_data. Filters are made of
|
|
||||||
# many definitions.
|
|
||||||
if type(definition) != dict:
|
|
||||||
raise SynapseError(
|
|
||||||
400, "Expected JSON object, not %s" % (definition,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._check_definition_room_lists(definition)
|
|
||||||
|
|
||||||
# check senders are valid user IDs
|
|
||||||
user_id_keys = ["senders", "not_senders"]
|
|
||||||
for key in user_id_keys:
|
|
||||||
if key in definition:
|
|
||||||
if type(definition[key]) != list:
|
|
||||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
|
||||||
for user_id in definition[key]:
|
|
||||||
UserID.from_string(user_id)
|
|
||||||
|
|
||||||
# TODO: We don't limit event type values but we probably should...
|
|
||||||
# check types are valid event types
|
|
||||||
event_keys = ["types", "not_types"]
|
|
||||||
for key in event_keys:
|
|
||||||
if key in definition:
|
|
||||||
if type(definition[key]) != list:
|
|
||||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
|
||||||
for event_type in definition[key]:
|
|
||||||
if not isinstance(event_type, basestring):
|
|
||||||
raise SynapseError(400, "Event type should be a string")
|
|
||||||
|
|
||||||
|
|
||||||
class FilterCollection(object):
|
class FilterCollection(object):
|
||||||
|
@ -253,19 +329,35 @@ class Filter(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the event matches
|
bool: True if the event matches
|
||||||
"""
|
"""
|
||||||
|
# We usually get the full "events" as dictionaries coming through,
|
||||||
|
# except for presence which actually gets passed around as its own
|
||||||
|
# namedtuple type.
|
||||||
|
if isinstance(event, UserPresenceState):
|
||||||
|
sender = event.user_id
|
||||||
|
room_id = None
|
||||||
|
ev_type = "m.presence"
|
||||||
|
is_url = False
|
||||||
|
else:
|
||||||
sender = event.get("sender", None)
|
sender = event.get("sender", None)
|
||||||
if not sender:
|
if not sender:
|
||||||
# Presence events have their 'sender' in content.user_id
|
# Presence events had their 'sender' in content.user_id, but are
|
||||||
|
# now handled above. We don't know if anything else uses this
|
||||||
|
# form. TODO: Check this and probably remove it.
|
||||||
content = event.get("content")
|
content = event.get("content")
|
||||||
# account_data has been allowed to have non-dict content, so check type first
|
# account_data has been allowed to have non-dict content, so
|
||||||
|
# check type first
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
sender = content.get("user_id")
|
sender = content.get("user_id")
|
||||||
|
|
||||||
|
room_id = event.get("room_id", None)
|
||||||
|
ev_type = event.get("type", None)
|
||||||
|
is_url = "url" in event.get("content", {})
|
||||||
|
|
||||||
return self.check_fields(
|
return self.check_fields(
|
||||||
event.get("room_id", None),
|
room_id,
|
||||||
sender,
|
sender,
|
||||||
event.get("type", None),
|
ev_type,
|
||||||
"url" in event.get("content", {})
|
is_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||||
|
|
|
@ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -157,7 +157,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.appservice"
|
assert config.worker_app == "synapse.app.appservice"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -187,7 +187,11 @@ def start(config_options):
|
||||||
ps.start_listening(config.worker_listeners)
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
|
||||||
DirectoryStore,
|
DirectoryStore,
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
|
TransactionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||||
):
|
):
|
||||||
|
@ -171,7 +173,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.client_reader"
|
assert config.worker_app == "synapse.app.client_reader"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -193,7 +195,11 @@ def start(config_options):
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -31,7 +31,7 @@ from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -162,7 +162,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.federation_reader"
|
assert config.worker_app == "synapse.app.federation_reader"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -184,7 +184,11 @@ def start(config_options):
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -35,7 +35,7 @@ from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -160,7 +160,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.federation_sender"
|
assert config.worker_app == "synapse.app.federation_sender"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -193,7 +193,11 @@ def start(config_options):
|
||||||
ps.start_listening(config.worker_listeners)
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -20,6 +20,8 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import synapse.config.logger
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
|
|
||||||
from synapse.python_dependencies import (
|
from synapse.python_dependencies import (
|
||||||
|
@ -50,7 +52,7 @@ from synapse.api.urls import (
|
||||||
)
|
)
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.metrics import register_memory_metrics, get_metrics_for
|
from synapse.metrics import register_memory_metrics, get_metrics_for
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||||
|
@ -286,7 +288,7 @@ def setup(config_options):
|
||||||
# generating config files and shouldn't try to continue.
|
# generating config files and shouldn't try to continue.
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
config.setup_logging()
|
synapse.config.logger.setup_logging(config, use_worker_options=False)
|
||||||
|
|
||||||
# check any extra requirements we have now we have a config
|
# check any extra requirements we have now we have a config
|
||||||
check_requirements(config)
|
check_requirements(config)
|
||||||
|
@ -454,7 +456,12 @@ def run(hs):
|
||||||
def in_thread():
|
def in_thread():
|
||||||
# Uncomment to enable tracing of log context changes.
|
# Uncomment to enable tracing of log context changes.
|
||||||
# sys.settrace(logcontext_tracer)
|
# sys.settrace(logcontext_tracer)
|
||||||
with LoggingContext("run"):
|
|
||||||
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
change_resource_limit(hs.config.soft_file_limit)
|
change_resource_limit(hs.config.soft_file_limit)
|
||||||
if hs.config.gc_thresholds:
|
if hs.config.gc_thresholds:
|
||||||
gc.set_threshold(*hs.config.gc_thresholds)
|
gc.set_threshold(*hs.config.gc_thresholds)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -32,7 +33,7 @@ from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.media_repository import MediaRepositoryStore
|
from synapse.storage.media_repository import MediaRepositoryStore
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
|
||||||
class MediaRepositorySlavedStore(
|
class MediaRepositorySlavedStore(
|
||||||
SlavedApplicationServiceStore,
|
SlavedApplicationServiceStore,
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
|
TransactionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
ClientIpStore,
|
ClientIpStore,
|
||||||
|
@ -168,7 +170,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.media_repository"
|
assert config.worker_app == "synapse.app.media_repository"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -190,7 +192,11 @@ def start(config_options):
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -31,7 +31,8 @@ from synapse.storage.engines import create_engine
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
||||||
|
PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -245,7 +246,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.pusher"
|
assert config.worker_app == "synapse.app.pusher"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -275,7 +276,11 @@ def start(config_options):
|
||||||
ps.start_listening(config.worker_listeners)
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.logger import setup_logging
|
from synapse.config.logger import setup_logging
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.http.site import SynapseSite
|
from synapse.http.site import SynapseSite
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
||||||
|
PreserveLoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -399,8 +399,7 @@ class SynchrotronServer(HomeServer):
|
||||||
position = row[position_index]
|
position = row[position_index]
|
||||||
user_id = row[user_index]
|
user_id = row[user_index]
|
||||||
|
|
||||||
rooms = yield store.get_rooms_for_user(user_id)
|
room_ids = yield store.get_rooms_for_user(user_id)
|
||||||
room_ids = [r.room_id for r in rooms]
|
|
||||||
|
|
||||||
notifier.on_new_event(
|
notifier.on_new_event(
|
||||||
"device_list_key", position, rooms=room_ids,
|
"device_list_key", position, rooms=room_ids,
|
||||||
|
@ -411,11 +410,16 @@ class SynchrotronServer(HomeServer):
|
||||||
stream = result.get("events")
|
stream = result.get("events")
|
||||||
if stream:
|
if stream:
|
||||||
max_position = stream["position"]
|
max_position = stream["position"]
|
||||||
|
|
||||||
|
event_map = yield store.get_events([row[1] for row in stream["rows"]])
|
||||||
|
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
position = row[0]
|
position = row[0]
|
||||||
internal = json.loads(row[1])
|
event_id = row[1]
|
||||||
event_json = json.loads(row[2])
|
event = event_map.get(event_id, None)
|
||||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
if not event:
|
||||||
|
continue
|
||||||
|
|
||||||
extra_users = ()
|
extra_users = ()
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
extra_users = (event.state_key,)
|
extra_users = (event.state_key,)
|
||||||
|
@ -478,7 +482,7 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.synchrotron"
|
assert config.worker_app == "synapse.app.synchrotron"
|
||||||
|
|
||||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
@ -497,7 +501,11 @@ def start(config_options):
|
||||||
ss.start_listening(config.worker_listeners)
|
ss.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
with LoggingContext("run"):
|
# make sure that we run the reactor with the sentinel log context,
|
||||||
|
# otherwise other PreserveLoggingContext instances will get confused
|
||||||
|
# and complain when they see the logcontext arbitrarily swapping
|
||||||
|
# between the sentinel and `run` logcontexts.
|
||||||
|
with PreserveLoggingContext():
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
change_resource_limit(config.soft_file_limit)
|
change_resource_limit(config.soft_file_limit)
|
||||||
if config.gc_thresholds:
|
if config.gc_thresholds:
|
||||||
|
|
|
@ -23,14 +23,27 @@ import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
import yaml
|
||||||
|
import errno
|
||||||
|
import time
|
||||||
|
|
||||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||||
|
|
||||||
GREEN = "\x1b[1;32m"
|
GREEN = "\x1b[1;32m"
|
||||||
|
YELLOW = "\x1b[1;33m"
|
||||||
RED = "\x1b[1;31m"
|
RED = "\x1b[1;31m"
|
||||||
NORMAL = "\x1b[m"
|
NORMAL = "\x1b[m"
|
||||||
|
|
||||||
|
|
||||||
|
def pid_running(pid):
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0)
|
||||||
|
return True
|
||||||
|
except OSError, err:
|
||||||
|
if err.errno == errno.EPERM:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||||
if colour == NORMAL:
|
if colour == NORMAL:
|
||||||
stream.write(message + "\n")
|
stream.write(message + "\n")
|
||||||
|
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
||||||
stream.write(colour + message + NORMAL + "\n")
|
stream.write(colour + message + NORMAL + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def abort(message, colour=RED, stream=sys.stderr):
|
||||||
|
write(message, colour, stream)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def start(configfile):
|
def start(configfile):
|
||||||
write("Starting ...")
|
write("Starting ...")
|
||||||
args = SYNAPSE
|
args = SYNAPSE
|
||||||
|
@ -45,7 +63,8 @@ def start(configfile):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(args)
|
subprocess.check_call(args)
|
||||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
write("started synapse.app.homeserver(%r)" %
|
||||||
|
(configfile,), colour=GREEN)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
write(
|
write(
|
||||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||||
|
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
|
||||||
def stop(pidfile, app):
|
def stop(pidfile, app):
|
||||||
if os.path.exists(pidfile):
|
if os.path.exists(pidfile):
|
||||||
pid = int(open(pidfile).read())
|
pid = int(open(pidfile).read())
|
||||||
|
try:
|
||||||
os.kill(pid, signal.SIGTERM)
|
os.kill(pid, signal.SIGTERM)
|
||||||
write("stopped %s" % (app,), colour=GREEN)
|
write("stopped %s" % (app,), colour=GREEN)
|
||||||
|
except OSError, err:
|
||||||
|
if err.errno == errno.ESRCH:
|
||||||
|
write("%s not running" % (app,), colour=YELLOW)
|
||||||
|
elif err.errno == errno.EPERM:
|
||||||
|
abort("Cannot stop %s: Operation not permitted" % (app,))
|
||||||
|
else:
|
||||||
|
abort("Cannot stop %s: Unknown error" % (app,))
|
||||||
|
|
||||||
|
|
||||||
Worker = collections.namedtuple("Worker", [
|
Worker = collections.namedtuple("Worker", [
|
||||||
|
@ -190,7 +217,19 @@ def main():
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
stop(pidfile, "synapse.app.homeserver")
|
stop(pidfile, "synapse.app.homeserver")
|
||||||
|
|
||||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
# Wait for synapse to actually shutdown before starting it again
|
||||||
|
if action == "restart":
|
||||||
|
running_pids = []
|
||||||
|
if start_stop_synapse and os.path.exists(pidfile):
|
||||||
|
running_pids.append(int(open(pidfile).read()))
|
||||||
|
for worker in workers:
|
||||||
|
if os.path.exists(worker.pidfile):
|
||||||
|
running_pids.append(int(open(worker.pidfile).read()))
|
||||||
|
if len(running_pids) > 0:
|
||||||
|
write("Waiting for process to exit before restarting...")
|
||||||
|
for running_pid in running_pids:
|
||||||
|
while pid_running(running_pid):
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
if action == "start" or action == "restart":
|
if action == "start" or action == "restart":
|
||||||
if start_stop_synapse:
|
if start_stop_synapse:
|
||||||
|
|
|
@ -45,7 +45,6 @@ handlers:
|
||||||
maxBytes: 104857600
|
maxBytes: 104857600
|
||||||
backupCount: 10
|
backupCount: 10
|
||||||
filters: [context]
|
filters: [context]
|
||||||
level: INFO
|
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
formatter: precise
|
formatter: precise
|
||||||
|
@ -56,6 +55,8 @@ loggers:
|
||||||
level: INFO
|
level: INFO
|
||||||
|
|
||||||
synapse.storage.SQL:
|
synapse.storage.SQL:
|
||||||
|
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||||
|
# information such as access tokens.
|
||||||
level: INFO
|
level: INFO
|
||||||
|
|
||||||
root:
|
root:
|
||||||
|
@ -68,6 +69,7 @@ class LoggingConfig(Config):
|
||||||
|
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.verbosity = config.get("verbose", 0)
|
self.verbosity = config.get("verbose", 0)
|
||||||
|
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
|
||||||
self.log_config = self.abspath(config.get("log_config"))
|
self.log_config = self.abspath(config.get("log_config"))
|
||||||
self.log_file = self.abspath(config.get("log_file"))
|
self.log_file = self.abspath(config.get("log_file"))
|
||||||
|
|
||||||
|
@ -77,10 +79,10 @@ class LoggingConfig(Config):
|
||||||
os.path.join(config_dir_path, server_name + ".log.config")
|
os.path.join(config_dir_path, server_name + ".log.config")
|
||||||
)
|
)
|
||||||
return """
|
return """
|
||||||
# Logging verbosity level.
|
# Logging verbosity level. Ignored if log_config is specified.
|
||||||
verbose: 0
|
verbose: 0
|
||||||
|
|
||||||
# File to write logging to
|
# File to write logging to. Ignored if log_config is specified.
|
||||||
log_file: "%(log_file)s"
|
log_file: "%(log_file)s"
|
||||||
|
|
||||||
# A yaml python logging config file
|
# A yaml python logging config file
|
||||||
|
@ -90,6 +92,8 @@ class LoggingConfig(Config):
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
if args.verbose is not None:
|
if args.verbose is not None:
|
||||||
self.verbosity = args.verbose
|
self.verbosity = args.verbose
|
||||||
|
if args.no_redirect_stdio is not None:
|
||||||
|
self.no_redirect_stdio = args.no_redirect_stdio
|
||||||
if args.log_config is not None:
|
if args.log_config is not None:
|
||||||
self.log_config = args.log_config
|
self.log_config = args.log_config
|
||||||
if args.log_file is not None:
|
if args.log_file is not None:
|
||||||
|
@ -99,16 +103,22 @@ class LoggingConfig(Config):
|
||||||
logging_group = parser.add_argument_group("logging")
|
logging_group = parser.add_argument_group("logging")
|
||||||
logging_group.add_argument(
|
logging_group.add_argument(
|
||||||
'-v', '--verbose', dest="verbose", action='count',
|
'-v', '--verbose', dest="verbose", action='count',
|
||||||
help="The verbosity level."
|
help="The verbosity level. Specify multiple times to increase "
|
||||||
|
"verbosity. (Ignored if --log-config is specified.)"
|
||||||
)
|
)
|
||||||
logging_group.add_argument(
|
logging_group.add_argument(
|
||||||
'-f', '--log-file', dest="log_file",
|
'-f', '--log-file', dest="log_file",
|
||||||
help="File to log to."
|
help="File to log to. (Ignored if --log-config is specified.)"
|
||||||
)
|
)
|
||||||
logging_group.add_argument(
|
logging_group.add_argument(
|
||||||
'--log-config', dest="log_config", default=None,
|
'--log-config', dest="log_config", default=None,
|
||||||
help="Python logging config file"
|
help="Python logging config file"
|
||||||
)
|
)
|
||||||
|
logging_group.add_argument(
|
||||||
|
'-n', '--no-redirect-stdio',
|
||||||
|
action='store_true', default=None,
|
||||||
|
help="Do not redirect stdout/stderr to the log"
|
||||||
|
)
|
||||||
|
|
||||||
def generate_files(self, config):
|
def generate_files(self, config):
|
||||||
log_config = config.get("log_config")
|
log_config = config.get("log_config")
|
||||||
|
@ -118,11 +128,22 @@ class LoggingConfig(Config):
|
||||||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_logging(self):
|
|
||||||
setup_logging(self.log_config, self.log_file, self.verbosity)
|
|
||||||
|
|
||||||
|
def setup_logging(config, use_worker_options=False):
|
||||||
|
""" Set up python logging
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (LoggingConfig | synapse.config.workers.WorkerConfig):
|
||||||
|
configuration data
|
||||||
|
|
||||||
|
use_worker_options (bool): True to use 'worker_log_config' and
|
||||||
|
'worker_log_file' options instead of 'log_config' and 'log_file'.
|
||||||
|
"""
|
||||||
|
log_config = (config.worker_log_config if use_worker_options
|
||||||
|
else config.log_config)
|
||||||
|
log_file = (config.worker_log_file if use_worker_options
|
||||||
|
else config.log_file)
|
||||||
|
|
||||||
def setup_logging(log_config=None, log_file=None, verbosity=None):
|
|
||||||
log_format = (
|
log_format = (
|
||||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||||
" - %(message)s"
|
" - %(message)s"
|
||||||
|
@ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||||
|
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
level_for_storage = logging.INFO
|
level_for_storage = logging.INFO
|
||||||
if verbosity:
|
if config.verbosity:
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
if verbosity > 1:
|
if config.verbosity > 1:
|
||||||
level_for_storage = logging.DEBUG
|
level_for_storage = logging.DEBUG
|
||||||
|
|
||||||
# FIXME: we need a logging.WARN for a -q quiet option
|
# FIXME: we need a logging.WARN for a -q quiet option
|
||||||
|
@ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||||
logger.info("Closing log file due to SIGHUP")
|
logger.info("Closing log file due to SIGHUP")
|
||||||
handler.doRollover()
|
handler.doRollover()
|
||||||
logger.info("Opened new log file due to SIGHUP")
|
logger.info("Opened new log file due to SIGHUP")
|
||||||
|
|
||||||
# TODO(paul): obviously this is a terrible mechanism for
|
|
||||||
# stealing SIGHUP, because it means no other part of synapse
|
|
||||||
# can use it instead. If we want to catch SIGHUP anywhere
|
|
||||||
# else as well, I'd suggest we find a nicer way to broadcast
|
|
||||||
# it around.
|
|
||||||
if getattr(signal, "SIGHUP"):
|
|
||||||
signal.signal(signal.SIGHUP, sighup)
|
|
||||||
else:
|
else:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
@ -169,9 +182,26 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||||
|
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
else:
|
else:
|
||||||
|
def load_log_config():
|
||||||
with open(log_config, 'r') as f:
|
with open(log_config, 'r') as f:
|
||||||
logging.config.dictConfig(yaml.load(f))
|
logging.config.dictConfig(yaml.load(f))
|
||||||
|
|
||||||
|
def sighup(signum, stack):
|
||||||
|
# it might be better to use a file watcher or something for this.
|
||||||
|
logging.info("Reloading log config from %s due to SIGHUP",
|
||||||
|
log_config)
|
||||||
|
load_log_config()
|
||||||
|
|
||||||
|
load_log_config()
|
||||||
|
|
||||||
|
# TODO(paul): obviously this is a terrible mechanism for
|
||||||
|
# stealing SIGHUP, because it means no other part of synapse
|
||||||
|
# can use it instead. If we want to catch SIGHUP anywhere
|
||||||
|
# else as well, I'd suggest we find a nicer way to broadcast
|
||||||
|
# it around.
|
||||||
|
if getattr(signal, "SIGHUP"):
|
||||||
|
signal.signal(signal.SIGHUP, sighup)
|
||||||
|
|
||||||
# It's critical to point twisted's internal logging somewhere, otherwise it
|
# It's critical to point twisted's internal logging somewhere, otherwise it
|
||||||
# stacks up and leaks kup to 64K object;
|
# stacks up and leaks kup to 64K object;
|
||||||
# see: https://twistedmatrix.com/trac/ticket/8164
|
# see: https://twistedmatrix.com/trac/ticket/8164
|
||||||
|
@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||||
#
|
#
|
||||||
# However this may not be too much of a problem if we are just writing to a file.
|
# However this may not be too much of a problem if we are just writing to a file.
|
||||||
observer = STDLibLogObserver()
|
observer = STDLibLogObserver()
|
||||||
globalLogBeginner.beginLoggingTo([observer])
|
globalLogBeginner.beginLoggingTo(
|
||||||
|
[observer],
|
||||||
|
redirectStandardIO=not config.no_redirect_stdio,
|
||||||
|
)
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
from synapse.crypto.keyclient import fetch_server_key
|
from synapse.crypto.keyclient import fetch_server_key
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
|
@ -96,10 +95,11 @@ class Keyring(object):
|
||||||
verify_requests = []
|
verify_requests = []
|
||||||
|
|
||||||
for server_name, json_object in server_and_json:
|
for server_name, json_object in server_and_json:
|
||||||
logger.debug("Verifying for %s", server_name)
|
|
||||||
|
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
|
logger.warn("Request from %s: no supported signature keys",
|
||||||
|
server_name)
|
||||||
deferred = defer.fail(SynapseError(
|
deferred = defer.fail(SynapseError(
|
||||||
400,
|
400,
|
||||||
"Not signed with a supported algorithm",
|
"Not signed with a supported algorithm",
|
||||||
|
@ -108,6 +108,9 @@ class Keyring(object):
|
||||||
else:
|
else:
|
||||||
deferred = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
|
|
||||||
|
logger.debug("Verifying for %s with key_ids %s",
|
||||||
|
server_name, key_ids)
|
||||||
|
|
||||||
verify_request = VerifyKeyRequest(
|
verify_request = VerifyKeyRequest(
|
||||||
server_name, key_ids, json_object, deferred
|
server_name, key_ids, json_object, deferred
|
||||||
)
|
)
|
||||||
|
@ -142,6 +145,9 @@ class Keyring(object):
|
||||||
|
|
||||||
json_object = verify_request.json_object
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
|
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
||||||
|
key_id, verify_key.alg, verify_key.version, server_name,
|
||||||
|
))
|
||||||
try:
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
except:
|
except:
|
||||||
|
@ -231,8 +237,14 @@ class Keyring(object):
|
||||||
d.addBoth(rm, server_name)
|
d.addBoth(rm, server_name)
|
||||||
|
|
||||||
def get_server_verify_keys(self, verify_requests):
|
def get_server_verify_keys(self, verify_requests):
|
||||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
"""Tries to find at least one key for each verify request
|
||||||
each group.
|
|
||||||
|
For each verify_request, verify_request.deferred is called back with
|
||||||
|
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
||||||
|
with a SynapseError if none of the keys are found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verify_requests (list[VerifyKeyRequest]): list of verify requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# These are functions that produce keys given a list of key ids
|
# These are functions that produce keys given a list of key ids
|
||||||
|
@ -245,8 +257,11 @@ class Keyring(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_iterations():
|
def do_iterations():
|
||||||
with Measure(self.clock, "get_server_verify_keys"):
|
with Measure(self.clock, "get_server_verify_keys"):
|
||||||
|
# dict[str, dict[str, VerifyKey]]: results so far.
|
||||||
|
# map server_name -> key_id -> VerifyKey
|
||||||
merged_results = {}
|
merged_results = {}
|
||||||
|
|
||||||
|
# dict[str, set(str)]: keys to fetch for each server
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
|
@ -308,6 +323,16 @@ class Keyring(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name_and_key_ids (list[(str, iterable[str])]):
|
||||||
|
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
|
||||||
|
server_name -> key_id -> VerifyKey
|
||||||
|
"""
|
||||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.store.get_server_verify_keys)(
|
preserve_fn(self.store.get_server_verify_keys)(
|
||||||
|
@ -356,12 +381,6 @@ class Keyring(object):
|
||||||
def get_keys_from_server(self, server_name_and_key_ids):
|
def get_keys_from_server(self, server_name_and_key_ids):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_key(server_name, key_ids):
|
def get_key(server_name, key_ids):
|
||||||
limiter = yield get_retry_limiter(
|
|
||||||
server_name,
|
|
||||||
self.clock,
|
|
||||||
self.store,
|
|
||||||
)
|
|
||||||
with limiter:
|
|
||||||
keys = None
|
keys = None
|
||||||
try:
|
try:
|
||||||
keys = yield self.get_server_verify_key_v2_direct(
|
keys = yield self.get_server_verify_key_v2_direct(
|
||||||
|
|
|
@ -15,6 +15,32 @@
|
||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
|
"""
|
||||||
|
Attributes:
|
||||||
|
current_state_ids (dict[(str, str), str]):
|
||||||
|
The current state map including the current event.
|
||||||
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
|
prev_state_ids (dict[(str, str), str]):
|
||||||
|
The current state map excluding the current event.
|
||||||
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
|
state_group (int): state group id
|
||||||
|
rejected (bool|str): A rejection reason if the event was rejected, else
|
||||||
|
False
|
||||||
|
|
||||||
|
push_actions (list[(str, list[object])]): list of (user_id, actions)
|
||||||
|
tuples
|
||||||
|
|
||||||
|
prev_group (int): Previously persisted state group. ``None`` for an
|
||||||
|
outlier.
|
||||||
|
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||||
|
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||||
|
|
||||||
|
prev_state_events (?): XXX: is this ever set to anything other than
|
||||||
|
the empty list?
|
||||||
|
"""
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"current_state_ids",
|
"current_state_ids",
|
||||||
"prev_state_ids",
|
"prev_state_ids",
|
||||||
|
|
|
@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.events import FrozenEvent, builder
|
from synapse.events import FrozenEvent, builder
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def make_query(self, destination, query_type, args,
|
def make_query(self, destination, query_type, args,
|
||||||
retry_on_dns_fail=False):
|
retry_on_dns_fail=False, ignore_backoff=False):
|
||||||
"""Sends a federation Query to a remote homeserver of the given type
|
"""Sends a federation Query to a remote homeserver of the given type
|
||||||
and arguments.
|
and arguments.
|
||||||
|
|
||||||
|
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
|
||||||
handler name used in register_query_handler().
|
handler name used in register_query_handler().
|
||||||
args (dict): Mapping of strings to strings containing the details
|
args (dict): Mapping of strings to strings containing the details
|
||||||
of the query request.
|
of the query request.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data
|
||||||
|
and try the request anyway.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a Deferred which will eventually yield a JSON object from the
|
a Deferred which will eventually yield a JSON object from the
|
||||||
|
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
|
||||||
sent_queries_counter.inc(query_type)
|
sent_queries_counter.inc(query_type)
|
||||||
|
|
||||||
return self.transport_layer.make_query(
|
return self.transport_layer.make_query(
|
||||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -234,13 +237,6 @@ class FederationClient(FederationBase):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
limiter = yield get_retry_limiter(
|
|
||||||
destination,
|
|
||||||
self._clock,
|
|
||||||
self.store,
|
|
||||||
)
|
|
||||||
|
|
||||||
with limiter:
|
|
||||||
transaction_data = yield self.transport_layer.get_event(
|
transaction_data = yield self.transport_layer.get_event(
|
||||||
destination, event_id, timeout=timeout,
|
destination, event_id, timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,7 +52,6 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
|
||||||
self._server_linearizer = Linearizer("fed_server")
|
self._server_linearizer = Linearizer("fed_server")
|
||||||
|
|
||||||
# We cache responses to state queries, as they take a while and often
|
# We cache responses to state queries, as they take a while and often
|
||||||
|
@ -147,11 +146,15 @@ class FederationServer(FederationBase):
|
||||||
# check that it's actually being sent from a valid destination to
|
# check that it's actually being sent from a valid destination to
|
||||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||||
if transaction.origin != get_domain_from_id(pdu.event_id):
|
if transaction.origin != get_domain_from_id(pdu.event_id):
|
||||||
|
# We continue to accept join events from any server; this is
|
||||||
|
# necessary for the federation join dance to work correctly.
|
||||||
|
# (When we join over federation, the "helper" server is
|
||||||
|
# responsible for sending out the join event, rather than the
|
||||||
|
# origin. See bug #1893).
|
||||||
if not (
|
if not (
|
||||||
pdu.type == 'm.room.member' and
|
pdu.type == 'm.room.member' and
|
||||||
pdu.content and
|
pdu.content and
|
||||||
pdu.content.get("membership", None) == 'join' and
|
pdu.content.get("membership", None) == 'join'
|
||||||
self.hs.is_mine_id(pdu.state_key)
|
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Discarding PDU %s from invalid origin %s",
|
"Discarding PDU %s from invalid origin %s",
|
||||||
|
@ -165,7 +168,7 @@ class FederationServer(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_pdu(transaction.origin, pdu)
|
yield self._handle_received_pdu(transaction.origin, pdu)
|
||||||
results.append({})
|
results.append({})
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
self.send_failure(e, transaction.origin)
|
self.send_failure(e, transaction.origin)
|
||||||
|
@ -497,27 +500,16 @@ class FederationServer(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
def _handle_received_pdu(self, origin, pdu):
|
||||||
def _handle_new_pdu(self, origin, pdu, get_missing=True):
|
""" Process a PDU received in a federation /send/ transaction.
|
||||||
|
|
||||||
# We reprocess pdus when we have seen them only as outliers
|
Args:
|
||||||
existing = yield self._get_persisted_pdu(
|
origin (str): server which sent the pdu
|
||||||
origin, pdu.event_id, do_auth=False
|
pdu (FrozenEvent): received pdu
|
||||||
)
|
|
||||||
|
|
||||||
# FIXME: Currently we fetch an event again when we already have it
|
|
||||||
# if it has been marked as an outlier.
|
|
||||||
|
|
||||||
already_seen = (
|
|
||||||
existing and (
|
|
||||||
not existing.internal_metadata.is_outlier()
|
|
||||||
or pdu.internal_metadata.is_outlier()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if already_seen:
|
|
||||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
Returns (Deferred): completes with None
|
||||||
|
Raises: FederationError if the signatures / hash do not match
|
||||||
|
"""
|
||||||
# Check signature.
|
# Check signature.
|
||||||
try:
|
try:
|
||||||
pdu = yield self._check_sigs_and_hash(pdu)
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
@ -529,143 +521,7 @@ class FederationServer(FederationBase):
|
||||||
affected=pdu.event_id,
|
affected=pdu.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
state = None
|
yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
|
||||||
|
|
||||||
auth_chain = []
|
|
||||||
|
|
||||||
have_seen = yield self.store.have_events(
|
|
||||||
[ev for ev, _ in pdu.prev_events]
|
|
||||||
)
|
|
||||||
|
|
||||||
fetch_state = False
|
|
||||||
|
|
||||||
# Get missing pdus if necessary.
|
|
||||||
if not pdu.internal_metadata.is_outlier():
|
|
||||||
# We only backfill backwards to the min depth.
|
|
||||||
min_depth = yield self.handler.get_min_depth_for_context(
|
|
||||||
pdu.room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"_handle_new_pdu min_depth for %s: %d",
|
|
||||||
pdu.room_id, min_depth
|
|
||||||
)
|
|
||||||
|
|
||||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
|
||||||
seen = set(have_seen.keys())
|
|
||||||
|
|
||||||
if min_depth and pdu.depth < min_depth:
|
|
||||||
# This is so that we don't notify the user about this
|
|
||||||
# message, to work around the fact that some events will
|
|
||||||
# reference really really old events we really don't want to
|
|
||||||
# send to the clients.
|
|
||||||
pdu.internal_metadata.outlier = True
|
|
||||||
elif min_depth and pdu.depth > min_depth:
|
|
||||||
if get_missing and prevs - seen:
|
|
||||||
# If we're missing stuff, ensure we only fetch stuff one
|
|
||||||
# at a time.
|
|
||||||
logger.info(
|
|
||||||
"Acquiring lock for room %r to fetch %d missing events: %r...",
|
|
||||||
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
|
|
||||||
)
|
|
||||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
|
||||||
logger.info(
|
|
||||||
"Acquired lock for room %r to fetch %d missing events",
|
|
||||||
pdu.room_id, len(prevs - seen),
|
|
||||||
)
|
|
||||||
|
|
||||||
# We recalculate seen, since it may have changed.
|
|
||||||
have_seen = yield self.store.have_events(prevs)
|
|
||||||
seen = set(have_seen.keys())
|
|
||||||
|
|
||||||
if prevs - seen:
|
|
||||||
latest = yield self.store.get_latest_event_ids_in_room(
|
|
||||||
pdu.room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# We add the prev events that we have seen to the latest
|
|
||||||
# list to ensure the remote server doesn't give them to us
|
|
||||||
latest = set(latest)
|
|
||||||
latest |= seen
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Missing %d events for room %r: %r...",
|
|
||||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
|
||||||
)
|
|
||||||
|
|
||||||
# XXX: we set timeout to 10s to help workaround
|
|
||||||
# https://github.com/matrix-org/synapse/issues/1733.
|
|
||||||
# The reason is to avoid holding the linearizer lock
|
|
||||||
# whilst processing inbound /send transactions, causing
|
|
||||||
# FDs to stack up and block other inbound transactions
|
|
||||||
# which empirically can currently take up to 30 minutes.
|
|
||||||
#
|
|
||||||
# N.B. this explicitly disables retry attempts.
|
|
||||||
#
|
|
||||||
# N.B. this also increases our chances of falling back to
|
|
||||||
# fetching fresh state for the room if the missing event
|
|
||||||
# can't be found, which slightly reduces our security.
|
|
||||||
# it may also increase our DAG extremity count for the room,
|
|
||||||
# causing additional state resolution? See #1760.
|
|
||||||
# However, fetching state doesn't hold the linearizer lock
|
|
||||||
# apparently.
|
|
||||||
#
|
|
||||||
# see https://github.com/matrix-org/synapse/pull/1744
|
|
||||||
|
|
||||||
missing_events = yield self.get_missing_events(
|
|
||||||
origin,
|
|
||||||
pdu.room_id,
|
|
||||||
earliest_events_ids=list(latest),
|
|
||||||
latest_events=[pdu],
|
|
||||||
limit=10,
|
|
||||||
min_depth=min_depth,
|
|
||||||
timeout=10000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We want to sort these by depth so we process them and
|
|
||||||
# tell clients about them in order.
|
|
||||||
missing_events.sort(key=lambda x: x.depth)
|
|
||||||
|
|
||||||
for e in missing_events:
|
|
||||||
yield self._handle_new_pdu(
|
|
||||||
origin,
|
|
||||||
e,
|
|
||||||
get_missing=False
|
|
||||||
)
|
|
||||||
|
|
||||||
have_seen = yield self.store.have_events(
|
|
||||||
[ev for ev, _ in pdu.prev_events]
|
|
||||||
)
|
|
||||||
|
|
||||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
|
||||||
seen = set(have_seen.keys())
|
|
||||||
if prevs - seen:
|
|
||||||
logger.info(
|
|
||||||
"Still missing %d events for room %r: %r...",
|
|
||||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
|
||||||
)
|
|
||||||
fetch_state = True
|
|
||||||
|
|
||||||
if fetch_state:
|
|
||||||
# We need to get the state at this event, since we haven't
|
|
||||||
# processed all the prev events.
|
|
||||||
logger.debug(
|
|
||||||
"_handle_new_pdu getting state for %s",
|
|
||||||
pdu.room_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
state, auth_chain = yield self.get_state_for_room(
|
|
||||||
origin, pdu.room_id, pdu.event_id,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
|
||||||
|
|
||||||
yield self.handler.on_receive_pdu(
|
|
||||||
origin,
|
|
||||||
pdu,
|
|
||||||
state=state,
|
|
||||||
auth_chain=auth_chain,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
return "<ReplicationLayer(%s)>" % self.server_name
|
||||||
|
|
|
@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
self.presence_map = {}
|
self.presence_map = {}
|
||||||
self.presence_changed = sorteddict()
|
self.presence_changed = sorteddict()
|
||||||
|
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
|
||||||
else:
|
else:
|
||||||
self.edus[pos] = edu
|
self.edus[pos] = edu
|
||||||
|
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_presence(self, destination, states):
|
def send_presence(self, destination, states):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
|
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
|
||||||
(destination, state.user_id) for state in states
|
(destination, state.user_id) for state in states
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
def send_failure(self, failure, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
|
|
||||||
self.failures[pos] = (destination, str(failure))
|
self.failures[pos] = (destination, str(failure))
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
self.device_messages[pos] = destination
|
self.device_messages[pos] = destination
|
||||||
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
return self.pos - 1
|
return self.pos - 1
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import datetime
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -22,9 +22,7 @@ from .units import Transaction, Edu
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from synapse.util.retryutils import (
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
get_retry_limiter, NotRetryingDestination,
|
|
||||||
)
|
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
|
@ -99,7 +97,12 @@ class TransactionQueue(object):
|
||||||
# destination -> list of tuple(failure, deferred)
|
# destination -> list of tuple(failure, deferred)
|
||||||
self.pending_failures_by_dest = {}
|
self.pending_failures_by_dest = {}
|
||||||
|
|
||||||
|
# destination -> stream_id of last successfully sent to-device message.
|
||||||
|
# NB: may be a long or an int.
|
||||||
self.last_device_stream_id_by_dest = {}
|
self.last_device_stream_id_by_dest = {}
|
||||||
|
|
||||||
|
# destination -> stream_id of last successfully sent device list
|
||||||
|
# update.
|
||||||
self.last_device_list_stream_id_by_dest = {}
|
self.last_device_list_stream_id_by_dest = {}
|
||||||
|
|
||||||
# HACK to get unique tx id
|
# HACK to get unique tx id
|
||||||
|
@ -300,20 +303,20 @@ class TransactionQueue(object):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
pending_pdus = []
|
||||||
try:
|
try:
|
||||||
self.pending_transactions[destination] = 1
|
self.pending_transactions[destination] = 1
|
||||||
|
|
||||||
|
# This will throw if we wouldn't retry. We do this here so we fail
|
||||||
|
# quickly, but we will later check this again in the http client,
|
||||||
|
# hence why we throw the result away.
|
||||||
|
yield get_retry_limiter(destination, self.clock, self.store)
|
||||||
|
|
||||||
# XXX: what's this for?
|
# XXX: what's this for?
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
pending_pdus = []
|
||||||
while True:
|
while True:
|
||||||
limiter = yield get_retry_limiter(
|
|
||||||
destination,
|
|
||||||
self.clock,
|
|
||||||
self.store,
|
|
||||||
backoff_on_404=True, # If we get a 404 the other side has gone
|
|
||||||
)
|
|
||||||
|
|
||||||
device_message_edus, device_stream_id, dev_list_id = (
|
device_message_edus, device_stream_id, dev_list_id = (
|
||||||
yield self._get_new_device_messages(destination)
|
yield self._get_new_device_messages(destination)
|
||||||
)
|
)
|
||||||
|
@ -369,7 +372,6 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
success = yield self._send_new_transaction(
|
success = yield self._send_new_transaction(
|
||||||
destination, pending_pdus, pending_edus, pending_failures,
|
destination, pending_pdus, pending_edus, pending_failures,
|
||||||
limiter=limiter,
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
# Remove the acknowledged device messages from the database
|
# Remove the acknowledged device messages from the database
|
||||||
|
@ -387,12 +389,24 @@ class TransactionQueue(object):
|
||||||
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
|
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
except NotRetryingDestination:
|
except NotRetryingDestination as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TX [%s] not ready for retry yet - "
|
"TX [%s] not ready for retry yet (next retry at %s) - "
|
||||||
"dropping transaction for now",
|
"dropping transaction for now",
|
||||||
destination,
|
destination,
|
||||||
|
datetime.datetime.fromtimestamp(
|
||||||
|
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
"TX [%s] Failed to send transaction: %s",
|
||||||
|
destination,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
for p, _ in pending_pdus:
|
||||||
|
logger.info("Failed to send event %s to %s", p.event_id,
|
||||||
|
destination)
|
||||||
finally:
|
finally:
|
||||||
# We want to be *very* sure we delete this after we stop processing
|
# We want to be *very* sure we delete this after we stop processing
|
||||||
self.pending_transactions.pop(destination, None)
|
self.pending_transactions.pop(destination, None)
|
||||||
|
@ -432,7 +446,7 @@ class TransactionQueue(object):
|
||||||
@measure_func("_send_new_transaction")
|
@measure_func("_send_new_transaction")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||||
pending_failures, limiter):
|
pending_failures):
|
||||||
|
|
||||||
# Sort based on the order field
|
# Sort based on the order field
|
||||||
pending_pdus.sort(key=lambda t: t[1])
|
pending_pdus.sort(key=lambda t: t[1])
|
||||||
|
@ -442,7 +456,6 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||||
|
|
||||||
txn_id = str(self._next_txn_id)
|
txn_id = str(self._next_txn_id)
|
||||||
|
@ -483,7 +496,6 @@ class TransactionQueue(object):
|
||||||
len(failures),
|
len(failures),
|
||||||
)
|
)
|
||||||
|
|
||||||
with limiter:
|
|
||||||
# Actually send the transaction
|
# Actually send the transaction
|
||||||
|
|
||||||
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
||||||
|
@ -543,31 +555,5 @@ class TransactionQueue(object):
|
||||||
"Failed to send event %s to %s", p.event_id, destination
|
"Failed to send event %s to %s", p.event_id, destination
|
||||||
)
|
)
|
||||||
success = False
|
success = False
|
||||||
except RuntimeError as e:
|
|
||||||
# We capture this here as there as nothing actually listens
|
|
||||||
# for this finishing functions deferred.
|
|
||||||
logger.warn(
|
|
||||||
"TX [%s] Problem in _attempt_transaction: %s",
|
|
||||||
destination,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
success = False
|
|
||||||
|
|
||||||
for p in pdus:
|
|
||||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
|
||||||
except Exception as e:
|
|
||||||
# We capture this here as there as nothing actually listens
|
|
||||||
# for this finishing functions deferred.
|
|
||||||
logger.warn(
|
|
||||||
"TX [%s] Problem in _attempt_transaction: %s",
|
|
||||||
destination,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
success = False
|
|
||||||
|
|
||||||
for p in pdus:
|
|
||||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
|
||||||
|
|
||||||
defer.returnValue(success)
|
defer.returnValue(success)
|
||||||
|
|
|
@ -163,6 +163,7 @@ class TransportLayerClient(object):
|
||||||
data=json_data,
|
data=json_data,
|
||||||
json_data_callback=json_data_callback,
|
json_data_callback=json_data_callback,
|
||||||
long_retries=True,
|
long_retries=True,
|
||||||
|
backoff_on_404=True, # If we get a 404 the other side has gone
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -174,7 +175,8 @@ class TransportLayerClient(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def make_query(self, destination, query_type, args, retry_on_dns_fail):
|
def make_query(self, destination, query_type, args, retry_on_dns_fail,
|
||||||
|
ignore_backoff=False):
|
||||||
path = PREFIX + "/query/%s" % query_type
|
path = PREFIX + "/query/%s" % query_type
|
||||||
|
|
||||||
content = yield self.client.get_json(
|
content = yield self.client.get_json(
|
||||||
|
@ -183,6 +185,7 @@ class TransportLayerClient(object):
|
||||||
args=args,
|
args=args,
|
||||||
retry_on_dns_fail=retry_on_dns_fail,
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
@ -242,6 +245,7 @@ class TransportLayerClient(object):
|
||||||
destination=destination,
|
destination=destination,
|
||||||
path=path,
|
path=path,
|
||||||
data=content,
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
@ -269,6 +273,7 @@ class TransportLayerClient(object):
|
||||||
destination=remote_server,
|
destination=remote_server,
|
||||||
path=path,
|
path=path,
|
||||||
args=args,
|
args=args,
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -47,6 +48,7 @@ class AuthHandler(BaseHandler):
|
||||||
LoginType.PASSWORD: self._check_password_auth,
|
LoginType.PASSWORD: self._check_password_auth,
|
||||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||||
|
LoginType.MSISDN: self._check_msisdn,
|
||||||
LoginType.DUMMY: self._check_dummy_auth,
|
LoginType.DUMMY: self._check_dummy_auth,
|
||||||
}
|
}
|
||||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||||
|
@ -307,31 +309,47 @@ class AuthHandler(BaseHandler):
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_email_identity(self, authdict, _):
|
def _check_email_identity(self, authdict, _):
|
||||||
|
return self._check_threepid('email', authdict)
|
||||||
|
|
||||||
|
def _check_msisdn(self, authdict, _):
|
||||||
|
return self._check_threepid('msisdn', authdict)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_dummy_auth(self, authdict, _):
|
||||||
|
yield run_on_reactor()
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_threepid(self, medium, authdict):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if 'threepid_creds' not in authdict:
|
if 'threepid_creds' not in authdict:
|
||||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
threepid_creds = authdict['threepid_creds']
|
threepid_creds = authdict['threepid_creds']
|
||||||
|
|
||||||
identity_handler = self.hs.get_handlers().identity_handler
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
|
|
||||||
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
|
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
|
||||||
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
|
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
|
||||||
|
|
||||||
if not threepid:
|
if not threepid:
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
if threepid['medium'] != medium:
|
||||||
|
raise LoginError(
|
||||||
|
401,
|
||||||
|
"Expecting threepid of type '%s', got '%s'" % (
|
||||||
|
medium, threepid['medium'],
|
||||||
|
),
|
||||||
|
errcode=Codes.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
threepid['threepid_creds'] = authdict['threepid_creds']
|
threepid['threepid_creds'] = authdict['threepid_creds']
|
||||||
|
|
||||||
defer.returnValue(threepid)
|
defer.returnValue(threepid)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_dummy_auth(self, authdict, _):
|
|
||||||
yield run_on_reactor()
|
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
def _get_params_recaptcha(self):
|
def _get_params_recaptcha(self):
|
||||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,40 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
yield self.notify_device_update(user_id, [device_id])
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_devices(self, user_id, device_ids):
|
||||||
|
""" Delete several devices
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_ids (str): The list of device IDs to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.delete_devices(user_id, device_ids)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
# no match
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Delete access tokens and e2e keys for each device. Not optimised as it is not
|
||||||
|
# considered as part of a critical path.
|
||||||
|
for device_id in device_ids:
|
||||||
|
yield self.store.user_delete_access_tokens(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
delete_refresh_tokens=True,
|
||||||
|
)
|
||||||
|
yield self.store.delete_e2e_keys_by_device(
|
||||||
|
user_id=user_id, device_id=device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.notify_device_update(user_id, device_ids)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_device(self, user_id, device_id, content):
|
def update_device(self, user_id, device_id, content):
|
||||||
""" Update the given device
|
""" Update the given device
|
||||||
|
@ -214,8 +248,7 @@ class DeviceHandler(BaseHandler):
|
||||||
user_id, device_ids, list(hosts)
|
user_id, device_ids, list(hosts)
|
||||||
)
|
)
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
room_ids = [r.room_id for r in rooms]
|
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
yield self.notifier.on_new_event(
|
||||||
"device_list_key", position, rooms=room_ids,
|
"device_list_key", position, rooms=room_ids,
|
||||||
|
@ -236,8 +269,7 @@ class DeviceHandler(BaseHandler):
|
||||||
user_id (str)
|
user_id (str)
|
||||||
from_token (StreamToken)
|
from_token (StreamToken)
|
||||||
"""
|
"""
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
room_ids = set(r.room_id for r in rooms)
|
|
||||||
|
|
||||||
# First we check if any devices have changed
|
# First we check if any devices have changed
|
||||||
changed = yield self.store.get_user_whose_devices_changed(
|
changed = yield self.store.get_user_whose_devices_changed(
|
||||||
|
@ -262,7 +294,7 @@ class DeviceHandler(BaseHandler):
|
||||||
# ordering: treat it the same as a new room
|
# ordering: treat it the same as a new room
|
||||||
event_ids = []
|
event_ids = []
|
||||||
|
|
||||||
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
current_state_ids = yield self.store.get_current_state_ids(room_id)
|
||||||
|
|
||||||
# special-case for an empty prev state: include all members
|
# special-case for an empty prev state: include all members
|
||||||
# in the changed list
|
# in the changed list
|
||||||
|
@ -313,8 +345,8 @@ class DeviceHandler(BaseHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_left_room(self, user, room_id):
|
def user_left_room(self, user, room_id):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
if not rooms:
|
if not room_ids:
|
||||||
# We no longer share rooms with this user, so we'll no longer
|
# We no longer share rooms with this user, so we'll no longer
|
||||||
# receive device updates. Mark this in DB.
|
# receive device updates. Mark this in DB.
|
||||||
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
|
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
|
||||||
|
@ -370,8 +402,8 @@ class DeviceListEduUpdater(object):
|
||||||
logger.warning("Got device list update edu for %r from %r", user_id, origin)
|
logger.warning("Got device list update edu for %r from %r", user_id, origin)
|
||||||
return
|
return
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
if not rooms:
|
if not room_ids:
|
||||||
# We don't share any rooms with this user. Ignore update, as we
|
# We don't share any rooms with this user. Ignore update, as we
|
||||||
# probably won't get any further updates.
|
# probably won't get any further updates.
|
||||||
return
|
return
|
||||||
|
|
|
@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
"room_alias": room_alias.to_string(),
|
"room_alias": room_alias.to_string(),
|
||||||
},
|
},
|
||||||
retry_on_dns_fail=False,
|
retry_on_dns_fail=False,
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logging.warn("Error retrieving alias")
|
logging.warn("Error retrieving alias")
|
||||||
|
|
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import SynapseError, CodeMessageException
|
from synapse.api.errors import SynapseError, CodeMessageException
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -121,10 +121,6 @@ class E2eKeysHandler(object):
|
||||||
def do_remote_query(destination):
|
def do_remote_query(destination):
|
||||||
destination_query = remote_queries_not_in_cache[destination]
|
destination_query = remote_queries_not_in_cache[destination]
|
||||||
try:
|
try:
|
||||||
limiter = yield get_retry_limiter(
|
|
||||||
destination, self.clock, self.store
|
|
||||||
)
|
|
||||||
with limiter:
|
|
||||||
remote_result = yield self.federation.query_client_keys(
|
remote_result = yield self.federation.query_client_keys(
|
||||||
destination,
|
destination,
|
||||||
{"device_keys": destination_query},
|
{"device_keys": destination_query},
|
||||||
|
@ -239,10 +235,6 @@ class E2eKeysHandler(object):
|
||||||
def claim_client_keys(destination):
|
def claim_client_keys(destination):
|
||||||
device_keys = remote_queries[destination]
|
device_keys = remote_queries[destination]
|
||||||
try:
|
try:
|
||||||
limiter = yield get_retry_limiter(
|
|
||||||
destination, self.clock, self.store
|
|
||||||
)
|
|
||||||
with limiter:
|
|
||||||
remote_result = yield self.federation.claim_client_keys(
|
remote_result = yield self.federation.claim_client_keys(
|
||||||
destination,
|
destination,
|
||||||
{"one_time_keys": device_keys},
|
{"one_time_keys": device_keys},
|
||||||
|
@ -316,7 +308,7 @@ class E2eKeysHandler(object):
|
||||||
# old access_token without an associated device_id. Either way, we
|
# old access_token without an associated device_id. Either way, we
|
||||||
# need to double-check the device is registered to avoid ending up with
|
# need to double-check the device is registered to avoid ending up with
|
||||||
# keys without a corresponding device.
|
# keys without a corresponding device.
|
||||||
self.device_handler.check_device_registered(user_id, device_id)
|
yield self.device_handler.check_device_registered(user_id, device_id)
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains handlers for federation events."""
|
"""Contains handlers for federation events."""
|
||||||
|
import synapse.util.logcontext
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json
|
from signedjson.sign import verify_signed_json
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
@ -31,7 +32,7 @@ from synapse.util.logcontext import (
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor, Linearizer
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
from synapse.crypto.event_signing import (
|
from synapse.crypto.event_signing import (
|
||||||
compute_event_signature, add_hashes_and_signatures,
|
compute_event_signature, add_hashes_and_signatures,
|
||||||
|
@ -79,29 +80,216 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# When joining a room we need to queue any events for that room up
|
# When joining a room we need to queue any events for that room up
|
||||||
self.room_queues = {}
|
self.room_queues = {}
|
||||||
|
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||||
|
|
||||||
@log_function
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
@log_function
|
||||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
def on_receive_pdu(self, origin, pdu, get_missing=True):
|
||||||
do auth checks and put it through the StateHandler.
|
""" Process a PDU received via a federation /send/ transaction, or
|
||||||
|
via backfill of missing prev_events
|
||||||
|
|
||||||
auth_chain and state are None if we already have the necessary state
|
Args:
|
||||||
and prev_events in the db
|
origin (str): server which initiated the /send/ transaction. Will
|
||||||
|
be used to fetch missing events or state.
|
||||||
|
pdu (FrozenEvent): received PDU
|
||||||
|
get_missing (bool): True if we should fetch missing prev_events
|
||||||
|
|
||||||
|
Returns (Deferred): completes with None
|
||||||
"""
|
"""
|
||||||
event = pdu
|
|
||||||
|
|
||||||
logger.debug("Got event: %s", event.event_id)
|
# We reprocess pdus when we have seen them only as outliers
|
||||||
|
existing = yield self.get_persisted_pdu(
|
||||||
|
origin, pdu.event_id, do_auth=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# FIXME: Currently we fetch an event again when we already have it
|
||||||
|
# if it has been marked as an outlier.
|
||||||
|
|
||||||
|
already_seen = (
|
||||||
|
existing and (
|
||||||
|
not existing.internal_metadata.is_outlier()
|
||||||
|
or pdu.internal_metadata.is_outlier()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if already_seen:
|
||||||
|
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||||
|
return
|
||||||
|
|
||||||
# If we are currently in the process of joining this room, then we
|
# If we are currently in the process of joining this room, then we
|
||||||
# queue up events for later processing.
|
# queue up events for later processing.
|
||||||
if event.room_id in self.room_queues:
|
if pdu.room_id in self.room_queues:
|
||||||
self.room_queues[event.room_id].append((pdu, origin))
|
logger.info("Ignoring PDU %s for room %s from %s for now; join "
|
||||||
|
"in progress", pdu.event_id, pdu.room_id, origin)
|
||||||
|
self.room_queues[pdu.room_id].append((pdu, origin))
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("Processing event: %s", event.event_id)
|
state = None
|
||||||
|
|
||||||
logger.debug("Event: %s", event)
|
auth_chain = []
|
||||||
|
|
||||||
|
have_seen = yield self.store.have_events(
|
||||||
|
[ev for ev, _ in pdu.prev_events]
|
||||||
|
)
|
||||||
|
|
||||||
|
fetch_state = False
|
||||||
|
|
||||||
|
# Get missing pdus if necessary.
|
||||||
|
if not pdu.internal_metadata.is_outlier():
|
||||||
|
# We only backfill backwards to the min depth.
|
||||||
|
min_depth = yield self.get_min_depth_for_context(
|
||||||
|
pdu.room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"_handle_new_pdu min_depth for %s: %d",
|
||||||
|
pdu.room_id, min_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||||
|
seen = set(have_seen.keys())
|
||||||
|
|
||||||
|
if min_depth and pdu.depth < min_depth:
|
||||||
|
# This is so that we don't notify the user about this
|
||||||
|
# message, to work around the fact that some events will
|
||||||
|
# reference really really old events we really don't want to
|
||||||
|
# send to the clients.
|
||||||
|
pdu.internal_metadata.outlier = True
|
||||||
|
elif min_depth and pdu.depth > min_depth:
|
||||||
|
if get_missing and prevs - seen:
|
||||||
|
# If we're missing stuff, ensure we only fetch stuff one
|
||||||
|
# at a time.
|
||||||
|
logger.info(
|
||||||
|
"Acquiring lock for room %r to fetch %d missing events: %r...",
|
||||||
|
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
|
||||||
|
)
|
||||||
|
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||||
|
logger.info(
|
||||||
|
"Acquired lock for room %r to fetch %d missing events",
|
||||||
|
pdu.room_id, len(prevs - seen),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._get_missing_events_for_pdu(
|
||||||
|
origin, pdu, prevs, min_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||||
|
seen = set(have_seen.keys())
|
||||||
|
if prevs - seen:
|
||||||
|
logger.info(
|
||||||
|
"Still missing %d events for room %r: %r...",
|
||||||
|
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||||
|
)
|
||||||
|
fetch_state = True
|
||||||
|
|
||||||
|
if fetch_state:
|
||||||
|
# We need to get the state at this event, since we haven't
|
||||||
|
# processed all the prev events.
|
||||||
|
logger.debug(
|
||||||
|
"_handle_new_pdu getting state for %s",
|
||||||
|
pdu.room_id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
state, auth_chain = yield self.replication_layer.get_state_for_room(
|
||||||
|
origin, pdu.room_id, pdu.event_id,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||||
|
|
||||||
|
yield self._process_received_pdu(
|
||||||
|
origin,
|
||||||
|
pdu,
|
||||||
|
state=state,
|
||||||
|
auth_chain=auth_chain,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||||
|
pdu: received pdu
|
||||||
|
prevs (str[]): List of event ids which we are missing
|
||||||
|
min_depth (int): Minimum depth of events to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred<dict(str, str?)>: updated have_seen dictionary
|
||||||
|
"""
|
||||||
|
# We recalculate seen, since it may have changed.
|
||||||
|
have_seen = yield self.store.have_events(prevs)
|
||||||
|
seen = set(have_seen.keys())
|
||||||
|
|
||||||
|
if not prevs - seen:
|
||||||
|
# nothing left to do
|
||||||
|
defer.returnValue(have_seen)
|
||||||
|
|
||||||
|
latest = yield self.store.get_latest_event_ids_in_room(
|
||||||
|
pdu.room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# We add the prev events that we have seen to the latest
|
||||||
|
# list to ensure the remote server doesn't give them to us
|
||||||
|
latest = set(latest)
|
||||||
|
latest |= seen
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Missing %d events for room %r: %r...",
|
||||||
|
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||||
|
)
|
||||||
|
|
||||||
|
# XXX: we set timeout to 10s to help workaround
|
||||||
|
# https://github.com/matrix-org/synapse/issues/1733.
|
||||||
|
# The reason is to avoid holding the linearizer lock
|
||||||
|
# whilst processing inbound /send transactions, causing
|
||||||
|
# FDs to stack up and block other inbound transactions
|
||||||
|
# which empirically can currently take up to 30 minutes.
|
||||||
|
#
|
||||||
|
# N.B. this explicitly disables retry attempts.
|
||||||
|
#
|
||||||
|
# N.B. this also increases our chances of falling back to
|
||||||
|
# fetching fresh state for the room if the missing event
|
||||||
|
# can't be found, which slightly reduces our security.
|
||||||
|
# it may also increase our DAG extremity count for the room,
|
||||||
|
# causing additional state resolution? See #1760.
|
||||||
|
# However, fetching state doesn't hold the linearizer lock
|
||||||
|
# apparently.
|
||||||
|
#
|
||||||
|
# see https://github.com/matrix-org/synapse/pull/1744
|
||||||
|
|
||||||
|
missing_events = yield self.replication_layer.get_missing_events(
|
||||||
|
origin,
|
||||||
|
pdu.room_id,
|
||||||
|
earliest_events_ids=list(latest),
|
||||||
|
latest_events=[pdu],
|
||||||
|
limit=10,
|
||||||
|
min_depth=min_depth,
|
||||||
|
timeout=10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We want to sort these by depth so we process them and
|
||||||
|
# tell clients about them in order.
|
||||||
|
missing_events.sort(key=lambda x: x.depth)
|
||||||
|
|
||||||
|
for e in missing_events:
|
||||||
|
yield self.on_receive_pdu(
|
||||||
|
origin,
|
||||||
|
e,
|
||||||
|
get_missing=False
|
||||||
|
)
|
||||||
|
|
||||||
|
have_seen = yield self.store.have_events(
|
||||||
|
[ev for ev, _ in pdu.prev_events]
|
||||||
|
)
|
||||||
|
defer.returnValue(have_seen)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _process_received_pdu(self, origin, pdu, state, auth_chain):
|
||||||
|
""" Called when we have a new pdu. We need to do auth checks and put it
|
||||||
|
through the StateHandler.
|
||||||
|
"""
|
||||||
|
event = pdu
|
||||||
|
|
||||||
|
logger.debug("Processing event: %s", event)
|
||||||
|
|
||||||
# FIXME (erikj): Awful hack to make the case where we are not currently
|
# FIXME (erikj): Awful hack to make the case where we are not currently
|
||||||
# in the room work
|
# in the room work
|
||||||
|
@ -670,8 +858,6 @@ class FederationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
logger.debug("Joining %s to %s", joinee, room_id)
|
logger.debug("Joining %s to %s", joinee, room_id)
|
||||||
|
|
||||||
yield self.store.clean_room_for_join(room_id)
|
|
||||||
|
|
||||||
origin, event = yield self._make_and_verify_event(
|
origin, event = yield self._make_and_verify_event(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -680,7 +866,15 @@ class FederationHandler(BaseHandler):
|
||||||
content,
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This shouldn't happen, because the RoomMemberHandler has a
|
||||||
|
# linearizer lock which only allows one operation per user per room
|
||||||
|
# at a time - so this is just paranoia.
|
||||||
|
assert (room_id not in self.room_queues)
|
||||||
|
|
||||||
self.room_queues[room_id] = []
|
self.room_queues[room_id] = []
|
||||||
|
|
||||||
|
yield self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
handled_events = set()
|
handled_events = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -733,17 +927,36 @@ class FederationHandler(BaseHandler):
|
||||||
room_queue = self.room_queues[room_id]
|
room_queue = self.room_queues[room_id]
|
||||||
del self.room_queues[room_id]
|
del self.room_queues[room_id]
|
||||||
|
|
||||||
for p, origin in room_queue:
|
# we don't need to wait for the queued events to be processed -
|
||||||
if p.event_id in handled_events:
|
# it's just a best-effort thing at this point. We do want to do
|
||||||
continue
|
# them roughly in order, though, otherwise we'll end up making
|
||||||
|
# lots of requests for missing prev_events which we do actually
|
||||||
|
# have. Hence we fire off the deferred, but don't wait for it.
|
||||||
|
|
||||||
try:
|
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
|
||||||
self.on_receive_pdu(origin, p)
|
room_queue
|
||||||
except:
|
)
|
||||||
logger.exception("Couldn't handle pdu")
|
|
||||||
|
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_queued_pdus(self, room_queue):
|
||||||
|
"""Process PDUs which got queued up while we were busy send_joining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
|
||||||
|
and the servers that sent them
|
||||||
|
"""
|
||||||
|
for p, origin in room_queue:
|
||||||
|
try:
|
||||||
|
logger.info("Processing queued PDU %s which was received "
|
||||||
|
"while we were joining %s", p.event_id, p.room_id)
|
||||||
|
yield self.on_receive_pdu(origin, p)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
"Error handling queued PDU %s from %s: %s",
|
||||||
|
p.event_id, origin, e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_make_join_request(self, room_id, user_id):
|
def on_make_join_request(self, room_id, user_id):
|
||||||
|
@ -791,9 +1004,19 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
event.internal_metadata.outlier = False
|
event.internal_metadata.outlier = False
|
||||||
# Send this event on behalf of the origin server since they may not
|
# Send this event on behalf of the origin server.
|
||||||
# have an up to data view of the state of the room at this event so
|
#
|
||||||
# will not know which servers to send the event to.
|
# The reasons we have the destination server rather than the origin
|
||||||
|
# server send it are slightly mysterious: the origin server should have
|
||||||
|
# all the neccessary state once it gets the response to the send_join,
|
||||||
|
# so it could send the event itself if it wanted to. It may be that
|
||||||
|
# doing it this way reduces failure modes, or avoids certain attacks
|
||||||
|
# where a new server selectively tells a subset of the federation that
|
||||||
|
# it has joined.
|
||||||
|
#
|
||||||
|
# The fact is that, as of the current writing, Synapse doesn't send out
|
||||||
|
# the join event over federation after joining, and changing it now
|
||||||
|
# would introduce the danger of backwards-compatibility problems.
|
||||||
event.internal_metadata.send_on_behalf_of = origin
|
event.internal_metadata.send_on_behalf_of = origin
|
||||||
|
|
||||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||||
|
@ -878,15 +1101,15 @@ class FederationHandler(BaseHandler):
|
||||||
user_id,
|
user_id,
|
||||||
"leave"
|
"leave"
|
||||||
)
|
)
|
||||||
signed_event = self._sign_event(event)
|
event = self._sign_event(event)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
raise
|
raise
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logger.warn("Failed to reject invite: %s", e)
|
logger.warn("Failed to reject invite: %s", e)
|
||||||
raise SynapseError(500, "Failed to reject invite")
|
raise SynapseError(500, "Failed to reject invite")
|
||||||
|
|
||||||
# Try the host we successfully got a response to /make_join/
|
# Try the host that we succesfully called /make_leave/ on first for
|
||||||
# request first.
|
# the /send_leave/ request.
|
||||||
try:
|
try:
|
||||||
target_hosts.remove(origin)
|
target_hosts.remove(origin)
|
||||||
target_hosts.insert(0, origin)
|
target_hosts.insert(0, origin)
|
||||||
|
@ -896,7 +1119,7 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
yield self.replication_layer.send_leave(
|
yield self.replication_layer.send_leave(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
signed_event
|
event
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
raise
|
raise
|
||||||
|
@ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _prep_event(self, origin, event, state=None, auth_events=None):
|
def _prep_event(self, origin, event, state=None, auth_events=None):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin:
|
||||||
|
event:
|
||||||
|
state:
|
||||||
|
auth_events:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred, which resolves to synapse.events.snapshot.EventContext
|
||||||
|
"""
|
||||||
context = yield self.state_handler.compute_event_context(
|
context = yield self.state_handler.compute_event_context(
|
||||||
event, old_state=state,
|
event, old_state=state,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -150,7 +151,7 @@ class IdentityHandler(BaseHandler):
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = yield self.http_client.post_urlencoded_get_json(
|
data = yield self.http_client.post_json_get_json(
|
||||||
"https://%s%s" % (
|
"https://%s%s" % (
|
||||||
id_server,
|
id_server,
|
||||||
"/_matrix/identity/api/v1/validate/email/requestToken"
|
"/_matrix/identity/api/v1/validate/email/requestToken"
|
||||||
|
@ -161,3 +162,37 @@ class IdentityHandler(BaseHandler):
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logger.info("Proxied requestToken failed: %r", e)
|
logger.info("Proxied requestToken failed: %r", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def requestMsisdnToken(
|
||||||
|
self, id_server, country, phone_number,
|
||||||
|
client_secret, send_attempt, **kwargs
|
||||||
|
):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if not self._should_trust_id_server(id_server):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Untrusted ID server '%s'" % id_server,
|
||||||
|
Codes.SERVER_NOT_TRUSTED
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'country': country,
|
||||||
|
'phone_number': phone_number,
|
||||||
|
'client_secret': client_secret,
|
||||||
|
'send_attempt': send_attempt,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = yield self.http_client.post_json_get_json(
|
||||||
|
"https://%s%s" % (
|
||||||
|
id_server,
|
||||||
|
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
|
||||||
|
),
|
||||||
|
params
|
||||||
|
)
|
||||||
|
defer.returnValue(data)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
logger.info("Proxied requestToken failed: %r", e)
|
||||||
|
raise e
|
||||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError, Codes
|
from synapse.api.errors import AuthError, Codes
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
UserID, StreamToken,
|
UserID, StreamToken,
|
||||||
|
@ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler):
|
||||||
"content": content,
|
"content": content,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"rooms": rooms_ret,
|
"rooms": rooms_ret,
|
||||||
"presence": presence,
|
"presence": [
|
||||||
|
{
|
||||||
|
"type": "m.presence",
|
||||||
|
"content": format_user_presence_state(event, now),
|
||||||
|
}
|
||||||
|
for event in presence
|
||||||
|
],
|
||||||
"account_data": account_data_events,
|
"account_data": account_data_events,
|
||||||
"receipts": receipt,
|
"receipts": receipt,
|
||||||
"end": now_token.to_string(),
|
"end": now_token.to_string(),
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
|
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -556,9 +557,9 @@ class PresenceHandler(object):
|
||||||
room_ids_to_states = {}
|
room_ids_to_states = {}
|
||||||
users_to_states = {}
|
users_to_states = {}
|
||||||
for state in states:
|
for state in states:
|
||||||
events = yield self.store.get_rooms_for_user(state.user_id)
|
room_ids = yield self.store.get_rooms_for_user(state.user_id)
|
||||||
for e in events:
|
for room_id in room_ids:
|
||||||
room_ids_to_states.setdefault(e.room_id, []).append(state)
|
room_ids_to_states.setdefault(room_id, []).append(state)
|
||||||
|
|
||||||
plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
|
plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
|
||||||
for u in plist:
|
for u in plist:
|
||||||
|
@ -574,8 +575,7 @@ class PresenceHandler(object):
|
||||||
if not local_states:
|
if not local_states:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
hosts = yield self.store.get_hosts_in_room(room_id)
|
||||||
hosts = set(get_domain_from_id(u) for u in users)
|
|
||||||
|
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||||
|
@ -719,9 +719,7 @@ class PresenceHandler(object):
|
||||||
for state in updates
|
for state in updates
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
defer.returnValue([
|
defer.returnValue(updates)
|
||||||
format_user_presence_state(state, now) for state in updates
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_state(self, target_user, state, ignore_status_msg=False):
|
def set_state(self, target_user, state, ignore_status_msg=False):
|
||||||
|
@ -795,6 +793,9 @@ class PresenceHandler(object):
|
||||||
as_event=False,
|
as_event=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
results[:] = [format_user_presence_state(r, now) for r in results]
|
||||||
|
|
||||||
is_accepted = {
|
is_accepted = {
|
||||||
row["observed_user_id"]: row["accepted"] for row in presence_list
|
row["observed_user_id"]: row["accepted"] for row in presence_list
|
||||||
}
|
}
|
||||||
|
@ -847,6 +848,7 @@ class PresenceHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict = yield self.get_state(observed_user, as_event=False)
|
state_dict = yield self.get_state(observed_user, as_event=False)
|
||||||
|
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
|
||||||
|
|
||||||
self.federation.send_edu(
|
self.federation.send_edu(
|
||||||
destination=observer_user.domain,
|
destination=observer_user.domain,
|
||||||
|
@ -910,11 +912,12 @@ class PresenceHandler(object):
|
||||||
def is_visible(self, observed_user, observer_user):
|
def is_visible(self, observed_user, observer_user):
|
||||||
"""Returns whether a user can see another user's presence.
|
"""Returns whether a user can see another user's presence.
|
||||||
"""
|
"""
|
||||||
observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string())
|
observer_room_ids = yield self.store.get_rooms_for_user(
|
||||||
observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string())
|
observer_user.to_string()
|
||||||
|
)
|
||||||
observer_room_ids = set(r.room_id for r in observer_rooms)
|
observed_room_ids = yield self.store.get_rooms_for_user(
|
||||||
observed_room_ids = set(r.room_id for r in observed_rooms)
|
observed_user.to_string()
|
||||||
|
)
|
||||||
|
|
||||||
if observer_room_ids & observed_room_ids:
|
if observer_room_ids & observed_room_ids:
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
@ -979,14 +982,18 @@ def should_notify(old_state, new_state):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def format_user_presence_state(state, now):
|
def format_user_presence_state(state, now, include_user_id=True):
|
||||||
"""Convert UserPresenceState to a format that can be sent down to clients
|
"""Convert UserPresenceState to a format that can be sent down to clients
|
||||||
and to other servers.
|
and to other servers.
|
||||||
|
|
||||||
|
The "user_id" is optional so that this function can be used to format presence
|
||||||
|
updates for client /sync responses and for federation /send requests.
|
||||||
"""
|
"""
|
||||||
content = {
|
content = {
|
||||||
"presence": state.state,
|
"presence": state.state,
|
||||||
"user_id": state.user_id,
|
|
||||||
}
|
}
|
||||||
|
if include_user_id:
|
||||||
|
content["user_id"] = state.user_id
|
||||||
if state.last_active_ts:
|
if state.last_active_ts:
|
||||||
content["last_active_ago"] = now - state.last_active_ts
|
content["last_active_ago"] = now - state.last_active_ts
|
||||||
if state.status_msg and state.state != PresenceState.OFFLINE:
|
if state.status_msg and state.state != PresenceState.OFFLINE:
|
||||||
|
@ -1025,7 +1032,6 @@ class PresenceEventSource(object):
|
||||||
# sending down the rare duplicate is not a concern.
|
# sending down the rare duplicate is not a concern.
|
||||||
|
|
||||||
with Measure(self.clock, "presence.get_new_events"):
|
with Measure(self.clock, "presence.get_new_events"):
|
||||||
user_id = user.to_string()
|
|
||||||
if from_key is not None:
|
if from_key is not None:
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
|
||||||
|
@ -1034,18 +1040,7 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
max_token = self.store.get_current_presence_token()
|
max_token = self.store.get_current_presence_token()
|
||||||
|
|
||||||
plist = yield self.store.get_presence_list_accepted(user.localpart)
|
users_interested_in = yield self._get_interested_in(user, explicit_room_id)
|
||||||
users_interested_in = set(row["observed_user_id"] for row in plist)
|
|
||||||
users_interested_in.add(user_id) # So that we receive our own presence
|
|
||||||
|
|
||||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
users_interested_in.update(users_who_share_room)
|
|
||||||
|
|
||||||
if explicit_room_id:
|
|
||||||
user_ids = yield self.store.get_users_in_room(explicit_room_id)
|
|
||||||
users_interested_in.update(user_ids)
|
|
||||||
|
|
||||||
user_ids_changed = set()
|
user_ids_changed = set()
|
||||||
changed = None
|
changed = None
|
||||||
|
@ -1073,15 +1068,12 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
updates = yield presence.current_state_for_users(user_ids_changed)
|
updates = yield presence.current_state_for_users(user_ids_changed)
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
if include_offline:
|
||||||
|
defer.returnValue((updates.values(), max_token))
|
||||||
|
else:
|
||||||
defer.returnValue(([
|
defer.returnValue(([
|
||||||
{
|
s for s in updates.itervalues()
|
||||||
"type": "m.presence",
|
if s.state != PresenceState.OFFLINE
|
||||||
"content": format_user_presence_state(s, now),
|
|
||||||
}
|
|
||||||
for s in updates.values()
|
|
||||||
if include_offline or s.state != PresenceState.OFFLINE
|
|
||||||
], max_token))
|
], max_token))
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self):
|
||||||
|
@ -1090,6 +1082,31 @@ class PresenceEventSource(object):
|
||||||
def get_pagination_rows(self, user, pagination_config, key):
|
def get_pagination_rows(self, user, pagination_config, key):
|
||||||
return self.get_new_events(user, from_key=None, include_offline=False)
|
return self.get_new_events(user, from_key=None, include_offline=False)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
|
def _get_interested_in(self, user, explicit_room_id, cache_context):
|
||||||
|
"""Returns the set of users that the given user should see presence
|
||||||
|
updates for
|
||||||
|
"""
|
||||||
|
user_id = user.to_string()
|
||||||
|
plist = yield self.store.get_presence_list_accepted(
|
||||||
|
user.localpart, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
users_interested_in = set(row["observed_user_id"] for row in plist)
|
||||||
|
users_interested_in.add(user_id) # So that we receive our own presence
|
||||||
|
|
||||||
|
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||||
|
user_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
users_interested_in.update(users_who_share_room)
|
||||||
|
|
||||||
|
if explicit_room_id:
|
||||||
|
user_ids = yield self.store.get_users_in_room(
|
||||||
|
explicit_room_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
users_interested_in.update(user_ids)
|
||||||
|
|
||||||
|
defer.returnValue(users_interested_in)
|
||||||
|
|
||||||
|
|
||||||
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||||
"""Checks the presence of users that have timed out and updates as
|
"""Checks the presence of users that have timed out and updates as
|
||||||
|
@ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||||
# If there are have been no sync for a while (and none ongoing),
|
# If there are have been no sync for a while (and none ongoing),
|
||||||
# set presence to offline
|
# set presence to offline
|
||||||
if user_id not in syncing_user_ids:
|
if user_id not in syncing_user_ids:
|
||||||
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
|
# If the user has done something recently but hasn't synced,
|
||||||
|
# don't set them as offline.
|
||||||
|
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
|
||||||
|
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
|
||||||
state = state.copy_and_replace(
|
state = state.copy_and_replace(
|
||||||
state=PresenceState.OFFLINE,
|
state=PresenceState.OFFLINE,
|
||||||
status_msg=None,
|
status_msg=None,
|
||||||
|
|
|
@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
|
||||||
args={
|
args={
|
||||||
"user_id": target_user.to_string(),
|
"user_id": target_user.to_string(),
|
||||||
"field": "displayname",
|
"field": "displayname",
|
||||||
}
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
if e.code != 404:
|
if e.code != 404:
|
||||||
|
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
|
||||||
args={
|
args={
|
||||||
"user_id": target_user.to_string(),
|
"user_id": target_user.to_string(),
|
||||||
"field": "avatar_url",
|
"field": "avatar_url",
|
||||||
}
|
},
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
if e.code != 404:
|
if e.code != 404:
|
||||||
|
@ -156,11 +158,11 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
self.ratelimit(requester)
|
self.ratelimit(requester)
|
||||||
|
|
||||||
joins = yield self.store.get_rooms_for_user(
|
room_ids = yield self.store.get_rooms_for_user(
|
||||||
user.to_string(),
|
user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
for j in joins:
|
for room_id in room_ids:
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_handlers().room_member_handler
|
||||||
try:
|
try:
|
||||||
# Assume the user isn't a guest because we don't let guests set
|
# Assume the user isn't a guest because we don't let guests set
|
||||||
|
@ -171,12 +173,12 @@ class ProfileHandler(BaseHandler):
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
user,
|
user,
|
||||||
j.room_id,
|
room_id,
|
||||||
"join", # We treat a profile update like a join.
|
"join", # We treat a profile update like a join.
|
||||||
ratelimit=False, # Try to hide that these events aren't atomic.
|
ratelimit=False, # Try to hide that these events aren't atomic.
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to update join event for room %s - %s",
|
"Failed to update join event for room %s - %s",
|
||||||
j.room_id, str(e.message)
|
room_id, str(e.message)
|
||||||
)
|
)
|
||||||
|
|
|
@ -210,10 +210,9 @@ class ReceiptEventSource(object):
|
||||||
else:
|
else:
|
||||||
from_key = None
|
from_key = None
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
||||||
rooms = [room.room_id for room in rooms]
|
|
||||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||||
rooms,
|
room_ids,
|
||||||
from_key=from_key,
|
from_key=from_key,
|
||||||
to_key=to_key,
|
to_key=to_key,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from synapse.api.constants import (
|
||||||
EventTypes, JoinRules,
|
EventTypes, JoinRules,
|
||||||
)
|
)
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.types import ThirdPartyInstanceID
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
|
@ -62,6 +63,10 @@ class RoomListHandler(BaseHandler):
|
||||||
appservice and network id to use an appservice specific one.
|
appservice and network id to use an appservice specific one.
|
||||||
Setting to None returns all public rooms across all lists.
|
Setting to None returns all public rooms across all lists.
|
||||||
"""
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
|
||||||
|
limit, since_token, bool(search_filter), network_tuple,
|
||||||
|
)
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We explicitly don't bother caching searches or requests for
|
# We explicitly don't bother caching searches or requests for
|
||||||
# appservice specific lists.
|
# appservice specific lists.
|
||||||
|
@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
rooms_to_order_value = {}
|
rooms_to_order_value = {}
|
||||||
rooms_to_num_joined = {}
|
rooms_to_num_joined = {}
|
||||||
rooms_to_latest_event_ids = {}
|
|
||||||
|
|
||||||
newly_visible = []
|
newly_visible = []
|
||||||
newly_unpublished = []
|
newly_unpublished = []
|
||||||
|
@ -116,12 +120,18 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_order_for_room(room_id):
|
def get_order_for_room(room_id):
|
||||||
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
|
# Most of the rooms won't have changed between the since token and
|
||||||
if not latest_event_ids:
|
# now (especially if the since token is "now"). So, we can ask what
|
||||||
|
# the current users are in a room (that will hit a cache) and then
|
||||||
|
# check if the room has changed since the since token. (We have to
|
||||||
|
# do it in that order to avoid races).
|
||||||
|
# If things have changed then fall back to getting the current state
|
||||||
|
# at the since token.
|
||||||
|
joined_users = yield self.store.get_users_in_room(room_id)
|
||||||
|
if self.store.has_room_changed_since(room_id, stream_token):
|
||||||
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
|
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
|
||||||
room_id, stream_token
|
room_id, stream_token
|
||||||
)
|
)
|
||||||
rooms_to_latest_event_ids[room_id] = latest_event_ids
|
|
||||||
|
|
||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
return
|
return
|
||||||
|
@ -129,6 +139,7 @@ class RoomListHandler(BaseHandler):
|
||||||
joined_users = yield self.state_handler.get_current_user_in_room(
|
joined_users = yield self.state_handler.get_current_user_in_room(
|
||||||
room_id, latest_event_ids,
|
room_id, latest_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_joined_users = len(joined_users)
|
num_joined_users = len(joined_users)
|
||||||
rooms_to_num_joined[room_id] = num_joined_users
|
rooms_to_num_joined[room_id] = num_joined_users
|
||||||
|
|
||||||
|
@ -165,19 +176,19 @@ class RoomListHandler(BaseHandler):
|
||||||
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
||||||
rooms_to_scan.reverse()
|
rooms_to_scan.reverse()
|
||||||
|
|
||||||
# Actually generate the entries. _generate_room_entry will append to
|
# Actually generate the entries. _append_room_entry_to_chunk will append to
|
||||||
# chunk but will stop if len(chunk) > limit
|
# chunk but will stop if len(chunk) > limit
|
||||||
chunk = []
|
chunk = []
|
||||||
if limit and not search_filter:
|
if limit and not search_filter:
|
||||||
step = limit + 1
|
step = limit + 1
|
||||||
for i in xrange(0, len(rooms_to_scan), step):
|
for i in xrange(0, len(rooms_to_scan), step):
|
||||||
# We iterate here because the vast majority of cases we'll stop
|
# We iterate here because the vast majority of cases we'll stop
|
||||||
# at first iteration, but occaisonally _generate_room_entry
|
# at first iteration, but occaisonally _append_room_entry_to_chunk
|
||||||
# won't append to the chunk and so we need to loop again.
|
# won't append to the chunk and so we need to loop again.
|
||||||
# We don't want to scan over the entire range either as that
|
# We don't want to scan over the entire range either as that
|
||||||
# would potentially waste a lot of work.
|
# would potentially waste a lot of work.
|
||||||
yield concurrently_execute(
|
yield concurrently_execute(
|
||||||
lambda r: self._generate_room_entry(
|
lambda r: self._append_room_entry_to_chunk(
|
||||||
r, rooms_to_num_joined[r],
|
r, rooms_to_num_joined[r],
|
||||||
chunk, limit, search_filter
|
chunk, limit, search_filter
|
||||||
),
|
),
|
||||||
|
@ -187,7 +198,7 @@ class RoomListHandler(BaseHandler):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
yield concurrently_execute(
|
yield concurrently_execute(
|
||||||
lambda r: self._generate_room_entry(
|
lambda r: self._append_room_entry_to_chunk(
|
||||||
r, rooms_to_num_joined[r],
|
r, rooms_to_num_joined[r],
|
||||||
chunk, limit, search_filter
|
chunk, limit, search_filter
|
||||||
),
|
),
|
||||||
|
@ -256,21 +267,35 @@ class RoomListHandler(BaseHandler):
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
|
def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
|
||||||
search_filter):
|
search_filter):
|
||||||
|
"""Generate the entry for a room in the public room list and append it
|
||||||
|
to the `chunk` if it matches the search filter
|
||||||
|
"""
|
||||||
if limit and len(chunk) > limit + 1:
|
if limit and len(chunk) > limit + 1:
|
||||||
# We've already got enough, so lets just drop it.
|
# We've already got enough, so lets just drop it.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
result = yield self._generate_room_entry(room_id, num_joined_users)
|
||||||
|
|
||||||
|
if result and _matches_room_entry(result, search_filter):
|
||||||
|
chunk.append(result)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||||
|
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
|
||||||
|
"""Returns the entry for a room
|
||||||
|
"""
|
||||||
result = {
|
result = {
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"num_joined_members": num_joined_users,
|
"num_joined_members": num_joined_users,
|
||||||
}
|
}
|
||||||
|
|
||||||
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
|
current_state_ids = yield self.store.get_current_state_ids(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
event_map = yield self.store.get_events([
|
event_map = yield self.store.get_events([
|
||||||
event_id for key, event_id in current_state_ids.items()
|
event_id for key, event_id in current_state_ids.iteritems()
|
||||||
if key[0] in (
|
if key[0] in (
|
||||||
EventTypes.JoinRules,
|
EventTypes.JoinRules,
|
||||||
EventTypes.Name,
|
EventTypes.Name,
|
||||||
|
@ -294,7 +319,9 @@ class RoomListHandler(BaseHandler):
|
||||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
aliases = yield self.store.get_aliases_for_room(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
|
)
|
||||||
if aliases:
|
if aliases:
|
||||||
result["aliases"] = aliases
|
result["aliases"] = aliases
|
||||||
|
|
||||||
|
@ -334,8 +361,7 @@ class RoomListHandler(BaseHandler):
|
||||||
if avatar_url:
|
if avatar_url:
|
||||||
result["avatar_url"] = avatar_url
|
result["avatar_url"] = avatar_url
|
||||||
|
|
||||||
if _matches_room_entry(result, search_filter):
|
defer.returnValue(result)
|
||||||
chunk.append(result)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
from synapse.types import RoomStreamToken
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -225,8 +226,7 @@ class SyncHandler(object):
|
||||||
with Measure(self.clock, "ephemeral_by_room"):
|
with Measure(self.clock, "ephemeral_by_room"):
|
||||||
typing_key = since_token.typing_key if since_token else "0"
|
typing_key = since_token.typing_key if since_token else "0"
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||||
room_ids = [room.room_id for room in rooms]
|
|
||||||
|
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources["typing"]
|
||||||
typing, typing_key = yield typing_source.get_new_events(
|
typing, typing_key = yield typing_source.get_new_events(
|
||||||
|
@ -568,16 +568,15 @@ class SyncHandler(object):
|
||||||
since_token = sync_result_builder.since_token
|
since_token = sync_result_builder.since_token
|
||||||
|
|
||||||
if since_token and since_token.device_list_key:
|
if since_token and since_token.device_list_key:
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
room_ids = set(r.room_id for r in rooms)
|
|
||||||
|
|
||||||
user_ids_changed = set()
|
user_ids_changed = set()
|
||||||
changed = yield self.store.get_user_whose_devices_changed(
|
changed = yield self.store.get_user_whose_devices_changed(
|
||||||
since_token.device_list_key
|
since_token.device_list_key
|
||||||
)
|
)
|
||||||
for other_user_id in changed:
|
for other_user_id in changed:
|
||||||
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
|
other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
|
||||||
if room_ids.intersection(e.room_id for e in other_rooms):
|
if room_ids.intersection(other_room_ids):
|
||||||
user_ids_changed.add(other_user_id)
|
user_ids_changed.add(other_user_id)
|
||||||
|
|
||||||
defer.returnValue(user_ids_changed)
|
defer.returnValue(user_ids_changed)
|
||||||
|
@ -721,14 +720,14 @@ class SyncHandler(object):
|
||||||
extra_users_ids.update(users)
|
extra_users_ids.update(users)
|
||||||
extra_users_ids.discard(user.to_string())
|
extra_users_ids.discard(user.to_string())
|
||||||
|
|
||||||
|
if extra_users_ids:
|
||||||
states = yield self.presence_handler.get_states(
|
states = yield self.presence_handler.get_states(
|
||||||
extra_users_ids,
|
extra_users_ids,
|
||||||
as_event=True,
|
|
||||||
)
|
)
|
||||||
presence.extend(states)
|
presence.extend(states)
|
||||||
|
|
||||||
# Deduplicate the presence entries so that there's at most one per user
|
# Deduplicate the presence entries so that there's at most one per user
|
||||||
presence = {p["content"]["user_id"]: p for p in presence}.values()
|
presence = {p.user_id: p for p in presence}.values()
|
||||||
|
|
||||||
presence = sync_config.filter_collection.filter_presence(
|
presence = sync_config.filter_collection.filter_presence(
|
||||||
presence
|
presence
|
||||||
|
@ -765,6 +764,21 @@ class SyncHandler(object):
|
||||||
)
|
)
|
||||||
sync_result_builder.now_token = now_token
|
sync_result_builder.now_token = now_token
|
||||||
|
|
||||||
|
# We check up front if anything has changed, if it hasn't then there is
|
||||||
|
# no point in going futher.
|
||||||
|
since_token = sync_result_builder.since_token
|
||||||
|
if not sync_result_builder.full_state:
|
||||||
|
if since_token and not ephemeral_by_room and not account_data_by_room:
|
||||||
|
have_changed = yield self._have_rooms_changed(sync_result_builder)
|
||||||
|
if not have_changed:
|
||||||
|
tags_by_room = yield self.store.get_updated_tags(
|
||||||
|
user_id,
|
||||||
|
since_token.account_data_key,
|
||||||
|
)
|
||||||
|
if not tags_by_room:
|
||||||
|
logger.debug("no-oping sync")
|
||||||
|
defer.returnValue(([], []))
|
||||||
|
|
||||||
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
|
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
|
||||||
"m.ignored_user_list", user_id=user_id,
|
"m.ignored_user_list", user_id=user_id,
|
||||||
)
|
)
|
||||||
|
@ -774,13 +788,12 @@ class SyncHandler(object):
|
||||||
else:
|
else:
|
||||||
ignored_users = frozenset()
|
ignored_users = frozenset()
|
||||||
|
|
||||||
if sync_result_builder.since_token:
|
if since_token:
|
||||||
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
|
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
|
||||||
room_entries, invited, newly_joined_rooms = res
|
room_entries, invited, newly_joined_rooms = res
|
||||||
|
|
||||||
tags_by_room = yield self.store.get_updated_tags(
|
tags_by_room = yield self.store.get_updated_tags(
|
||||||
user_id,
|
user_id, since_token.account_data_key,
|
||||||
sync_result_builder.since_token.account_data_key,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
|
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
|
||||||
|
@ -805,7 +818,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
# Now we want to get any newly joined users
|
# Now we want to get any newly joined users
|
||||||
newly_joined_users = set()
|
newly_joined_users = set()
|
||||||
if sync_result_builder.since_token:
|
if since_token:
|
||||||
for joined_sync in sync_result_builder.joined:
|
for joined_sync in sync_result_builder.joined:
|
||||||
it = itertools.chain(
|
it = itertools.chain(
|
||||||
joined_sync.timeline.events, joined_sync.state.values()
|
joined_sync.timeline.events, joined_sync.state.values()
|
||||||
|
@ -817,6 +830,38 @@ class SyncHandler(object):
|
||||||
|
|
||||||
defer.returnValue((newly_joined_rooms, newly_joined_users))
|
defer.returnValue((newly_joined_rooms, newly_joined_users))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _have_rooms_changed(self, sync_result_builder):
|
||||||
|
"""Returns whether there may be any new events that should be sent down
|
||||||
|
the sync. Returns True if there are.
|
||||||
|
"""
|
||||||
|
user_id = sync_result_builder.sync_config.user.to_string()
|
||||||
|
since_token = sync_result_builder.since_token
|
||||||
|
now_token = sync_result_builder.now_token
|
||||||
|
|
||||||
|
assert since_token
|
||||||
|
|
||||||
|
# Get a list of membership change events that have happened.
|
||||||
|
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||||
|
user_id, since_token.room_key, now_token.room_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if rooms_changed:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||||
|
if app_service:
|
||||||
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
|
joined_room_ids = set(r.room_id for r in rooms)
|
||||||
|
else:
|
||||||
|
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
|
|
||||||
|
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
|
||||||
|
for room_id in joined_room_ids:
|
||||||
|
if self.store.has_room_changed_since(room_id, stream_id):
|
||||||
|
defer.returnValue(True)
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_rooms_changed(self, sync_result_builder, ignored_users):
|
def _get_rooms_changed(self, sync_result_builder, ignored_users):
|
||||||
"""Gets the the changes that have happened since the last sync.
|
"""Gets the the changes that have happened since the last sync.
|
||||||
|
@ -841,8 +886,7 @@ class SyncHandler(object):
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
joined_room_ids = set(r.room_id for r in rooms)
|
||||||
else:
|
else:
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
|
||||||
|
|
||||||
# Get a list of membership change events that have happened.
|
# Get a list of membership change events that have happened.
|
||||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||||
|
|
|
@ -12,8 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import synapse.util.retryutils
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, protocol
|
from twisted.internet import defer, reactor, protocol
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
||||||
|
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util import logcontext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
|
||||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||||
)
|
)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self._store = hs.get_datastore()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
|
|
||||||
|
@ -103,12 +103,40 @@ class MatrixFederationHttpClient(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, destination, method, path_bytes,
|
def _request(self, destination, method, path,
|
||||||
body_callback, headers_dict={}, param_bytes=b"",
|
body_callback, headers_dict={}, param_bytes=b"",
|
||||||
query_bytes=b"", retry_on_dns_fail=True,
|
query_bytes=b"", retry_on_dns_fail=True,
|
||||||
timeout=None, long_retries=False):
|
timeout=None, long_retries=False,
|
||||||
""" Creates and sends a request to the given url
|
ignore_backoff=False,
|
||||||
|
backoff_on_404=False):
|
||||||
|
""" Creates and sends a request to the given server
|
||||||
|
Args:
|
||||||
|
destination (str): The remote server to send the HTTP request to.
|
||||||
|
method (str): HTTP method
|
||||||
|
path (str): The HTTP path
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data
|
||||||
|
and try the request anyway.
|
||||||
|
backoff_on_404 (bool): Back off if we get a 404
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: resolves with the http response object on success.
|
||||||
|
|
||||||
|
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||||
|
code >= 300.
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
"""
|
"""
|
||||||
|
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||||
|
destination,
|
||||||
|
self.clock,
|
||||||
|
self._store,
|
||||||
|
backoff_on_404=backoff_on_404,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
destination = destination.encode("ascii")
|
||||||
|
path_bytes = path.encode("ascii")
|
||||||
|
with limiter:
|
||||||
headers_dict[b"User-Agent"] = [self.version_string]
|
headers_dict[b"User-Agent"] = [self.version_string]
|
||||||
headers_dict[b"Host"] = [destination]
|
headers_dict[b"Host"] = [destination]
|
||||||
|
|
||||||
|
@ -144,8 +172,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def send_request():
|
def send_request():
|
||||||
request_deferred = preserve_context_over_fn(
|
request_deferred = self.agent.request(
|
||||||
self.agent.request,
|
|
||||||
method,
|
method,
|
||||||
url_bytes,
|
url_bytes,
|
||||||
Headers(headers_dict),
|
Headers(headers_dict),
|
||||||
|
@ -157,7 +184,8 @@ class MatrixFederationHttpClient(object):
|
||||||
time_out=timeout / 1000. if timeout else 60,
|
time_out=timeout / 1000. if timeout else 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield preserve_context_over_fn(send_request)
|
with logcontext.PreserveLoggingContext():
|
||||||
|
response = yield send_request()
|
||||||
|
|
||||||
log_result = "%d %s" % (response.code, response.phrase,)
|
log_result = "%d %s" % (response.code, response.phrase,)
|
||||||
break
|
break
|
||||||
|
@ -214,7 +242,8 @@ class MatrixFederationHttpClient(object):
|
||||||
else:
|
else:
|
||||||
# :'(
|
# :'(
|
||||||
# Update transactions table?
|
# Update transactions table?
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
with logcontext.PreserveLoggingContext():
|
||||||
|
body = yield readBody(response)
|
||||||
raise HttpResponseException(
|
raise HttpResponseException(
|
||||||
response.code, response.phrase, body
|
response.code, response.phrase, body
|
||||||
)
|
)
|
||||||
|
@ -248,7 +277,9 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def put_json(self, destination, path, data={}, json_data_callback=None,
|
def put_json(self, destination, path, data={}, json_data_callback=None,
|
||||||
long_retries=False, timeout=None):
|
long_retries=False, timeout=None,
|
||||||
|
ignore_backoff=False,
|
||||||
|
backoff_on_404=False):
|
||||||
""" Sends the specifed json data using PUT
|
""" Sends the specifed json data using PUT
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -263,11 +294,19 @@ class MatrixFederationHttpClient(object):
|
||||||
retry for a short or long time.
|
retry for a short or long time.
|
||||||
timeout(int): How long to try (in ms) the destination for before
|
timeout(int): How long to try (in ms) the destination for before
|
||||||
giving up. None indicates no timeout.
|
giving up. None indicates no timeout.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data
|
||||||
|
and try the request anyway.
|
||||||
|
backoff_on_404 (bool): True if we should count a 404 response as
|
||||||
|
a failure of the server (and should therefore back off future
|
||||||
|
requests)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||||
CodeMessageException is raised.
|
CodeMessageException is raised.
|
||||||
|
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not json_data_callback:
|
if not json_data_callback:
|
||||||
|
@ -282,26 +321,29 @@ class MatrixFederationHttpClient(object):
|
||||||
producer = _JsonProducer(json_data)
|
producer = _JsonProducer(json_data)
|
||||||
return producer
|
return producer
|
||||||
|
|
||||||
response = yield self._create_request(
|
response = yield self._request(
|
||||||
destination.encode("ascii"),
|
destination,
|
||||||
"PUT",
|
"PUT",
|
||||||
path.encode("ascii"),
|
path,
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
headers_dict={"Content-Type": ["application/json"]},
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
long_retries=long_retries,
|
long_retries=long_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
|
backoff_on_404=backoff_on_404,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
check_content_type_is_json(response.headers)
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
with logcontext.PreserveLoggingContext():
|
||||||
|
body = yield readBody(response)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_json(self, destination, path, data={}, long_retries=False,
|
def post_json(self, destination, path, data={}, long_retries=False,
|
||||||
timeout=None):
|
timeout=None, ignore_backoff=False):
|
||||||
""" Sends the specifed json data using POST
|
""" Sends the specifed json data using POST
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -314,11 +356,15 @@ class MatrixFederationHttpClient(object):
|
||||||
retry for a short or long time.
|
retry for a short or long time.
|
||||||
timeout(int): How long to try (in ms) the destination for before
|
timeout(int): How long to try (in ms) the destination for before
|
||||||
giving up. None indicates no timeout.
|
giving up. None indicates no timeout.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||||
|
try the request anyway.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||||
CodeMessageException is raised.
|
CodeMessageException is raised.
|
||||||
|
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def body_callback(method, url_bytes, headers_dict):
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
|
@ -327,27 +373,29 @@ class MatrixFederationHttpClient(object):
|
||||||
)
|
)
|
||||||
return _JsonProducer(data)
|
return _JsonProducer(data)
|
||||||
|
|
||||||
response = yield self._create_request(
|
response = yield self._request(
|
||||||
destination.encode("ascii"),
|
destination,
|
||||||
"POST",
|
"POST",
|
||||||
path.encode("ascii"),
|
path,
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
headers_dict={"Content-Type": ["application/json"]},
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
long_retries=long_retries,
|
long_retries=long_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
check_content_type_is_json(response.headers)
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
with logcontext.PreserveLoggingContext():
|
||||||
|
body = yield readBody(response)
|
||||||
|
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
|
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
|
||||||
timeout=None):
|
timeout=None, ignore_backoff=False):
|
||||||
""" GETs some json from the given host homeserver and path
|
""" GETs some json from the given host homeserver and path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -359,11 +407,16 @@ class MatrixFederationHttpClient(object):
|
||||||
timeout (int): How long to try (in ms) the destination for before
|
timeout (int): How long to try (in ms) the destination for before
|
||||||
giving up. None indicates no timeout and that the request will
|
giving up. None indicates no timeout and that the request will
|
||||||
be retried.
|
be retried.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data
|
||||||
|
and try the request anyway.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get *any* HTTP response.
|
Deferred: Succeeds when we get *any* HTTP response.
|
||||||
|
|
||||||
The result of the deferred is a tuple of `(code, response)`,
|
The result of the deferred is a tuple of `(code, response)`,
|
||||||
where `response` is a dict representing the decoded JSON body.
|
where `response` is a dict representing the decoded JSON body.
|
||||||
|
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
"""
|
"""
|
||||||
logger.debug("get_json args: %s", args)
|
logger.debug("get_json args: %s", args)
|
||||||
|
|
||||||
|
@ -380,36 +433,47 @@ class MatrixFederationHttpClient(object):
|
||||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
response = yield self._create_request(
|
response = yield self._request(
|
||||||
destination.encode("ascii"),
|
destination,
|
||||||
"GET",
|
"GET",
|
||||||
path.encode("ascii"),
|
path,
|
||||||
query_bytes=query_bytes,
|
query_bytes=query_bytes,
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
retry_on_dns_fail=retry_on_dns_fail,
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
check_content_type_is_json(response.headers)
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
with logcontext.PreserveLoggingContext():
|
||||||
|
body = yield readBody(response)
|
||||||
|
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_file(self, destination, path, output_stream, args={},
|
def get_file(self, destination, path, output_stream, args={},
|
||||||
retry_on_dns_fail=True, max_size=None):
|
retry_on_dns_fail=True, max_size=None,
|
||||||
|
ignore_backoff=False):
|
||||||
"""GETs a file from a given homeserver
|
"""GETs a file from a given homeserver
|
||||||
Args:
|
Args:
|
||||||
destination (str): The remote server to send the HTTP request to.
|
destination (str): The remote server to send the HTTP request to.
|
||||||
path (str): The HTTP path to GET.
|
path (str): The HTTP path to GET.
|
||||||
output_stream (file): File to write the response body to.
|
output_stream (file): File to write the response body to.
|
||||||
args (dict): Optional dictionary used to create the query string.
|
args (dict): Optional dictionary used to create the query string.
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data
|
||||||
|
and try the request anyway.
|
||||||
Returns:
|
Returns:
|
||||||
A (int,dict) tuple of the file length and a dict of the response
|
Deferred: resolves with an (int,dict) tuple of the file length and
|
||||||
headers.
|
a dict of the response headers.
|
||||||
|
|
||||||
|
Fails with ``HTTPRequestException`` if we get an HTTP response code
|
||||||
|
>= 300
|
||||||
|
|
||||||
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
|
to retry this server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoded_args = {}
|
encoded_args = {}
|
||||||
|
@ -419,26 +483,27 @@ class MatrixFederationHttpClient(object):
|
||||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||||
|
|
||||||
query_bytes = urllib.urlencode(encoded_args, True)
|
query_bytes = urllib.urlencode(encoded_args, True)
|
||||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
|
||||||
|
|
||||||
def body_callback(method, url_bytes, headers_dict):
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
response = yield self._create_request(
|
response = yield self._request(
|
||||||
destination.encode("ascii"),
|
destination,
|
||||||
"GET",
|
"GET",
|
||||||
path.encode("ascii"),
|
path,
|
||||||
query_bytes=query_bytes,
|
query_bytes=query_bytes,
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
retry_on_dns_fail=retry_on_dns_fail
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
|
ignore_backoff=ignore_backoff,
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = dict(response.headers.getAllRawHeaders())
|
headers = dict(response.headers.getAllRawHeaders())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
length = yield preserve_context_over_fn(
|
with logcontext.PreserveLoggingContext():
|
||||||
_readBodyToFile,
|
length = yield _readBodyToFile(
|
||||||
response, output_stream, max_size
|
response, output_stream, max_size
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -192,6 +192,16 @@ def parse_json_object_from_request(request):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def assert_params_in_request(body, required):
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if len(absent) > 0:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
|
||||||
class RestServlet(object):
|
class RestServlet(object):
|
||||||
|
|
||||||
""" A Synapse REST Servlet.
|
""" A Synapse REST Servlet.
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
|
|
||||||
from synapse.util import DeferredTimedOutError
|
from synapse.util import DeferredTimedOutError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
@ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
notified_events_counter = metrics.register_counter("notified_events")
|
notified_events_counter = metrics.register_counter("notified_events")
|
||||||
|
|
||||||
|
users_woken_by_stream_counter = metrics.register_counter(
|
||||||
|
"users_woken_by_stream", labels=["stream"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul): Should be shared somewhere
|
# TODO(paul): Should be shared somewhere
|
||||||
def count(func, l):
|
def count(func, l):
|
||||||
|
@ -73,6 +78,13 @@ class _NotifierUserStream(object):
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.rooms = set(rooms)
|
self.rooms = set(rooms)
|
||||||
self.current_token = current_token
|
self.current_token = current_token
|
||||||
|
|
||||||
|
# The last token for which we should wake up any streams that have a
|
||||||
|
# token that comes before it. This gets updated everytime we get poked.
|
||||||
|
# We start it at the current token since if we get any streams
|
||||||
|
# that have a token from before we have no idea whether they should be
|
||||||
|
# woken up or not, so lets just wake them up.
|
||||||
|
self.last_notified_token = current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -89,9 +101,12 @@ class _NotifierUserStream(object):
|
||||||
self.current_token = self.current_token.copy_and_advance(
|
self.current_token = self.current_token.copy_and_advance(
|
||||||
stream_key, stream_id
|
stream_key, stream_id
|
||||||
)
|
)
|
||||||
|
self.last_notified_token = self.current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
noify_deferred = self.notify_deferred
|
noify_deferred = self.notify_deferred
|
||||||
|
|
||||||
|
users_woken_by_stream_counter.inc(stream_key)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
noify_deferred.callback(self.current_token)
|
noify_deferred.callback(self.current_token)
|
||||||
|
@ -113,8 +128,14 @@ class _NotifierUserStream(object):
|
||||||
def new_listener(self, token):
|
def new_listener(self, token):
|
||||||
"""Returns a deferred that is resolved when there is a new token
|
"""Returns a deferred that is resolved when there is a new token
|
||||||
greater than the given token.
|
greater than the given token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token from which we are streaming from, i.e. we shouldn't
|
||||||
|
notify for things that happened before this.
|
||||||
"""
|
"""
|
||||||
if self.current_token.is_after(token):
|
# Immediately wake up stream if something has already since happened
|
||||||
|
# since their last token.
|
||||||
|
if self.last_notified_token.is_after(token):
|
||||||
return _NotificationListener(defer.succeed(self.current_token))
|
return _NotificationListener(defer.succeed(self.current_token))
|
||||||
else:
|
else:
|
||||||
return _NotificationListener(self.notify_deferred.observe())
|
return _NotificationListener(self.notify_deferred.observe())
|
||||||
|
@ -283,8 +304,7 @@ class Notifier(object):
|
||||||
if user_stream is None:
|
if user_stream is None:
|
||||||
current_token = yield self.event_sources.get_current_token()
|
current_token = yield self.event_sources.get_current_token()
|
||||||
if room_ids is None:
|
if room_ids is None:
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
room_ids = [room.room_id for room in rooms]
|
|
||||||
user_stream = _NotifierUserStream(
|
user_stream = _NotifierUserStream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
rooms=room_ids,
|
rooms=room_ids,
|
||||||
|
@ -294,40 +314,44 @@ class Notifier(object):
|
||||||
self._register_with_keys(user_stream)
|
self._register_with_keys(user_stream)
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
|
prev_token = from_token
|
||||||
if timeout:
|
if timeout:
|
||||||
end_time = self.clock.time_msec() + timeout
|
end_time = self.clock.time_msec() + timeout
|
||||||
|
|
||||||
prev_token = from_token
|
|
||||||
while not result:
|
while not result:
|
||||||
try:
|
try:
|
||||||
current_token = user_stream.current_token
|
|
||||||
|
|
||||||
result = yield callback(prev_token, current_token)
|
|
||||||
if result:
|
|
||||||
break
|
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
if end_time <= now:
|
if end_time <= now:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Now we wait for the _NotifierUserStream to be told there
|
# Now we wait for the _NotifierUserStream to be told there
|
||||||
# is a new token.
|
# is a new token.
|
||||||
# We need to supply the token we supplied to callback so
|
|
||||||
# that we don't miss any current_token updates.
|
|
||||||
prev_token = current_token
|
|
||||||
listener = user_stream.new_listener(prev_token)
|
listener = user_stream.new_listener(prev_token)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield self.clock.time_bound_deferred(
|
yield self.clock.time_bound_deferred(
|
||||||
listener.deferred,
|
listener.deferred,
|
||||||
time_out=(end_time - now) / 1000.
|
time_out=(end_time - now) / 1000.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_token = user_stream.current_token
|
||||||
|
|
||||||
|
result = yield callback(prev_token, current_token)
|
||||||
|
if result:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update the prev_token to the current_token since nothing
|
||||||
|
# has happened between the old prev_token and the current_token
|
||||||
|
prev_token = current_token
|
||||||
except DeferredTimedOutError:
|
except DeferredTimedOutError:
|
||||||
break
|
break
|
||||||
except defer.CancelledError:
|
except defer.CancelledError:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
|
if result is None:
|
||||||
|
# This happened if there was no timeout or if the timeout had
|
||||||
|
# already expired.
|
||||||
current_token = user_stream.current_token
|
current_token = user_stream.current_token
|
||||||
result = yield callback(from_token, current_token)
|
result = yield callback(prev_token, current_token)
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@ -388,6 +412,15 @@ class Notifier(object):
|
||||||
new_events,
|
new_events,
|
||||||
is_peeking=is_peeking,
|
is_peeking=is_peeking,
|
||||||
)
|
)
|
||||||
|
elif name == "presence":
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
new_events[:] = [
|
||||||
|
{
|
||||||
|
"type": "m.presence",
|
||||||
|
"content": format_user_presence_state(event, now),
|
||||||
|
}
|
||||||
|
for event in new_events
|
||||||
|
]
|
||||||
|
|
||||||
events.extend(new_events)
|
events.extend(new_events)
|
||||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||||
|
@ -420,8 +453,7 @@ class Notifier(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_room_ids(self, user, explicit_room_id):
|
def _get_room_ids(self, user, explicit_room_id):
|
||||||
joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
|
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
||||||
joined_room_ids = map(lambda r: r.room_id, joined_rooms)
|
|
||||||
if explicit_room_id:
|
if explicit_room_id:
|
||||||
if explicit_room_id in joined_room_ids:
|
if explicit_room_id in joined_room_ids:
|
||||||
defer.returnValue(([explicit_room_id], True))
|
defer.returnValue(([explicit_room_id], True))
|
||||||
|
|
|
@ -139,7 +139,7 @@ class Mailer(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _fetch_room_state(room_id):
|
def _fetch_room_state(room_id):
|
||||||
room_state = yield self.state_handler.get_current_state_ids(room_id)
|
room_state = yield self.store.get_current_state_ids(room_id)
|
||||||
state_by_room[room_id] = room_state
|
state_by_room[room_id] = room_state
|
||||||
|
|
||||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
|
||||||
return self._value_cache.get(dotted_key, None)
|
return self._value_cache.get(dotted_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
|
||||||
|
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||||
|
register_cache("regex_push_cache", regex_cache)
|
||||||
|
|
||||||
|
|
||||||
def _glob_matches(glob, value, word_boundary=False):
|
def _glob_matches(glob, value, word_boundary=False):
|
||||||
"""Tests if value matches glob.
|
"""Tests if value matches glob.
|
||||||
|
|
||||||
|
@ -137,7 +143,29 @@ def _glob_matches(glob, value, word_boundary=False):
|
||||||
Returns:
|
Returns:
|
||||||
bool
|
bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
r = regex_cache.get((glob, word_boundary), None)
|
||||||
|
if not r:
|
||||||
|
r = _glob_to_re(glob, word_boundary)
|
||||||
|
regex_cache[(glob, word_boundary)] = r
|
||||||
|
return r.search(value)
|
||||||
|
except re.error:
|
||||||
|
logger.warn("Failed to parse glob to regex: %r", glob)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _glob_to_re(glob, word_boundary):
|
||||||
|
"""Generates regex for a given glob.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
glob (string)
|
||||||
|
word_boundary (bool): Whether to match against word boundaries or entire
|
||||||
|
string. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
regex object
|
||||||
|
"""
|
||||||
if IS_GLOB.search(glob):
|
if IS_GLOB.search(glob):
|
||||||
r = re.escape(glob)
|
r = re.escape(glob)
|
||||||
|
|
||||||
|
@ -156,25 +184,20 @@ def _glob_matches(glob, value, word_boundary=False):
|
||||||
)
|
)
|
||||||
if word_boundary:
|
if word_boundary:
|
||||||
r = r"\b%s\b" % (r,)
|
r = r"\b%s\b" % (r,)
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.search(value)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
else:
|
else:
|
||||||
r = r + "$"
|
r = "^" + r + "$"
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.match(value)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
elif word_boundary:
|
elif word_boundary:
|
||||||
r = re.escape(glob)
|
r = re.escape(glob)
|
||||||
r = r"\b%s\b" % (r,)
|
r = r"\b%s\b" % (r,)
|
||||||
r = _compile_regex(r)
|
|
||||||
|
|
||||||
return r.search(value)
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
else:
|
else:
|
||||||
return value.lower() == glob.lower()
|
r = "^" + re.escape(glob) + "$"
|
||||||
except re.error:
|
return re.compile(r, flags=re.IGNORECASE)
|
||||||
logger.warn("Failed to parse glob to regex: %r", glob)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(d, prefix=[], result={}):
|
def _flatten_dict(d, prefix=[], result={}):
|
||||||
|
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
|
||||||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
regex_cache = LruCache(5000)
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_regex(regex_str):
|
|
||||||
r = regex_cache.get(regex_str, None)
|
|
||||||
if r:
|
|
||||||
return r
|
|
||||||
|
|
||||||
r = re.compile(regex_str, flags=re.IGNORECASE)
|
|
||||||
regex_cache[regex_str] = r
|
|
||||||
return r
|
|
||||||
|
|
|
@ -33,13 +33,13 @@ def get_badge_count(store, user_id):
|
||||||
|
|
||||||
badge = len(invites)
|
badge = len(invites)
|
||||||
|
|
||||||
for r in joins:
|
for room_id in joins:
|
||||||
if r.room_id in my_receipts_by_room:
|
if room_id in my_receipts_by_room:
|
||||||
last_unread_event_id = my_receipts_by_room[r.room_id]
|
last_unread_event_id = my_receipts_by_room[room_id]
|
||||||
|
|
||||||
notifs = yield (
|
notifs = yield (
|
||||||
store.get_unread_event_push_actions_by_room_for_user(
|
store.get_unread_event_push_actions_by_room_for_user(
|
||||||
r.room_id, user_id, last_unread_event_id
|
room_id, user_id, last_unread_event_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# return one badge count per conversation, as count per
|
# return one badge count per conversation, as count per
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -18,6 +19,7 @@ from distutils.version import LooseVersion
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
|
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
|
||||||
"frozendict>=0.4": ["frozendict"],
|
"frozendict>=0.4": ["frozendict"],
|
||||||
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
|
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
|
||||||
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
||||||
|
@ -37,6 +39,7 @@ REQUIREMENTS = {
|
||||||
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
||||||
"pymacaroons-pynacl": ["pymacaroons"],
|
"pymacaroons-pynacl": ["pymacaroons"],
|
||||||
"msgpack-python>=0.3.0": ["msgpack"],
|
"msgpack-python>=0.3.0": ["msgpack"],
|
||||||
|
"phonenumbers>=8.2.0": ["phonenumbers"],
|
||||||
}
|
}
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
"web_client": {
|
"web_client": {
|
||||||
|
|
|
@ -283,12 +283,12 @@ class ReplicationResource(Resource):
|
||||||
|
|
||||||
if request_events != upto_events_token:
|
if request_events != upto_events_token:
|
||||||
writer.write_header_and_rows("events", res.new_forward_events, (
|
writer.write_header_and_rows("events", res.new_forward_events, (
|
||||||
"position", "internal", "json", "state_group"
|
"position", "event_id", "room_id", "type", "state_key",
|
||||||
), position=upto_events_token)
|
), position=upto_events_token)
|
||||||
|
|
||||||
if request_backfill != upto_backfill_token:
|
if request_backfill != upto_backfill_token:
|
||||||
writer.write_header_and_rows("backfill", res.new_backfill_events, (
|
writer.write_header_and_rows("backfill", res.new_backfill_events, (
|
||||||
"position", "internal", "json", "state_group",
|
"position", "event_id", "room_id", "type", "state_key", "redacts",
|
||||||
), position=upto_backfill_token)
|
), position=upto_backfill_token)
|
||||||
|
|
||||||
writer.write_header_and_rows(
|
writer.write_header_and_rows(
|
||||||
|
|
|
@ -27,4 +27,9 @@ class SlavedIdTracker(object):
|
||||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
return self._current
|
return self._current
|
||||||
|
|
|
@ -16,7 +16,6 @@ from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
|
@ -25,7 +24,6 @@ from synapse.storage.state import StateStore
|
||||||
from synapse.storage.stream import StreamStore
|
from synapse.storage.stream import StreamStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
import ujson as json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
get_recent_event_ids_for_room = (
|
get_recent_event_ids_for_room = (
|
||||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||||
)
|
)
|
||||||
|
get_current_state_ids = (
|
||||||
|
StateStore.__dict__["get_current_state_ids"]
|
||||||
|
)
|
||||||
|
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
||||||
|
|
||||||
get_unread_push_actions_for_user_in_range_for_http = (
|
get_unread_push_actions_for_user_in_range_for_http = (
|
||||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||||
|
@ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
_get_rooms_for_user_where_membership_is_txn = (
|
_get_rooms_for_user_where_membership_is_txn = (
|
||||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||||
)
|
)
|
||||||
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
|
|
||||||
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
||||||
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
||||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||||
|
@ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
return super(SlavedEventStore, self).process_replication(result)
|
return super(SlavedEventStore, self).process_replication(result)
|
||||||
|
|
||||||
def _process_replication_row(self, row, backfilled):
|
def _process_replication_row(self, row, backfilled):
|
||||||
internal = json.loads(row[1])
|
stream_ordering = row[0] if not backfilled else -row[0]
|
||||||
event_json = json.loads(row[2])
|
|
||||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
|
||||||
self.invalidate_caches_for_event(
|
self.invalidate_caches_for_event(
|
||||||
event, backfilled,
|
stream_ordering, row[1], row[2], row[3], row[4], row[5],
|
||||||
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invalidate_caches_for_event(self, event, backfilled):
|
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
||||||
self._invalidate_get_event_cache(event.event_id)
|
etype, state_key, redacts, backfilled):
|
||||||
|
self._invalidate_get_event_cache(event_id)
|
||||||
|
|
||||||
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
|
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||||
|
|
||||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
||||||
(event.room_id,)
|
(room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
self._events_stream_cache.entity_has_changed(
|
self._events_stream_cache.entity_has_changed(
|
||||||
event.room_id, event.internal_metadata.stream_ordering
|
room_id, stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
if redacts:
|
||||||
# (event.room_id,)
|
self._invalidate_get_event_cache(redacts)
|
||||||
# )
|
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if etype == EventTypes.Member:
|
||||||
self._invalidate_get_event_cache(event.redacts)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
self._membership_stream_cache.entity_has_changed(
|
self._membership_stream_cache.entity_has_changed(
|
||||||
event.state_key, event.internal_metadata.stream_ordering
|
state_key, stream_ordering
|
||||||
)
|
)
|
||||||
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
self.get_invited_rooms_for_user.invalidate((state_key,))
|
||||||
|
|
||||||
if not event.is_state():
|
|
||||||
return
|
|
||||||
|
|
||||||
if backfilled:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (not event.internal_metadata.is_invite_from_remote()
|
|
||||||
and event.internal_metadata.is_outlier()):
|
|
||||||
return
|
|
||||||
|
|
|
@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||||
self.presence_stream_cache.entity_has_changed(
|
self.presence_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
user_id, position
|
||||||
)
|
)
|
||||||
|
self._get_presence_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
return super(SlavedPresenceStore, self).process_replication(result)
|
return super(SlavedPresenceStore, self).process_replication(result)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.http.server import finish_request
|
from synapse.http.server import finish_request
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
|
|
||||||
|
@ -33,10 +34,55 @@ from saml2.client import Saml2Client
|
||||||
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def login_submission_legacy_convert(submission):
|
||||||
|
"""
|
||||||
|
If the input login submission is an old style object
|
||||||
|
(ie. with top-level user / medium / address) convert it
|
||||||
|
to a typed object.
|
||||||
|
"""
|
||||||
|
if "user" in submission:
|
||||||
|
submission["identifier"] = {
|
||||||
|
"type": "m.id.user",
|
||||||
|
"user": submission["user"],
|
||||||
|
}
|
||||||
|
del submission["user"]
|
||||||
|
|
||||||
|
if "medium" in submission and "address" in submission:
|
||||||
|
submission["identifier"] = {
|
||||||
|
"type": "m.id.thirdparty",
|
||||||
|
"medium": submission["medium"],
|
||||||
|
"address": submission["address"],
|
||||||
|
}
|
||||||
|
del submission["medium"]
|
||||||
|
del submission["address"]
|
||||||
|
|
||||||
|
|
||||||
|
def login_id_thirdparty_from_phone(identifier):
|
||||||
|
"""
|
||||||
|
Convert a phone login identifier type to a generic threepid identifier
|
||||||
|
Args:
|
||||||
|
identifier(dict): Login identifier dict of type 'm.id.phone'
|
||||||
|
|
||||||
|
Returns: Login identifier dict of type 'm.id.threepid'
|
||||||
|
"""
|
||||||
|
if "country" not in identifier or "number" not in identifier:
|
||||||
|
raise SynapseError(400, "Invalid phone-type identifier")
|
||||||
|
|
||||||
|
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "m.id.thirdparty",
|
||||||
|
"medium": "msisdn",
|
||||||
|
"address": msisdn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login$")
|
PATTERNS = client_path_patterns("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
PASS_TYPE = "m.login.password"
|
||||||
|
@ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_password_login(self, login_submission):
|
def do_password_login(self, login_submission):
|
||||||
if 'medium' in login_submission and 'address' in login_submission:
|
if "password" not in login_submission:
|
||||||
address = login_submission['address']
|
raise SynapseError(400, "Missing parameter: password")
|
||||||
if login_submission['medium'] == 'email':
|
|
||||||
|
login_submission_legacy_convert(login_submission)
|
||||||
|
|
||||||
|
if "identifier" not in login_submission:
|
||||||
|
raise SynapseError(400, "Missing param: identifier")
|
||||||
|
|
||||||
|
identifier = login_submission["identifier"]
|
||||||
|
if "type" not in identifier:
|
||||||
|
raise SynapseError(400, "Login identifier has no type")
|
||||||
|
|
||||||
|
# convert phone type identifiers to generic threepids
|
||||||
|
if identifier["type"] == "m.id.phone":
|
||||||
|
identifier = login_id_thirdparty_from_phone(identifier)
|
||||||
|
|
||||||
|
# convert threepid identifiers to user IDs
|
||||||
|
if identifier["type"] == "m.id.thirdparty":
|
||||||
|
if 'medium' not in identifier or 'address' not in identifier:
|
||||||
|
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||||
|
|
||||||
|
address = identifier['address']
|
||||||
|
if identifier['medium'] == 'email':
|
||||||
# For emails, transform the address to lowercase.
|
# For emails, transform the address to lowercase.
|
||||||
# We store all email addreses as lowercase in the DB.
|
# We store all email addreses as lowercase in the DB.
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
address = address.lower()
|
address = address.lower()
|
||||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
login_submission['medium'], address
|
identifier['medium'], address
|
||||||
)
|
)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
else:
|
|
||||||
user_id = login_submission['user']
|
identifier = {
|
||||||
|
"type": "m.id.user",
|
||||||
|
"user": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# by this point, the identifier should be an m.id.user: if it's anything
|
||||||
|
# else, we haven't understood it.
|
||||||
|
if identifier["type"] != "m.id.user":
|
||||||
|
raise SynapseError(400, "Unknown login identifier type")
|
||||||
|
if "user" not in identifier:
|
||||||
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
|
user_id = identifier["user"]
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(
|
user_id = UserID.create(
|
||||||
|
@ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
"ticket": request.args["ticket"],
|
"ticket": request.args["ticket"],
|
||||||
"service": self.cas_service_url
|
"service": self.cas_service_url
|
||||||
}
|
}
|
||||||
|
try:
|
||||||
body = yield http_client.get_raw(uri, args)
|
body = yield http_client.get_raw(uri, args)
|
||||||
|
except PartialDownloadError as pde:
|
||||||
|
# Twisted raises this error if the connection is closed,
|
||||||
|
# even if that's being used old-http style to signal end-of-data
|
||||||
|
body = pde.response
|
||||||
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(PresenceStatusRestServlet, self).__init__(hs)
|
super(PresenceStatusRestServlet, self).__init__(hs)
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
|
@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
raise AuthError(403, "You are not allowed to see their presence.")
|
raise AuthError(403, "You are not allowed to see their presence.")
|
||||||
|
|
||||||
state = yield self.presence_handler.get_state(target_user=user)
|
state = yield self.presence_handler.get_state(target_user=user)
|
||||||
|
state = format_user_presence_state(state, self.clock.time_msec())
|
||||||
|
|
||||||
defer.returnValue((200, state))
|
defer.returnValue((200, state))
|
||||||
|
|
||||||
|
|
|
@ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
|
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||||
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
|
|
||||||
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
|
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,8 +18,11 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import LoginError, SynapseError, Codes
|
from synapse.api.errors import LoginError, SynapseError, Codes
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import (
|
||||||
|
RestServlet, parse_json_object_from_request, assert_params_in_request
|
||||||
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
@ -28,11 +32,11 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PasswordRequestTokenRestServlet(RestServlet):
|
class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(PasswordRequestTokenRestServlet, self).__init__()
|
super(EmailPasswordRequestTokenRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
assert_params_in_request(body, [
|
||||||
absent = []
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
for k in required:
|
])
|
||||||
if k not in body:
|
|
||||||
absent.append(k)
|
|
||||||
|
|
||||||
if absent:
|
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
|
@ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet):
|
||||||
defer.returnValue((200, ret))
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.datastore = self.hs.get_datastore()
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
assert_params_in_request(body, [
|
||||||
|
'id_server', 'client_secret',
|
||||||
|
'country', 'phone_number', 'send_attempt',
|
||||||
|
])
|
||||||
|
|
||||||
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
|
'msisdn', msisdn
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is None:
|
||||||
|
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
class PasswordRestServlet(RestServlet):
|
class PasswordRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/password$")
|
PATTERNS = client_v2_patterns("/account/password$")
|
||||||
|
|
||||||
|
@ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
self.datastore = self.hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet):
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY],
|
||||||
|
[LoginType.MSISDN],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
if not authed:
|
||||||
|
@ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
threepid['address'] = threepid['address'].lower()
|
threepid['address'] = threepid['address'].lower()
|
||||||
# if using email, we must know about the email they're authing with!
|
# if using email, we must know about the email they're authing with!
|
||||||
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||||
threepid['medium'], threepid['address']
|
threepid['medium'], threepid['address']
|
||||||
)
|
)
|
||||||
if not threepid_user_id:
|
if not threepid_user_id:
|
||||||
|
@ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
class ThreepidRequestTokenRestServlet(RestServlet):
|
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
super(ThreepidRequestTokenRestServlet, self).__init__()
|
super(EmailThreepidRequestTokenRestServlet, self).__init__()
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
self.datastore = self.hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet):
|
||||||
if absent:
|
if absent:
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet):
|
||||||
defer.returnValue((200, ret))
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
self.datastore = self.hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
required = [
|
||||||
|
'id_server', 'client_secret',
|
||||||
|
'country', 'phone_number', 'send_attempt',
|
||||||
|
]
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if absent:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
|
'msisdn', msisdn
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is not None:
|
||||||
|
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
class ThreepidRestServlet(RestServlet):
|
class ThreepidRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/3pid$")
|
PATTERNS = client_v2_patterns("/account/3pid$")
|
||||||
|
|
||||||
|
@ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
self.datastore = self.hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
@ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
threepids = yield self.datastore.user_get_threepids(
|
||||||
requester.user.to_string()
|
requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
|
|
||||||
if 'bind' in body and body['bind']:
|
if 'bind' in body and body['bind']:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Binding emails %s to %s",
|
"Binding threepid %s to %s",
|
||||||
threepid, user_id
|
threepid, user_id
|
||||||
)
|
)
|
||||||
yield self.identity_handler.bind_threepid(
|
yield self.identity_handler.bind_threepid(
|
||||||
|
@ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
PasswordRequestTokenRestServlet(hs).register(http_server)
|
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||||
|
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||||
PasswordRestServlet(hs).register(http_server)
|
PasswordRestServlet(hs).register(http_server)
|
||||||
DeactivateAccountRestServlet(hs).register(http_server)
|
DeactivateAccountRestServlet(hs).register(http_server)
|
||||||
ThreepidRequestTokenRestServlet(hs).register(http_server)
|
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||||
|
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||||
ThreepidRestServlet(hs).register(http_server)
|
ThreepidRestServlet(hs).register(http_server)
|
||||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet):
|
||||||
defer.returnValue((200, {"devices": devices}))
|
defer.returnValue((200, {"devices": devices}))
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||||
|
"""
|
||||||
|
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||||
|
key which lists the device_ids to delete. Requires user interactive auth.
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(DeleteDevicesRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
try:
|
||||||
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
except errors.SynapseError as e:
|
||||||
|
if e.errcode == errors.Codes.NOT_JSON:
|
||||||
|
# deal with older clients which didn't pass a J*DELETESON dict
|
||||||
|
# the same as those that pass an empty dict
|
||||||
|
body = {}
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if 'devices' not in body:
|
||||||
|
raise errors.SynapseError(
|
||||||
|
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||||
|
[constants.LoginType.PASSWORD],
|
||||||
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
|
if not authed:
|
||||||
|
defer.returnValue((401, result))
|
||||||
|
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
yield self.device_handler.delete_devices(
|
||||||
|
requester.user.to_string(),
|
||||||
|
body['devices'],
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
class DeviceRestServlet(servlet.RestServlet):
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||||
releases=[], v2_alpha=False)
|
releases=[], v2_alpha=False)
|
||||||
|
@ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
DeleteDevicesRestServlet(hs).register(http_server)
|
||||||
DevicesRestServlet(hs).register(http_server)
|
DevicesRestServlet(hs).register(http_server)
|
||||||
DeviceRestServlet(hs).register(http_server)
|
DeviceRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015 - 2016 OpenMarket Ltd
|
# Copyright 2015 - 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,7 +20,10 @@ import synapse
|
||||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import (
|
||||||
|
RestServlet, parse_json_object_from_request, assert_params_in_request
|
||||||
|
)
|
||||||
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
@ -43,7 +47,7 @@ else:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequestTokenRestServlet(RestServlet):
|
class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer): server
|
hs (synapse.server.HomeServer): server
|
||||||
"""
|
"""
|
||||||
super(RegisterRequestTokenRestServlet, self).__init__()
|
super(EmailRegisterRequestTokenRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
assert_params_in_request(body, [
|
||||||
absent = []
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
for k in required:
|
])
|
||||||
if k not in body:
|
|
||||||
absent.append(k)
|
|
||||||
|
|
||||||
if len(absent) > 0:
|
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
|
@ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
||||||
defer.returnValue((200, ret))
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
|
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
assert_params_in_request(body, [
|
||||||
|
'id_server', 'client_secret',
|
||||||
|
'country', 'phone_number',
|
||||||
|
'send_attempt',
|
||||||
|
])
|
||||||
|
|
||||||
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
'msisdn', msisdn
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is not None:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Phone number is already in use", Codes.THREEPID_IN_USE
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServlet(RestServlet):
|
class RegisterRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/register$")
|
PATTERNS = client_v2_patterns("/register$")
|
||||||
|
|
||||||
|
@ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet):
|
||||||
assigned_user_id=registered_user_id,
|
assigned_user_id=registered_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only give msisdn flows if the x_show_msisdn flag is given:
|
||||||
|
# this is a hack to work around the fact that clients were shipped
|
||||||
|
# that use fallback registration if they see any flows that they don't
|
||||||
|
# recognise, which means we break registration for these clients if we
|
||||||
|
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
|
||||||
|
# Android <=0.6.9 have fallen below an acceptable threshold, this
|
||||||
|
# parameter should go away and we should always advertise msisdn flows.
|
||||||
|
show_msisdn = False
|
||||||
|
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||||
|
show_msisdn = True
|
||||||
|
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
flows = [
|
||||||
[LoginType.RECAPTCHA],
|
[LoginType.RECAPTCHA],
|
||||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
|
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||||
]
|
]
|
||||||
|
if show_msisdn:
|
||||||
|
flows.extend([
|
||||||
|
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
||||||
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
flows = [
|
flows = [
|
||||||
[LoginType.DUMMY],
|
[LoginType.DUMMY],
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY],
|
||||||
]
|
]
|
||||||
|
if show_msisdn:
|
||||||
|
flows.extend([
|
||||||
|
[LoginType.MSISDN],
|
||||||
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||||
|
])
|
||||||
|
|
||||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
|
@ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
registered_user_id
|
registered_user_id
|
||||||
)
|
)
|
||||||
# don't re-register the email address
|
# don't re-register the threepids
|
||||||
add_email = False
|
add_email = False
|
||||||
|
add_msisdn = False
|
||||||
else:
|
else:
|
||||||
# NB: This may be from the auth handler and NOT from the POST
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
|
@ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
add_email = True
|
add_email = True
|
||||||
|
add_msisdn = True
|
||||||
|
|
||||||
return_dict = yield self._create_registration_details(
|
return_dict = yield self._create_registration_details(
|
||||||
registered_user_id, params
|
registered_user_id, params
|
||||||
|
@ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet):
|
||||||
params.get("bind_email")
|
params.get("bind_email")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
|
||||||
|
threepid = auth_result[LoginType.MSISDN]
|
||||||
|
yield self._register_msisdn_threepid(
|
||||||
|
registered_user_id, threepid, return_dict["access_token"],
|
||||||
|
params.get("bind_msisdn")
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, return_dict))
|
defer.returnValue((200, return_dict))
|
||||||
|
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
|
@ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
reqd = ('medium', 'address', 'validated_at')
|
reqd = ('medium', 'address', 'validated_at')
|
||||||
if any(x not in threepid for x in reqd):
|
if any(x not in threepid for x in reqd):
|
||||||
|
# This will only happen if the ID server returns a malformed response
|
||||||
logger.info("Can't add incomplete 3pid")
|
logger.info("Can't add incomplete 3pid")
|
||||||
defer.returnValue()
|
return
|
||||||
|
|
||||||
yield self.auth_handler.add_threepid(
|
yield self.auth_handler.add_threepid(
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
logger.info("bind_email not specified: not binding email")
|
logger.info("bind_email not specified: not binding email")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
|
||||||
|
"""Add a phone number as a 3pid identifier
|
||||||
|
|
||||||
|
Also optionally binds msisdn to the given user_id on the identity server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): id of user
|
||||||
|
threepid (object): m.login.msisdn auth response
|
||||||
|
token (str): access_token for the user
|
||||||
|
bind_email (bool): true if the client requested the email to be
|
||||||
|
bound at the identity server
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
reqd = ('medium', 'address', 'validated_at')
|
||||||
|
if any(x not in threepid for x in reqd):
|
||||||
|
# This will only happen if the ID server returns a malformed response
|
||||||
|
logger.info("Can't add incomplete 3pid")
|
||||||
|
defer.returnValue()
|
||||||
|
|
||||||
|
yield self.auth_handler.add_threepid(
|
||||||
|
user_id,
|
||||||
|
threepid['medium'],
|
||||||
|
threepid['address'],
|
||||||
|
threepid['validated_at'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if bind_msisdn:
|
||||||
|
logger.info("bind_msisdn specified: binding")
|
||||||
|
logger.debug("Binding msisdn %s to %s", threepid, user_id)
|
||||||
|
yield self.identity_handler.bind_threepid(
|
||||||
|
threepid['threepid_creds'], user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("bind_msisdn not specified: not binding msisdn")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_registration_details(self, user_id, params):
|
def _create_registration_details(self, user_id, params):
|
||||||
"""Complete registration of newly-registered user
|
"""Complete registration of newly-registered user
|
||||||
|
@ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# we have nowhere to store it.
|
# we have nowhere to store it.
|
||||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
self.device_handler.check_device_registered(
|
yield self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRequestTokenRestServlet(hs).register(http_server)
|
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||||
|
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet, parse_string, parse_integer, parse_boolean
|
RestServlet, parse_string, parse_integer, parse_boolean
|
||||||
)
|
)
|
||||||
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.handlers.sync import SyncConfig
|
from synapse.handlers.sync import SyncConfig
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
|
@ -28,7 +29,6 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
import copy
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -194,12 +194,18 @@ class SyncRestServlet(RestServlet):
|
||||||
defer.returnValue((200, response_content))
|
defer.returnValue((200, response_content))
|
||||||
|
|
||||||
def encode_presence(self, events, time_now):
|
def encode_presence(self, events, time_now):
|
||||||
formatted = []
|
return {
|
||||||
for event in events:
|
"events": [
|
||||||
event = copy.deepcopy(event)
|
{
|
||||||
event['sender'] = event['content'].pop('user_id')
|
"type": "m.presence",
|
||||||
formatted.append(event)
|
"sender": event.user_id,
|
||||||
return {"events": formatted}
|
"content": format_user_presence_state(
|
||||||
|
event, time_now, include_user_id=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for event in events
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
def encode_joined(self, rooms, time_now, token_id, event_fields):
|
def encode_joined(self, rooms, time_now, token_id, event_fields):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import synapse.http.servlet
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_with_file, respond_404
|
from ._base import parse_media_id, respond_with_file, respond_404
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
@ -81,6 +82,17 @@ class DownloadResource(Resource):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||||
|
# don't forward requests for remote media if allow_remote is false
|
||||||
|
allow_remote = synapse.http.servlet.parse_boolean(
|
||||||
|
request, "allow_remote", default=True)
|
||||||
|
if not allow_remote:
|
||||||
|
logger.info(
|
||||||
|
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||||
|
server_name, media_id,
|
||||||
|
)
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
|
|
|
@ -13,22 +13,23 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer, threads
|
||||||
|
import twisted.internet.error
|
||||||
|
import twisted.web.http
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from .upload_resource import UploadResource
|
from .upload_resource import UploadResource
|
||||||
from .download_resource import DownloadResource
|
from .download_resource import DownloadResource
|
||||||
from .thumbnail_resource import ThumbnailResource
|
from .thumbnail_resource import ThumbnailResource
|
||||||
from .identicon_resource import IdenticonResource
|
from .identicon_resource import IdenticonResource
|
||||||
from .preview_url_resource import PreviewUrlResource
|
from .preview_url_resource import PreviewUrlResource
|
||||||
from .filepath import MediaFilePaths
|
from .filepath import MediaFilePaths
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
|
|
||||||
from .thumbnailer import Thumbnailer
|
from .thumbnailer import Thumbnailer
|
||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||||
|
NotFoundError
|
||||||
from twisted.internet import defer, threads
|
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
@ -157,11 +158,34 @@ class MediaRepository(object):
|
||||||
try:
|
try:
|
||||||
length, headers = yield self.client.get_file(
|
length, headers = yield self.client.get_file(
|
||||||
server_name, request_path, output_stream=f,
|
server_name, request_path, output_stream=f,
|
||||||
max_size=self.max_upload_size,
|
max_size=self.max_upload_size, args={
|
||||||
|
# tell the remote server to 404 if it doesn't
|
||||||
|
# recognise the server_name, to make sure we don't
|
||||||
|
# end up with a routing loop.
|
||||||
|
"allow_remote": "false",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except twisted.internet.error.DNSLookupError as e:
|
||||||
logger.warn("Failed to fetch remoted media %r", e)
|
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||||
raise SynapseError(502, "Failed to fetch remoted media")
|
server_name, media_id, e)
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
except HttpResponseException as e:
|
||||||
|
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||||
|
server_name, media_id, e.response)
|
||||||
|
if e.code == twisted.web.http.NOT_FOUND:
|
||||||
|
raise SynapseError.from_http_response_exception(e)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
|
except SynapseError:
|
||||||
|
logger.exception("Failed to fetch remote media %s/%s",
|
||||||
|
server_name, media_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to fetch remote media %s/%s",
|
||||||
|
server_name, media_id)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers["Content-Type"][0]
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
|
@ -177,17 +177,12 @@ class StateHandler(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
"""Build an EventContext structure for the event.
|
||||||
`current state` here is defined to be the state of the event graph
|
|
||||||
just before the event - i.e. it never includes `event`
|
|
||||||
|
|
||||||
If `event` has `auth_events` then this will also fill out the
|
|
||||||
`auth_events` field on `context` from the `current_state`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (EventBase)
|
event (synapse.events.EventBase):
|
||||||
Returns:
|
Returns:
|
||||||
an EventContext
|
synapse.events.snapshot.EventContext:
|
||||||
"""
|
"""
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
|
|
||||||
|
@ -200,11 +195,11 @@ class StateHandler(object):
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
context.current_state_events = dict(context.prev_state_ids)
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
context.current_state_events[key] = event.event_id
|
context.current_state_ids[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
context.current_state_events = context.prev_state_ids
|
context.current_state_ids = context.prev_state_ids
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = {}
|
context.current_state_ids = {}
|
||||||
context.prev_state_ids = {}
|
context.prev_state_ids = {}
|
||||||
|
|
|
@ -73,6 +73,9 @@ class LoggingTransaction(object):
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
setattr(self.txn, name, value)
|
setattr(self.txn, name, value)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self.txn.__iter__()
|
||||||
|
|
||||||
def execute(self, sql, *args):
|
def execute(self, sql, *args):
|
||||||
self._do_execute(self.txn.execute, sql, *args)
|
self._do_execute(self.txn.execute, sql, *args)
|
||||||
|
|
||||||
|
@ -132,7 +135,7 @@ class PerformanceCounters(object):
|
||||||
|
|
||||||
def interval(self, interval_duration, limit=3):
|
def interval(self, interval_duration, limit=3):
|
||||||
counters = []
|
counters = []
|
||||||
for name, (count, cum_time) in self.current_counters.items():
|
for name, (count, cum_time) in self.current_counters.iteritems():
|
||||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||||
counters.append((
|
counters.append((
|
||||||
(cum_time - prev_time) / interval_duration,
|
(cum_time - prev_time) / interval_duration,
|
||||||
|
@ -357,7 +360,7 @@ class SQLBaseStore(object):
|
||||||
"""
|
"""
|
||||||
col_headers = list(intern(column[0]) for column in cursor.description)
|
col_headers = list(intern(column[0]) for column in cursor.description)
|
||||||
results = list(
|
results = list(
|
||||||
dict(zip(col_headers, row)) for row in cursor.fetchall()
|
dict(zip(col_headers, row)) for row in cursor
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -565,7 +568,7 @@ class SQLBaseStore(object):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||||
else:
|
else:
|
||||||
where = ""
|
where = ""
|
||||||
|
|
||||||
|
@ -579,7 +582,7 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
txn.execute(sql, keyvalues.values())
|
txn.execute(sql, keyvalues.values())
|
||||||
|
|
||||||
return [r[0] for r in txn.fetchall()]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||||
desc="_simple_select_onecol"):
|
desc="_simple_select_onecol"):
|
||||||
|
@ -712,7 +715,7 @@ class SQLBaseStore(object):
|
||||||
)
|
)
|
||||||
values.extend(iterable)
|
values.extend(iterable)
|
||||||
|
|
||||||
for key, value in keyvalues.items():
|
for key, value in keyvalues.iteritems():
|
||||||
clauses.append("%s = ?" % (key,))
|
clauses.append("%s = ?" % (key,))
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
|
@ -753,7 +756,7 @@ class SQLBaseStore(object):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||||
else:
|
else:
|
||||||
where = ""
|
where = ""
|
||||||
|
|
||||||
|
@ -840,6 +843,47 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
return txn.execute(sql, keyvalues.values())
|
return txn.execute(sql, keyvalues.values())
|
||||||
|
|
||||||
|
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
|
||||||
|
return self.runInteraction(
|
||||||
|
desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
|
||||||
|
"""Executes a DELETE query on the named table.
|
||||||
|
|
||||||
|
Filters rows by if value of `column` is in `iterable`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn : Transaction object
|
||||||
|
table : string giving the table name
|
||||||
|
column : column name to test for inclusion against `iterable`
|
||||||
|
iterable : list
|
||||||
|
keyvalues : dict of column names and values to select the rows with
|
||||||
|
"""
|
||||||
|
if not iterable:
|
||||||
|
return
|
||||||
|
|
||||||
|
sql = "DELETE FROM %s" % table
|
||||||
|
|
||||||
|
clauses = []
|
||||||
|
values = []
|
||||||
|
clauses.append(
|
||||||
|
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||||
|
)
|
||||||
|
values.extend(iterable)
|
||||||
|
|
||||||
|
for key, value in keyvalues.iteritems():
|
||||||
|
clauses.append("%s = ?" % (key,))
|
||||||
|
values.append(value)
|
||||||
|
|
||||||
|
if clauses:
|
||||||
|
sql = "%s WHERE %s" % (
|
||||||
|
sql,
|
||||||
|
" AND ".join(clauses),
|
||||||
|
)
|
||||||
|
return txn.execute(sql, values)
|
||||||
|
|
||||||
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
||||||
max_value, limit=100000):
|
max_value, limit=100000):
|
||||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||||
|
@ -860,16 +904,16 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
txn = db_conn.cursor()
|
txn = db_conn.cursor()
|
||||||
txn.execute(sql, (int(max_value),))
|
txn.execute(sql, (int(max_value),))
|
||||||
rows = txn.fetchall()
|
|
||||||
txn.close()
|
|
||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
row[0]: int(row[1])
|
row[0]: int(row[1])
|
||||||
for row in rows
|
for row in txn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
txn.close()
|
||||||
|
|
||||||
if cache:
|
if cache:
|
||||||
min_val = min(cache.values())
|
min_val = min(cache.itervalues())
|
||||||
else:
|
else:
|
||||||
min_val = max_value
|
min_val = max_value
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_id, stream_id))
|
txn.execute(sql, (user_id, stream_id))
|
||||||
|
|
||||||
global_account_data = {
|
global_account_data = {
|
||||||
row[0]: json.loads(row[1]) for row in txn.fetchall()
|
row[0]: json.loads(row[1]) for row in txn
|
||||||
}
|
}
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_id, stream_id))
|
txn.execute(sql, (user_id, stream_id))
|
||||||
|
|
||||||
account_data_by_room = {}
|
account_data_by_room = {}
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
room_account_data = account_data_by_room.setdefault(row[0], {})
|
room_account_data = account_data_by_room.setdefault(row[0], {})
|
||||||
room_account_data[row[1]] = json.loads(row[2])
|
room_account_data[row[1]] = json.loads(row[2])
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import synapse.util.async
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from . import engines
|
from . import engines
|
||||||
|
@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_performance = {}
|
self._background_update_performance = {}
|
||||||
self._background_update_queue = []
|
self._background_update_queue = []
|
||||||
self._background_update_handlers = {}
|
self._background_update_handlers = {}
|
||||||
self._background_update_timer = None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start_doing_background_updates(self):
|
def start_doing_background_updates(self):
|
||||||
assert self._background_update_timer is None, \
|
|
||||||
"background updates already running"
|
|
||||||
|
|
||||||
logger.info("Starting background schema updates")
|
logger.info("Starting background schema updates")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
sleep = defer.Deferred()
|
yield synapse.util.async.sleep(
|
||||||
self._background_update_timer = self._clock.call_later(
|
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
yield sleep
|
|
||||||
finally:
|
|
||||||
self._background_update_timer = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = yield self.do_next_background_update(
|
result = yield self.do_next_background_update(
|
||||||
|
|
|
@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
message_json = ujson.dumps(messages_by_device["*"])
|
message_json = ujson.dumps(messages_by_device["*"])
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
# Add the message for all devices for this user on this
|
# Add the message for all devices for this user on this
|
||||||
# server.
|
# server.
|
||||||
device = row[0]
|
device = row[0]
|
||||||
|
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
# TODO: Maybe this needs to be done in batches if there are
|
# TODO: Maybe this needs to be done in batches if there are
|
||||||
# too many local devices for a given user.
|
# too many local devices for a given user.
|
||||||
txn.execute(sql, [user_id] + devices)
|
txn.execute(sql, [user_id] + devices)
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
# Only insert into the local inbox if the device exists on
|
# Only insert into the local inbox if the device exists on
|
||||||
# this server
|
# this server
|
||||||
device = row[0]
|
device = row[0]
|
||||||
|
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
user_id, device_id, last_stream_id, current_stream_id, limit
|
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||||
))
|
))
|
||||||
messages = []
|
messages = []
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
stream_pos = row[0]
|
stream_pos = row[0]
|
||||||
messages.append(ujson.loads(row[1]))
|
messages.append(ujson.loads(row[1]))
|
||||||
if len(messages) < limit:
|
if len(messages) < limit:
|
||||||
|
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
" ORDER BY stream_id ASC"
|
" ORDER BY stream_id ASC"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_pos, upper_pos))
|
txn.execute(sql, (last_pos, upper_pos))
|
||||||
rows.extend(txn.fetchall())
|
rows.extend(txn)
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
destination(str): The name of the remote server.
|
destination(str): The name of the remote server.
|
||||||
last_stream_id(int): The last position of the device message stream
|
last_stream_id(int|long): The last position of the device message stream
|
||||||
that the server sent up to.
|
that the server sent up to.
|
||||||
current_stream_id(int): The current position of the device
|
current_stream_id(int|long): The current position of the device
|
||||||
message stream.
|
message stream.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred ([dict], int): List of messages for the device and where
|
Deferred ([dict], int|long): List of messages for the device and where
|
||||||
in the stream the messages got to.
|
in the stream the messages got to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
destination, last_stream_id, current_stream_id, limit
|
destination, last_stream_id, current_stream_id, limit
|
||||||
))
|
))
|
||||||
messages = []
|
messages = []
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
stream_pos = row[0]
|
stream_pos = row[0]
|
||||||
messages.append(ujson.loads(row[1]))
|
messages.append(ujson.loads(row[1]))
|
||||||
if len(messages) < limit:
|
if len(messages) < limit:
|
||||||
|
|
|
@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
|
||||||
desc="delete_device",
|
desc="delete_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_devices(self, user_id, device_ids):
|
||||||
|
"""Deletes several devices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the devices
|
||||||
|
device_ids (list): The IDs of the devices to delete
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
return self._simple_delete_many(
|
||||||
|
table="devices",
|
||||||
|
column="device_id",
|
||||||
|
iterable=device_ids,
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
desc="delete_devices",
|
||||||
|
)
|
||||||
|
|
||||||
def update_device(self, user_id, device_id, new_display_name=None):
|
def update_device(self, user_id, device_id, new_display_name=None):
|
||||||
"""Update a device.
|
"""Update a device.
|
||||||
|
|
||||||
|
@ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore):
|
||||||
"""Get stream of updates to send to remote servers
|
"""Get stream of updates to send to remote servers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(now_stream_id, [ { updates }, .. ])
|
(int, list[dict]): current stream id and list of updates
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore):
|
||||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||||
GROUP BY user_id, device_id
|
GROUP BY user_id, device_id
|
||||||
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql, (destination, from_stream_id, now_stream_id, False)
|
sql, (destination, from_stream_id, now_stream_id, False)
|
||||||
)
|
)
|
||||||
rows = txn.fetchall()
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return (now_stream_id, [])
|
|
||||||
|
|
||||||
# maps (user_id, device_id) -> stream_id
|
# maps (user_id, device_id) -> stream_id
|
||||||
query_map = {(r[0], r[1]): r[2] for r in rows}
|
query_map = {(r[0], r[1]): r[2] for r in txn}
|
||||||
|
if not query_map:
|
||||||
|
return (now_stream_id, [])
|
||||||
|
|
||||||
|
if len(query_map) >= 20:
|
||||||
|
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(
|
devices = self._get_e2e_device_keys_txn(
|
||||||
txn, query_map.keys(), include_all_devices=True
|
txn, query_map.keys(), include_all_devices=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
|
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||||
def _add_e2e_one_time_keys(txn):
|
"""Insert some new one time keys for a device.
|
||||||
for (algorithm, key_id, json_bytes) in key_list:
|
|
||||||
self._simple_upsert_txn(
|
Checks if any of the keys are already inserted, if they are then check
|
||||||
txn, table="e2e_one_time_keys_json",
|
if they match. If they don't then we raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First we check if we have already persisted any of the keys.
|
||||||
|
rows = yield self._simple_select_many_batch(
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
column="key_id",
|
||||||
|
iterable=[key_id for _, key_id, _ in key_list],
|
||||||
|
retcols=("algorithm", "key_id", "key_json",),
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
desc="add_e2e_one_time_keys_check",
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_key_map = {
|
||||||
|
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
new_keys = [] # Keys that we need to insert
|
||||||
|
for algorithm, key_id, json_bytes in key_list:
|
||||||
|
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
||||||
|
if ex_bytes:
|
||||||
|
if json_bytes != ex_bytes:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "One time key with key_id %r already exists" % (key_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_keys.append((algorithm, key_id, json_bytes))
|
||||||
|
|
||||||
|
def _add_e2e_one_time_keys(txn):
|
||||||
|
# We are protected from race between lookup and insertion due to
|
||||||
|
# a unique constraint. If there is a race of two calls to
|
||||||
|
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||||
|
# insert one set.
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn, table="e2e_one_time_keys_json",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
"algorithm": algorithm,
|
"algorithm": algorithm,
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
},
|
|
||||||
values={
|
|
||||||
"ts_added_ms": time_now,
|
"ts_added_ms": time_now,
|
||||||
"key_json": json_bytes,
|
"key_json": json_bytes,
|
||||||
}
|
}
|
||||||
|
for algorithm, key_id, json_bytes in new_keys
|
||||||
|
],
|
||||||
)
|
)
|
||||||
return self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||||
|
@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, device_id))
|
txn.execute(sql, (user_id, device_id))
|
||||||
result = {}
|
result = {}
|
||||||
for algorithm, key_count in txn.fetchall():
|
for algorithm, key_count in txn:
|
||||||
result[algorithm] = key_count
|
result[algorithm] = key_count
|
||||||
return result
|
return result
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
|
@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
user_result = result.setdefault(user_id, {})
|
user_result = result.setdefault(user_id, {})
|
||||||
device_result = user_result.setdefault(device_id, {})
|
device_result = user_result.setdefault(device_id, {})
|
||||||
txn.execute(sql, (user_id, device_id, algorithm))
|
txn.execute(sql, (user_id, device_id, algorithm))
|
||||||
for key_id, key_json in txn.fetchall():
|
for key_id, key_json in txn:
|
||||||
device_result[algorithm + ":" + key_id] = key_json
|
device_result[algorithm + ":" + key_id] = key_json
|
||||||
delete.append((user_id, device_id, algorithm, key_id))
|
delete.append((user_id, device_id, algorithm, key_id))
|
||||||
sql = (
|
sql = (
|
||||||
|
|
|
@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
base_sql % (",".join(["?"] * len(chunk)),),
|
base_sql % (",".join(["?"] * len(chunk)),),
|
||||||
chunk
|
chunk
|
||||||
)
|
)
|
||||||
new_front.update([r[0] for r in txn.fetchall()])
|
new_front.update([r[0] for r in txn])
|
||||||
|
|
||||||
new_front -= results
|
new_front -= results
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(sql, (room_id, False,))
|
txn.execute(sql, (room_id, False,))
|
||||||
|
|
||||||
return dict(txn.fetchall())
|
return dict(txn)
|
||||||
|
|
||||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||||
return self._simple_select_onecol_txn(
|
return self._simple_select_onecol_txn(
|
||||||
|
@ -201,9 +201,9 @@ class EventFederationStore(SQLBaseStore):
|
||||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||||
min_depth = self._get_min_depth_interaction(txn, room_id)
|
min_depth = self._get_min_depth_interaction(txn, room_id)
|
||||||
|
|
||||||
do_insert = depth < min_depth if min_depth else True
|
if min_depth and depth >= min_depth:
|
||||||
|
return
|
||||||
|
|
||||||
if do_insert:
|
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="room_depth",
|
table="room_depth",
|
||||||
|
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
def get_forward_extremeties_for_room_txn(txn):
|
def get_forward_extremeties_for_room_txn(txn):
|
||||||
txn.execute(sql, (stream_ordering, room_id))
|
txn.execute(sql, (stream_ordering, room_id))
|
||||||
rows = txn.fetchall()
|
return [event_id for event_id, in txn]
|
||||||
return [event_id for event_id, in rows]
|
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_forward_extremeties_for_room",
|
"get_forward_extremeties_for_room",
|
||||||
|
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
(room_id, event_id, False, limit - len(event_results))
|
(room_id, event_id, False, limit - len(event_results))
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in txn.fetchall():
|
for row in txn:
|
||||||
if row[1] not in event_results:
|
if row[1] not in event_results:
|
||||||
queue.put((-row[0], row[1]))
|
queue.put((-row[0], row[1]))
|
||||||
|
|
||||||
|
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
(room_id, event_id, False, limit - len(event_results))
|
(room_id, event_id, False, limit - len(event_results))
|
||||||
)
|
)
|
||||||
|
|
||||||
for e_id, in txn.fetchall():
|
for e_id, in txn:
|
||||||
new_front.add(e_id)
|
new_front.add(e_id)
|
||||||
|
|
||||||
new_front -= earliest_events
|
new_front -= earliest_events
|
||||||
|
|
|
@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
" stream_ordering >= ? AND stream_ordering <= ?"
|
" stream_ordering >= ? AND stream_ordering <= ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||||
return [r[0] for r in txn.fetchall()]
|
return [r[0] for r in txn]
|
||||||
ret = yield self.runInteraction("get_push_action_users_in_range", f)
|
ret = yield self.runInteraction("get_push_action_users_in_range", f)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
|
|
@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
|
||||||
from collections import deque, namedtuple, OrderedDict
|
from collections import deque, namedtuple, OrderedDict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import synapse
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
|
# these are only included to make the type annotations work
|
||||||
|
from synapse.events import EventBase # noqa: F401
|
||||||
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
|
||||||
|
|
||||||
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
||||||
"""Add events to the queue, with the given persist_event options.
|
"""Add events to the queue, with the given persist_event options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str):
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]):
|
||||||
|
backfilled (bool):
|
||||||
"""
|
"""
|
||||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||||
if queue:
|
if queue:
|
||||||
|
@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
|
||||||
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for room_id, evs_ctxs in partitioned.items():
|
for room_id, evs_ctxs in partitioned.iteritems():
|
||||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||||
room_id, evs_ctxs,
|
room_id, evs_ctxs,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
deferreds.append(d)
|
deferreds.append(d)
|
||||||
|
|
||||||
for room_id in partitioned.keys():
|
for room_id in partitioned:
|
||||||
self._maybe_start_persisting(room_id)
|
self._maybe_start_persisting(room_id)
|
||||||
|
|
||||||
return preserve_context_over_deferred(
|
return preserve_context_over_deferred(
|
||||||
|
@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def persist_event(self, event, context, backfilled=False):
|
def persist_event(self, event, context, backfilled=False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (EventBase):
|
||||||
|
context (EventContext):
|
||||||
|
backfilled (bool):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: resolves to (int, int): the stream ordering of ``event``,
|
||||||
|
and the stream ordering of the latest persisted event
|
||||||
|
"""
|
||||||
deferred = self._event_persist_queue.add_to_queue(
|
deferred = self._event_persist_queue.add_to_queue(
|
||||||
event.room_id, [(event, context)],
|
event.room_id, [(event, context)],
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
|
@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _persist_events(self, events_and_contexts, backfilled=False,
|
def _persist_events(self, events_and_contexts, backfilled=False,
|
||||||
delete_existing=False):
|
delete_existing=False):
|
||||||
|
"""Persist events to db
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]):
|
||||||
|
backfilled (bool):
|
||||||
|
delete_existing (bool):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: resolves when the events have been persisted
|
||||||
|
"""
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
|
||||||
(event, context)
|
(event, context)
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id, ev_ctx_rm in events_by_room.items():
|
for room_id, ev_ctx_rm in events_by_room.iteritems():
|
||||||
# Work out new extremities by recursively adding and removing
|
# Work out new extremities by recursively adding and removing
|
||||||
# the new events.
|
# the new events.
|
||||||
latest_event_ids = yield self.get_latest_event_ids_in_room(
|
latest_event_ids = yield self.get_latest_event_ids_in_room(
|
||||||
|
@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
|
||||||
# Now we need to work out the different state sets for
|
# Now we need to work out the different state sets for
|
||||||
# each state extremities
|
# each state extremities
|
||||||
state_sets = []
|
state_sets = []
|
||||||
|
state_groups = set()
|
||||||
missing_event_ids = []
|
missing_event_ids = []
|
||||||
was_updated = False
|
was_updated = False
|
||||||
for event_id in new_latest_event_ids:
|
for event_id in new_latest_event_ids:
|
||||||
|
@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
|
||||||
if event_id == ev.event_id:
|
if event_id == ev.event_id:
|
||||||
if ctx.current_state_ids is None:
|
if ctx.current_state_ids is None:
|
||||||
raise Exception("Unknown current state")
|
raise Exception("Unknown current state")
|
||||||
|
|
||||||
|
# If we've already seen the state group don't bother adding
|
||||||
|
# it to the state sets again
|
||||||
|
if ctx.state_group not in state_groups:
|
||||||
state_sets.append(ctx.current_state_ids)
|
state_sets.append(ctx.current_state_ids)
|
||||||
if ctx.delta_ids or hasattr(ev, "state_key"):
|
if ctx.delta_ids or hasattr(ev, "state_key"):
|
||||||
was_updated = True
|
was_updated = True
|
||||||
|
if ctx.state_group:
|
||||||
|
# Add this as a seen state group (if it has a state
|
||||||
|
# group)
|
||||||
|
state_groups.add(ctx.state_group)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If we couldn't find it, then we'll need to pull
|
# If we couldn't find it, then we'll need to pull
|
||||||
|
@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
|
||||||
missing_event_ids,
|
missing_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.itervalues()) - state_groups
|
||||||
group_to_state = yield self._get_state_for_groups(groups)
|
|
||||||
|
|
||||||
state_sets.extend(group_to_state.values())
|
if groups:
|
||||||
|
group_to_state = yield self._get_state_for_groups(groups)
|
||||||
|
state_sets.extend(group_to_state.itervalues())
|
||||||
|
|
||||||
if not new_latest_event_ids:
|
if not new_latest_event_ids:
|
||||||
current_state = {}
|
current_state = {}
|
||||||
elif was_updated:
|
elif was_updated:
|
||||||
|
if len(state_sets) == 1:
|
||||||
|
# If there is only one state set, then we know what the current
|
||||||
|
# state is.
|
||||||
|
current_state = state_sets[0]
|
||||||
|
else:
|
||||||
|
# We work out the current state by passing the state sets to the
|
||||||
|
# state resolution algorithm. It may ask for some events, including
|
||||||
|
# the events we have yet to persist, so we need a slightly more
|
||||||
|
# complicated event lookup function than simply looking the events
|
||||||
|
# up in the db.
|
||||||
|
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_events(ev_ids):
|
||||||
|
# We get the events by first looking at the list of events we
|
||||||
|
# are trying to persist, and then fetching the rest from the DB.
|
||||||
|
db = []
|
||||||
|
to_return = {}
|
||||||
|
for ev_id in ev_ids:
|
||||||
|
ev = events_map.get(ev_id, None)
|
||||||
|
if ev:
|
||||||
|
to_return[ev_id] = ev
|
||||||
|
else:
|
||||||
|
db.append(ev_id)
|
||||||
|
|
||||||
|
if db:
|
||||||
|
evs = yield self.get_events(
|
||||||
|
ev_ids, get_prev_content=False, check_redacted=False,
|
||||||
|
)
|
||||||
|
to_return.update(evs)
|
||||||
|
defer.returnValue(to_return)
|
||||||
|
|
||||||
current_state = yield resolve_events(
|
current_state = yield resolve_events(
|
||||||
state_sets,
|
state_sets,
|
||||||
state_map_factory=lambda ev_ids: self.get_events(
|
state_map_factory=get_events,
|
||||||
ev_ids, get_prev_content=False, check_redacted=False,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
existing_state_rows = yield self._simple_select_list(
|
existing_state = yield self.get_current_state_ids(room_id)
|
||||||
table="current_state_events",
|
|
||||||
keyvalues={"room_id": room_id},
|
|
||||||
retcols=["event_id", "type", "state_key"],
|
|
||||||
desc="_calculate_state_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_events = set(row["event_id"] for row in existing_state_rows)
|
existing_events = set(existing_state.itervalues())
|
||||||
new_events = set(ev_id for ev_id in current_state.itervalues())
|
new_events = set(ev_id for ev_id in current_state.itervalues())
|
||||||
changed_events = existing_events ^ new_events
|
changed_events = existing_events ^ new_events
|
||||||
|
|
||||||
|
@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
|
||||||
return
|
return
|
||||||
|
|
||||||
to_delete = {
|
to_delete = {
|
||||||
(row["type"], row["state_key"]): row["event_id"]
|
key: ev_id for key, ev_id in existing_state.iteritems()
|
||||||
for row in existing_state_rows
|
if ev_id in changed_events
|
||||||
if row["event_id"] in changed_events
|
|
||||||
}
|
}
|
||||||
events_to_insert = (new_events - existing_events)
|
events_to_insert = (new_events - existing_events)
|
||||||
to_insert = {
|
to_insert = {
|
||||||
|
@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
|
||||||
and the rejections table. Things reading from those table will need to check
|
and the rejections table. Things reading from those table will need to check
|
||||||
whether the event was rejected.
|
whether the event was rejected.
|
||||||
|
|
||||||
If delete_existing is True then existing events will be purged from the
|
Args:
|
||||||
database before insertion. This is useful when retrying due to IntegrityError.
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]):
|
||||||
|
events to persist
|
||||||
|
backfilled (bool): True if the events were backfilled
|
||||||
|
delete_existing (bool): True to purge existing table rows for the
|
||||||
|
events from the database. This is useful when retrying due to
|
||||||
|
IntegrityError.
|
||||||
|
current_state_for_room (dict[str, (list[str], list[str])]):
|
||||||
|
The current-state delta for each room. For each room, a tuple
|
||||||
|
(to_delete, to_insert), being a list of event ids to be removed
|
||||||
|
from the current state, and a list of event ids to be added to
|
||||||
|
the current state.
|
||||||
|
new_forward_extremeties (dict[str, list[str]]):
|
||||||
|
The new forward extremities for each room. For each room, a
|
||||||
|
list of the event ids which are the forward extremities.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
self._update_current_state_txn(txn, current_state_for_room)
|
||||||
|
|
||||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||||
for room_id, current_state_tuple in current_state_for_room.iteritems():
|
self._update_forward_extremities_txn(
|
||||||
|
txn,
|
||||||
|
new_forward_extremities=new_forward_extremeties,
|
||||||
|
max_stream_order=max_stream_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure that we don't have the same event twice.
|
||||||
|
events_and_contexts = self._filter_events_and_contexts_for_duplicates(
|
||||||
|
events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_room_depths_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# _update_outliers_txn filters out any events which have already been
|
||||||
|
# persisted, and returns the filtered list.
|
||||||
|
events_and_contexts = self._update_outliers_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# From this point onwards the events are only events that we haven't
|
||||||
|
# seen before.
|
||||||
|
|
||||||
|
if delete_existing:
|
||||||
|
# For paranoia reasons, we go and delete all the existing entries
|
||||||
|
# for these events so we can reinsert them.
|
||||||
|
# This gets around any problems with some tables already having
|
||||||
|
# entries.
|
||||||
|
self._delete_existing_rows_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._store_event_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert into the state_groups, state_groups_state, and
|
||||||
|
# event_to_state_groups tables.
|
||||||
|
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||||
|
|
||||||
|
# _store_rejected_events_txn filters out any events which were
|
||||||
|
# rejected, and returns the filtered list.
|
||||||
|
events_and_contexts = self._store_rejected_events_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# From this point onwards the events are only ones that weren't
|
||||||
|
# rejected.
|
||||||
|
|
||||||
|
self._update_metadata_tables_txn(
|
||||||
|
txn,
|
||||||
|
events_and_contexts=events_and_contexts,
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_current_state_txn(self, txn, state_delta_by_room):
|
||||||
|
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
||||||
to_delete, to_insert = current_state_tuple
|
to_delete, to_insert = current_state_tuple
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
"DELETE FROM current_state_events WHERE event_id = ?",
|
"DELETE FROM current_state_events WHERE event_id = ?",
|
||||||
|
@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
|
||||||
txn, self.get_users_in_room, (room_id,)
|
txn, self.get_users_in_room, (room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id, new_extrem in new_forward_extremeties.items():
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_current_state_ids, (room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
|
||||||
|
max_stream_order):
|
||||||
|
for room_id, new_extrem in new_forward_extremities.iteritems():
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_forward_extremities",
|
table="event_forward_extremities",
|
||||||
|
@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
|
||||||
"event_id": ev_id,
|
"event_id": ev_id,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
}
|
}
|
||||||
for room_id, new_extrem in new_forward_extremeties.items()
|
for room_id, new_extrem in new_forward_extremities.iteritems()
|
||||||
for ev_id in new_extrem
|
for ev_id in new_extrem
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
"stream_ordering": max_stream_order,
|
"stream_ordering": max_stream_order,
|
||||||
}
|
}
|
||||||
for room_id, new_extrem in new_forward_extremeties.items()
|
for room_id, new_extrem in new_forward_extremities.iteritems()
|
||||||
for event_id in new_extrem
|
for event_id in new_extrem
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that we don't have the same event twice.
|
@classmethod
|
||||||
# Pick the earliest non-outlier if there is one, else the earliest one.
|
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
|
||||||
|
"""Ensure that we don't have the same event twice.
|
||||||
|
|
||||||
|
Pick the earliest non-outlier if there is one, else the earliest one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]):
|
||||||
|
Returns:
|
||||||
|
list[(EventBase, EventContext)]: filtered list
|
||||||
|
"""
|
||||||
new_events_and_contexts = OrderedDict()
|
new_events_and_contexts = OrderedDict()
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
prev_event_context = new_events_and_contexts.get(event.event_id)
|
prev_event_context = new_events_and_contexts.get(event.event_id)
|
||||||
|
@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
|
||||||
new_events_and_contexts[event.event_id] = (event, context)
|
new_events_and_contexts[event.event_id] = (event, context)
|
||||||
else:
|
else:
|
||||||
new_events_and_contexts[event.event_id] = (event, context)
|
new_events_and_contexts[event.event_id] = (event, context)
|
||||||
|
return new_events_and_contexts.values()
|
||||||
|
|
||||||
events_and_contexts = new_events_and_contexts.values()
|
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
|
||||||
|
"""Update min_depth for each room
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
backfilled (bool): True if the events were backfilled
|
||||||
|
"""
|
||||||
depth_updates = {}
|
depth_updates = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
|
@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
|
||||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id, depth in depth_updates.items():
|
for room_id, depth in depth_updates.iteritems():
|
||||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||||
|
|
||||||
|
def _update_outliers_txn(self, txn, events_and_contexts):
|
||||||
|
"""Update any outliers with new event info.
|
||||||
|
|
||||||
|
This turns outliers into ex-outliers (unless the new event was
|
||||||
|
rejected).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[(EventBase, EventContext)] new list, without events which
|
||||||
|
are already in the events table.
|
||||||
|
"""
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
|
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
|
||||||
",".join(["?"] * len(events_and_contexts)),
|
",".join(["?"] * len(events_and_contexts)),
|
||||||
|
@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
have_persisted = {
|
have_persisted = {
|
||||||
event_id: outlier
|
event_id: outlier
|
||||||
for event_id, outlier in txn.fetchall()
|
for event_id, outlier in txn
|
||||||
}
|
}
|
||||||
|
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
if context.rejected:
|
|
||||||
# If the event is rejected then we don't care if the event
|
|
||||||
# was an outlier or not.
|
|
||||||
if event.event_id in have_persisted:
|
|
||||||
# If we have already seen the event then ignore it.
|
|
||||||
to_remove.add(event)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event.event_id not in have_persisted:
|
if event.event_id not in have_persisted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_remove.add(event)
|
to_remove.add(event)
|
||||||
|
|
||||||
|
if context.rejected:
|
||||||
|
# If the event is rejected then we don't care if the event
|
||||||
|
# was an outlier or not.
|
||||||
|
continue
|
||||||
|
|
||||||
outlier_persisted = have_persisted[event.event_id]
|
outlier_persisted = have_persisted[event.event_id]
|
||||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||||
# We received a copy of an event that we had already stored as
|
# We received a copy of an event that we had already stored as
|
||||||
|
@ -741,34 +918,16 @@ class EventsStore(SQLBaseStore):
|
||||||
# event isn't an outlier any more.
|
# event isn't an outlier any more.
|
||||||
self._update_backward_extremeties(txn, [event])
|
self._update_backward_extremeties(txn, [event])
|
||||||
|
|
||||||
events_and_contexts = [
|
return [
|
||||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
# Make sure we don't pass an empty list to functions that expect to
|
# nothing to do here
|
||||||
# be storing at least one element.
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# From this point onwards the events are only events that we haven't
|
|
||||||
# seen before.
|
|
||||||
|
|
||||||
def event_dict(event):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in event.get_dict().items()
|
|
||||||
if k not in [
|
|
||||||
"redacted",
|
|
||||||
"redacted_because",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
if delete_existing:
|
|
||||||
# For paranoia reasons, we go and delete all the existing entries
|
|
||||||
# for these events so we can reinsert them.
|
|
||||||
# This gets around any problems with some tables already having
|
|
||||||
# entries.
|
|
||||||
|
|
||||||
logger.info("Deleting existing")
|
logger.info("Deleting existing")
|
||||||
|
|
||||||
for table in (
|
for table in (
|
||||||
|
@ -800,6 +959,25 @@ class EventsStore(SQLBaseStore):
|
||||||
[(ev.event_id,) for ev, _ in events_and_contexts]
|
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _store_event_txn(self, txn, events_and_contexts):
|
||||||
|
"""Insert new events into the event and event_json tables
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not events_and_contexts:
|
||||||
|
# nothing to do here
|
||||||
|
return
|
||||||
|
|
||||||
|
def event_dict(event):
|
||||||
|
d = event.get_dict()
|
||||||
|
d.pop("redacted", None)
|
||||||
|
d.pop("redacted_because", None)
|
||||||
|
return d
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="event_json",
|
||||||
|
@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _store_rejected_events_txn(self, txn, events_and_contexts):
|
||||||
|
"""Add rows to the 'rejections' table for received events which were
|
||||||
|
rejected
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[(EventBase, EventContext)] new list, without the rejected
|
||||||
|
events.
|
||||||
|
"""
|
||||||
# Remove the rejected events from the list now that we've added them
|
# Remove the rejected events from the list now that we've added them
|
||||||
# to the events table and the events_json table.
|
# to the events table and the events_json table.
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
|
@ -853,16 +1044,23 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
to_remove.add(event)
|
to_remove.add(event)
|
||||||
|
|
||||||
events_and_contexts = [
|
return [
|
||||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||||
]
|
]
|
||||||
|
|
||||||
if not events_and_contexts:
|
def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
|
||||||
# Make sure we don't pass an empty list to functions that expect to
|
"""Update all the miscellaneous tables for new events
|
||||||
# be storing at least one element.
|
|
||||||
return
|
|
||||||
|
|
||||||
# From this point onwards the events are only ones that weren't rejected.
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||||
|
we are persisting
|
||||||
|
backfilled (bool): True if the events were backfilled
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not events_and_contexts:
|
||||||
|
# nothing to do here
|
||||||
|
return
|
||||||
|
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Insert all the push actions into the event_push_actions table.
|
# Insert all the push actions into the event_push_actions table.
|
||||||
|
@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert into the state_groups, state_groups_state, and
|
|
||||||
# event_to_state_groups tables.
|
|
||||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
|
||||||
|
|
||||||
# Update the event_forward_extremities, event_backward_extremities and
|
# Update the event_forward_extremities, event_backward_extremities and
|
||||||
# event_edges tables.
|
# event_edges tables.
|
||||||
self._handle_mult_prev_events(
|
self._handle_mult_prev_events(
|
||||||
|
@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
|
||||||
# Prefill the event cache
|
# Prefill the event cache
|
||||||
self._add_to_cache(txn, events_and_contexts)
|
self._add_to_cache(txn, events_and_contexts)
|
||||||
|
|
||||||
if backfilled:
|
|
||||||
# Backfilled events come before the current state so we don't need
|
|
||||||
# to update the current state table
|
|
||||||
return
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
def _add_to_cache(self, txn, events_and_contexts):
|
def _add_to_cache(self, txn, events_and_contexts):
|
||||||
to_prefill = []
|
to_prefill = []
|
||||||
|
|
||||||
|
@ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
def get_all_new_events_txn(txn):
|
def get_all_new_events_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
|
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" FROM events as e"
|
" state_key, redacts"
|
||||||
" JOIN event_json as ej"
|
" FROM events AS e"
|
||||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
" LEFT JOIN event_to_state_groups as eg"
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
" ON e.event_id = eg.event_id"
|
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||||
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
|
" ORDER BY stream_ordering ASC"
|
||||||
" ORDER BY e.stream_ordering ASC"
|
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
if have_forward_events:
|
if have_forward_events:
|
||||||
|
@ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore):
|
||||||
forward_ex_outliers = []
|
forward_ex_outliers = []
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
|
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" eg.state_group"
|
" state_key, redacts"
|
||||||
" FROM events as e"
|
" FROM events AS e"
|
||||||
" JOIN event_json as ej"
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
" LEFT JOIN event_to_state_groups as eg"
|
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||||
" ON e.event_id = eg.event_id"
|
" ORDER BY stream_ordering DESC"
|
||||||
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
|
|
||||||
" ORDER BY e.stream_ordering DESC"
|
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
if have_backfill_events:
|
if have_backfill_events:
|
||||||
|
@ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore):
|
||||||
"state_key": key[1],
|
"state_key": key[1],
|
||||||
"event_id": state_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for key, state_id in curr_state.items()
|
for key, state_id in curr_state.iteritems()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
|
||||||
key_ids
|
key_ids
|
||||||
Args:
|
Args:
|
||||||
server_name (str): The name of the server.
|
server_name (str): The name of the server.
|
||||||
key_ids (list of str): List of key_ids to try and look up.
|
key_ids (iterable[str]): key_ids to try and look up.
|
||||||
Returns:
|
Returns:
|
||||||
(list of VerifyKey): The verification keys.
|
Deferred: resolves to dict[str, VerifyKey]: map from
|
||||||
|
key_id to verification key.
|
||||||
"""
|
"""
|
||||||
keys = {}
|
keys = {}
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
|
|
|
@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||||
),
|
),
|
||||||
(current_version,)
|
(current_version,)
|
||||||
)
|
)
|
||||||
applied_deltas = [d for d, in txn.fetchall()]
|
applied_deltas = [d for d, in txn]
|
||||||
return current_version, applied_deltas, upgraded
|
return current_version, applied_deltas, upgraded
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
|
||||||
self.presence_stream_cache.entity_has_changed,
|
self.presence_stream_cache.entity_has_changed,
|
||||||
state.user_id, stream_id,
|
state.user_id, stream_id,
|
||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(
|
txn.call_after(
|
||||||
txn, self._get_presence_for_user, (state.user_id,)
|
self._get_presence_for_user.invalidate, (state.user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually insert new rows
|
# Actually insert new rows
|
||||||
|
|
|
@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||||
results = txn.fetchall()
|
|
||||||
|
|
||||||
if results and topological_ordering:
|
if topological_ordering:
|
||||||
for to, so, _ in results:
|
for to, so, _ in txn:
|
||||||
if int(to) > topological_ordering:
|
if int(to) > topological_ordering:
|
||||||
return False
|
return False
|
||||||
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
||||||
|
|
|
@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
" WHERE lower(name) = lower(?)"
|
" WHERE lower(name) = lower(?)"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return dict(txn.fetchall())
|
return dict(txn)
|
||||||
|
|
||||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
|
|
|
@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
|
||||||
sql % ("AND appservice_id IS NULL",),
|
sql % ("AND appservice_id IS NULL",),
|
||||||
(stream_id,)
|
(stream_id,)
|
||||||
)
|
)
|
||||||
return dict(txn.fetchall())
|
return dict(txn)
|
||||||
else:
|
else:
|
||||||
# We want to get from all lists, so we need to aggregate the results
|
# We want to get from all lists, so we need to aggregate the results
|
||||||
|
|
||||||
|
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
# A room is visible if its visible on any list.
|
# A room is visible if its visible on any list.
|
||||||
for room_id, visibility in txn.fetchall():
|
for room_id, visibility in txn:
|
||||||
results[room_id] = bool(visibility) or results.get(room_id, False)
|
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
|
||||||
|
def get_hosts_in_room(self, room_id, cache_context):
|
||||||
|
"""Returns the set of all hosts currently in the room
|
||||||
|
"""
|
||||||
|
user_ids = yield self.get_users_in_room(
|
||||||
|
room_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
|
||||||
|
defer.returnValue(hosts)
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
@cached(max_entries=500000, iterable=True)
|
||||||
def get_users_in_room(self, room_id):
|
def get_users_in_room(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
sql = (
|
||||||
rows = self._get_members_rows_txn(
|
"SELECT m.user_id FROM room_memberships as m"
|
||||||
txn,
|
" INNER JOIN current_state_events as c"
|
||||||
room_id=room_id,
|
" ON m.event_id = c.event_id "
|
||||||
membership=Membership.JOIN,
|
" AND m.room_id = c.room_id "
|
||||||
|
" AND m.user_id = c.state_key"
|
||||||
|
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
return [r["user_id"] for r in rows]
|
txn.execute(sql, (room_id, Membership.JOIN,))
|
||||||
|
return [r[0] for r in txn]
|
||||||
return self.runInteraction("get_users_in_room", f)
|
return self.runInteraction("get_users_in_room", f)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
@ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
||||||
where_clause = "c.room_id = ?"
|
|
||||||
where_values = [room_id]
|
|
||||||
|
|
||||||
if membership:
|
|
||||||
where_clause += " AND m.membership = ?"
|
|
||||||
where_values.append(membership)
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
where_clause += " AND m.user_id = ?"
|
|
||||||
where_values.append(user_id)
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT m.* FROM room_memberships as m"
|
|
||||||
" INNER JOIN current_state_events as c"
|
|
||||||
" ON m.event_id = c.event_id "
|
|
||||||
" AND m.room_id = c.room_id "
|
|
||||||
" AND m.user_id = c.state_key"
|
|
||||||
" WHERE c.type = 'm.room.member' AND %(where)s"
|
|
||||||
) % {
|
|
||||||
"where": where_clause,
|
|
||||||
}
|
|
||||||
|
|
||||||
txn.execute(sql, where_values)
|
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
|
||||||
def get_rooms_for_user(self, user_id):
|
def get_rooms_for_user(self, user_id):
|
||||||
return self.get_rooms_for_user_where_membership_is(
|
"""Returns a set of room_ids the user is currently joined to
|
||||||
|
"""
|
||||||
|
rooms = yield self.get_rooms_for_user_where_membership_is(
|
||||||
user_id, membership_list=[Membership.JOIN],
|
user_id, membership_list=[Membership.JOIN],
|
||||||
)
|
)
|
||||||
|
defer.returnValue(frozenset(r.room_id for r in rooms))
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
||||||
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||||
"""Returns the set of users who share a room with `user_id`
|
"""Returns the set of users who share a room with `user_id`
|
||||||
"""
|
"""
|
||||||
rooms = yield self.get_rooms_for_user(
|
room_ids = yield self.get_rooms_for_user(
|
||||||
user_id, on_invalidate=cache_context.invalidate,
|
user_id, on_invalidate=cache_context.invalidate,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_who_share_room = set()
|
user_who_share_room = set()
|
||||||
for room in rooms:
|
for room_id in room_ids:
|
||||||
user_ids = yield self.get_users_in_room(
|
user_ids = yield self.get_users_in_room(
|
||||||
room.room_id, on_invalidate=cache_context.invalidate,
|
room_id, on_invalidate=cache_context.invalidate,
|
||||||
)
|
)
|
||||||
user_who_share_room.update(user_ids)
|
user_who_share_room.update(user_ids)
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
|
||||||
" WHERE event_id = ?"
|
" WHERE event_id = ?"
|
||||||
)
|
)
|
||||||
txn.execute(query, (event_id, ))
|
txn.execute(query, (event_id, ))
|
||||||
return {k: v for k, v in txn.fetchall()}
|
return {k: v for k, v in txn}
|
||||||
|
|
||||||
def _store_event_reference_hashes_txn(self, txn, events):
|
def _store_event_reference_hashes_txn(self, txn, events):
|
||||||
"""Store a hash for a PDU
|
"""Store a hash for a PDU
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
|
||||||
where_clause="type='m.room.member'",
|
where_clause="type='m.room.member'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(max_entries=100000, iterable=True)
|
||||||
|
def get_current_state_ids(self, room_id):
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="current_state_events",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=["event_id", "type", "state_key"],
|
||||||
|
desc="_calculate_state_delta",
|
||||||
|
)
|
||||||
|
defer.returnValue({
|
||||||
|
(r["type"], r["state_key"]): r["event_id"] for r in rows
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_groups_ids(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
|
@ -78,7 +90,7 @@ class StateStore(SQLBaseStore):
|
||||||
event_ids,
|
event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.itervalues())
|
||||||
group_to_state = yield self._get_state_for_groups(groups)
|
group_to_state = yield self._get_state_for_groups(groups)
|
||||||
|
|
||||||
defer.returnValue(group_to_state)
|
defer.returnValue(group_to_state)
|
||||||
|
@ -96,17 +108,18 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
state_event_map = yield self.get_events(
|
state_event_map = yield self.get_events(
|
||||||
[
|
[
|
||||||
ev_id for group_ids in group_to_ids.values()
|
ev_id for group_ids in group_to_ids.itervalues()
|
||||||
for ev_id in group_ids.values()
|
for ev_id in group_ids.itervalues()
|
||||||
],
|
],
|
||||||
get_prev_content=False
|
get_prev_content=False
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
group: [
|
group: [
|
||||||
state_event_map[v] for v in event_id_map.values() if v in state_event_map
|
state_event_map[v] for v in event_id_map.itervalues()
|
||||||
|
if v in state_event_map
|
||||||
]
|
]
|
||||||
for group, event_id_map in group_to_ids.items()
|
for group, event_id_map in group_to_ids.iteritems()
|
||||||
})
|
})
|
||||||
|
|
||||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||||
|
@ -124,6 +137,16 @@ class StateStore(SQLBaseStore):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.current_state_ids is None:
|
if context.current_state_ids is None:
|
||||||
|
# AFAIK, this can never happen
|
||||||
|
logger.error(
|
||||||
|
"Non-outlier event %s had current_state_ids==None",
|
||||||
|
event.event_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if the event was rejected, just give it the same state as its
|
||||||
|
# predecessor.
|
||||||
|
if context.rejected:
|
||||||
|
state_groups[event.event_id] = context.prev_group
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state_groups[event.event_id] = context.state_group
|
state_groups[event.event_id] = context.state_group
|
||||||
|
@ -168,7 +191,7 @@ class StateStore(SQLBaseStore):
|
||||||
"state_key": key[1],
|
"state_key": key[1],
|
||||||
"event_id": state_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for key, state_id in context.delta_ids.items()
|
for key, state_id in context.delta_ids.iteritems()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -183,7 +206,7 @@ class StateStore(SQLBaseStore):
|
||||||
"state_key": key[1],
|
"state_key": key[1],
|
||||||
"event_id": state_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for key, state_id in context.current_state_ids.items()
|
for key, state_id in context.current_state_ids.iteritems()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -195,7 +218,7 @@ class StateStore(SQLBaseStore):
|
||||||
"state_group": state_group_id,
|
"state_group": state_group_id,
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
}
|
}
|
||||||
for event_id, state_group_id in state_groups.items()
|
for event_id, state_group_id in state_groups.iteritems()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -319,10 +342,10 @@ class StateStore(SQLBaseStore):
|
||||||
args.extend(where_args)
|
args.extend(where_args)
|
||||||
|
|
||||||
txn.execute(sql % (where_clause,), args)
|
txn.execute(sql % (where_clause,), args)
|
||||||
rows = self.cursor_to_dict(txn)
|
for row in txn:
|
||||||
for row in rows:
|
typ, state_key, event_id = row
|
||||||
key = (row["type"], row["state_key"])
|
key = (typ, state_key)
|
||||||
results[group][key] = row["event_id"]
|
results[group][key] = event_id
|
||||||
else:
|
else:
|
||||||
if types is not None:
|
if types is not None:
|
||||||
where_clause = "AND (%s)" % (
|
where_clause = "AND (%s)" % (
|
||||||
|
@ -351,12 +374,11 @@ class StateStore(SQLBaseStore):
|
||||||
" WHERE state_group = ? %s" % (where_clause,),
|
" WHERE state_group = ? %s" % (where_clause,),
|
||||||
args
|
args
|
||||||
)
|
)
|
||||||
rows = txn.fetchall()
|
results[group].update(
|
||||||
results[group].update({
|
((typ, state_key), event_id)
|
||||||
(typ, state_key): event_id
|
for typ, state_key, event_id in txn
|
||||||
for typ, state_key, event_id in rows
|
|
||||||
if (typ, state_key) not in results[group]
|
if (typ, state_key) not in results[group]
|
||||||
})
|
)
|
||||||
|
|
||||||
# If the lengths match then we must have all the types,
|
# If the lengths match then we must have all the types,
|
||||||
# so no need to go walk further down the tree.
|
# so no need to go walk further down the tree.
|
||||||
|
@ -393,21 +415,21 @@ class StateStore(SQLBaseStore):
|
||||||
event_ids,
|
event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.itervalues())
|
||||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||||
|
|
||||||
state_event_map = yield self.get_events(
|
state_event_map = yield self.get_events(
|
||||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
[ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
|
||||||
get_prev_content=False
|
get_prev_content=False
|
||||||
)
|
)
|
||||||
|
|
||||||
event_to_state = {
|
event_to_state = {
|
||||||
event_id: {
|
event_id: {
|
||||||
k: state_event_map[v]
|
k: state_event_map[v]
|
||||||
for k, v in group_to_state[group].items()
|
for k, v in group_to_state[group].iteritems()
|
||||||
if v in state_event_map
|
if v in state_event_map
|
||||||
}
|
}
|
||||||
for event_id, group in event_to_groups.items()
|
for event_id, group in event_to_groups.iteritems()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||||
|
@ -430,12 +452,12 @@ class StateStore(SQLBaseStore):
|
||||||
event_ids,
|
event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.itervalues())
|
||||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||||
|
|
||||||
event_to_state = {
|
event_to_state = {
|
||||||
event_id: group_to_state[group]
|
event_id: group_to_state[group]
|
||||||
for event_id, group in event_to_groups.items()
|
for event_id, group in event_to_groups.iteritems()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||||
|
@ -474,7 +496,7 @@ class StateStore(SQLBaseStore):
|
||||||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||||
defer.returnValue(state_map[event_id])
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000)
|
@cached(num_args=2, max_entries=100000)
|
||||||
def _get_state_group_for_event(self, room_id, event_id):
|
def _get_state_group_for_event(self, room_id, event_id):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
@ -547,7 +569,7 @@ class StateStore(SQLBaseStore):
|
||||||
got_all = not (missing_types or types is None)
|
got_all = not (missing_types or types is None)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: v for k, v in state_dict_ids.items()
|
k: v for k, v in state_dict_ids.iteritems()
|
||||||
if include(k[0], k[1])
|
if include(k[0], k[1])
|
||||||
}, missing_types, got_all
|
}, missing_types, got_all
|
||||||
|
|
||||||
|
@ -606,7 +628,7 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
# Now we want to update the cache with all the things we fetched
|
# Now we want to update the cache with all the things we fetched
|
||||||
# from the database.
|
# from the database.
|
||||||
for group, group_state_dict in group_to_state_dict.items():
|
for group, group_state_dict in group_to_state_dict.iteritems():
|
||||||
if types:
|
if types:
|
||||||
# We delibrately put key -> None mappings into the cache to
|
# We delibrately put key -> None mappings into the cache to
|
||||||
# cache absence of the key, on the assumption that if we've
|
# cache absence of the key, on the assumption that if we've
|
||||||
|
@ -621,10 +643,10 @@ class StateStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
state_dict = results[group]
|
state_dict = results[group]
|
||||||
|
|
||||||
state_dict.update({
|
state_dict.update(
|
||||||
(intern_string(k[0]), intern_string(k[1])): v
|
((intern_string(k[0]), intern_string(k[1])), v)
|
||||||
for k, v in group_state_dict.items()
|
for k, v in group_state_dict.iteritems()
|
||||||
})
|
)
|
||||||
|
|
||||||
self._state_group_cache.update(
|
self._state_group_cache.update(
|
||||||
cache_seq_num,
|
cache_seq_num,
|
||||||
|
@ -635,10 +657,10 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
# Remove all the entries with None values. The None values were just
|
# Remove all the entries with None values. The None values were just
|
||||||
# used for bookkeeping in the cache.
|
# used for bookkeeping in the cache.
|
||||||
for group, state_dict in results.items():
|
for group, state_dict in results.iteritems():
|
||||||
results[group] = {
|
results[group] = {
|
||||||
key: event_id
|
key: event_id
|
||||||
for key, event_id in state_dict.items()
|
for key, event_id in state_dict.iteritems()
|
||||||
if event_id
|
if event_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -727,7 +749,7 @@ class StateStore(SQLBaseStore):
|
||||||
# of keys
|
# of keys
|
||||||
|
|
||||||
delta_state = {
|
delta_state = {
|
||||||
key: value for key, value in curr_state.items()
|
key: value for key, value in curr_state.iteritems()
|
||||||
if prev_state.get(key, None) != value
|
if prev_state.get(key, None) != value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -767,7 +789,7 @@ class StateStore(SQLBaseStore):
|
||||||
"state_key": key[1],
|
"state_key": key[1],
|
||||||
"event_id": state_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for key, state_id in delta_state.items()
|
for key, state_id in delta_state.iteritems()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
|
||||||
updatevalues={"stream_id": stream_id},
|
updatevalues={"stream_id": stream_id},
|
||||||
desc="update_federation_out_pos",
|
desc="update_federation_out_pos",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def has_room_changed_since(self, room_id, stream_id):
|
||||||
|
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
|
||||||
|
|
|
@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
|
||||||
for stream_id, user_id, room_id in tag_ids:
|
for stream_id, user_id, room_id in tag_ids:
|
||||||
txn.execute(sql, (user_id, room_id))
|
txn.execute(sql, (user_id, room_id))
|
||||||
tags = []
|
tags = []
|
||||||
for tag, content in txn.fetchall():
|
for tag, content in txn:
|
||||||
tags.append(json.dumps(tag) + ":" + content)
|
tags.append(json.dumps(tag) + ":" + content)
|
||||||
tag_json = "{" + ",".join(tags) + "}"
|
tag_json = "{" + ",".join(tags) + "}"
|
||||||
results.append((stream_id, user_id, room_id, tag_json))
|
results.append((stream_id, user_id, room_id, tag_json))
|
||||||
|
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
|
||||||
" WHERE user_id = ? AND stream_id > ?"
|
" WHERE user_id = ? AND stream_id > ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, stream_id))
|
txn.execute(sql, (user_id, stream_id))
|
||||||
room_ids = [row[0] for row in txn.fetchall()]
|
room_ids = [row[0] for row in txn]
|
||||||
return room_ids
|
return room_ids
|
||||||
|
|
||||||
changed = self._account_data_stream_cache.has_entity_changed(
|
changed = self._account_data_stream_cache.has_entity_changed(
|
||||||
|
|
|
@ -30,6 +30,17 @@ class IdGenerator(object):
|
||||||
|
|
||||||
|
|
||||||
def _load_current_id(db_conn, table, column, step=1):
|
def _load_current_id(db_conn, table, column, step=1):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn (object):
|
||||||
|
table (str):
|
||||||
|
column (str):
|
||||||
|
step (int):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
|
"""
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
if step == 1:
|
if step == 1:
|
||||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
|
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
|
||||||
|
@ -131,6 +142,9 @@ class StreamIdGenerator(object):
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
equal to it have been successfully persisted.
|
equal to it have been successfully persisted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
|
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DeferredTimedOutError(SynapseError):
|
class DeferredTimedOutError(SynapseError):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SynapseError).__init__(504, "Timed out")
|
super(SynapseError, self).__init__(504, "Timed out")
|
||||||
|
|
||||||
|
|
||||||
def unwrapFirstError(failure):
|
def unwrapFirstError(failure):
|
||||||
|
@ -93,8 +93,10 @@ class Clock(object):
|
||||||
ret_deferred = defer.Deferred()
|
ret_deferred = defer.Deferred()
|
||||||
|
|
||||||
def timed_out_fn():
|
def timed_out_fn():
|
||||||
|
e = DeferredTimedOutError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret_deferred.errback(DeferredTimedOutError())
|
ret_deferred.errback(e)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -114,7 +116,7 @@ class Clock(object):
|
||||||
|
|
||||||
ret_deferred.addBoth(cancel)
|
ret_deferred.addBoth(cancel)
|
||||||
|
|
||||||
def sucess(res):
|
def success(res):
|
||||||
try:
|
try:
|
||||||
ret_deferred.callback(res)
|
ret_deferred.callback(res)
|
||||||
except:
|
except:
|
||||||
|
@ -128,7 +130,7 @@ class Clock(object):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
given_deferred.addCallbacks(callback=sucess, errback=err)
|
given_deferred.addCallbacks(callback=success, errback=err)
|
||||||
|
|
||||||
timer = self.call_later(time_out, timed_out_fn)
|
timer = self.call_later(time_out, timed_out_fn)
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
from synapse.util.logcontext import (
|
|
||||||
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import DEBUG_CACHES, register_cache
|
from . import DEBUG_CACHES, register_cache
|
||||||
|
|
||||||
|
@ -189,7 +186,55 @@ class Cache(object):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
class CacheDescriptor(object):
|
class _CacheDescriptorBase(object):
|
||||||
|
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
|
||||||
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
|
arg_spec = inspect.getargspec(orig)
|
||||||
|
all_args = arg_spec.args
|
||||||
|
|
||||||
|
if "cache_context" in all_args:
|
||||||
|
if not cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have a 'cache_context' arg without setting"
|
||||||
|
" cache_context=True"
|
||||||
|
)
|
||||||
|
elif cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have cache_context=True without having an arg"
|
||||||
|
" named `cache_context`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_args is None:
|
||||||
|
num_args = len(all_args) - 1
|
||||||
|
if cache_context:
|
||||||
|
num_args -= 1
|
||||||
|
|
||||||
|
if len(all_args) < num_args + 1:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off for %r: "
|
||||||
|
"got %i args, but wanted %i. (@cached cannot key off *args or "
|
||||||
|
"**kwargs)"
|
||||||
|
% (orig.__name__, len(all_args), num_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_args = num_args
|
||||||
|
self.arg_names = all_args[1:num_args + 1]
|
||||||
|
|
||||||
|
if "cache_context" in self.arg_names:
|
||||||
|
raise Exception(
|
||||||
|
"cache_context arg cannot be included among the cache keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDescriptor(_CacheDescriptorBase):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
This caches deferreds, rather than the results themselves. Deferreds that
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
@ -217,52 +262,24 @@ class CacheDescriptor(object):
|
||||||
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||||
defer.returnValue(r1 + r2)
|
defer.returnValue(r1 + r2)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_args (int): number of positional arguments (excluding ``self`` and
|
||||||
|
``cache_context``) to use as cache keys. Defaults to all named
|
||||||
|
args of the function.
|
||||||
"""
|
"""
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
|
def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
|
||||||
inlineCallbacks=False, cache_context=False, iterable=False):
|
inlineCallbacks=False, cache_context=False, iterable=False):
|
||||||
|
|
||||||
|
super(CacheDescriptor, self).__init__(
|
||||||
|
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
|
||||||
|
cache_context=cache_context)
|
||||||
|
|
||||||
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
||||||
|
|
||||||
self.orig = orig
|
|
||||||
|
|
||||||
if inlineCallbacks:
|
|
||||||
self.function_to_call = defer.inlineCallbacks(orig)
|
|
||||||
else:
|
|
||||||
self.function_to_call = orig
|
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
self.num_args = num_args
|
|
||||||
self.tree = tree
|
self.tree = tree
|
||||||
|
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
all_args = inspect.getargspec(orig)
|
|
||||||
self.arg_names = all_args.args[1:num_args + 1]
|
|
||||||
|
|
||||||
if "cache_context" in all_args.args:
|
|
||||||
if not cache_context:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have a 'cache_context' arg without setting"
|
|
||||||
" cache_context=True"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.arg_names.remove("cache_context")
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
elif cache_context:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have cache_context=True without having an arg"
|
|
||||||
" named `cache_context`"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.add_cache_context = cache_context
|
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
|
||||||
raise Exception(
|
|
||||||
"Not enough explicit positional arguments to key off of for %r."
|
|
||||||
" (@cached cannot key off of *args or **kwargs)"
|
|
||||||
% (orig.__name__,)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cache = Cache(
|
cache = Cache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
|
@ -308,11 +325,9 @@ class CacheDescriptor(object):
|
||||||
defer.returnValue(cached_result)
|
defer.returnValue(cached_result)
|
||||||
observer.addCallback(check_result)
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
return preserve_context_over_deferred(observer)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
ret = defer.maybeDeferred(
|
ret = defer.maybeDeferred(
|
||||||
preserve_context_over_fn,
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
self.function_to_call,
|
|
||||||
obj, *args, **kwargs
|
obj, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -322,10 +337,11 @@ class CacheDescriptor(object):
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
result_d = ObservableDeferred(ret, consumeErrors=True)
|
||||||
cache.set(cache_key, ret, callback=invalidate_callback)
|
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||||
|
observer = result_d.observe()
|
||||||
|
|
||||||
return preserve_context_over_deferred(ret.observe())
|
return logcontext.make_deferred_yieldable(observer)
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
|
@ -338,48 +354,40 @@ class CacheDescriptor(object):
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class CacheListDescriptor(object):
|
class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
the list of missing keys to the wrapped fucntion.
|
the list of missing keys to the wrapped function.
|
||||||
|
|
||||||
|
Once wrapped, the function returns either a Deferred which resolves to
|
||||||
|
the list of results, or (if all results were cached), just the list of
|
||||||
|
results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig, cached_method_name, list_name, num_args=1,
|
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
||||||
inlineCallbacks=False):
|
inlineCallbacks=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
orig (function)
|
orig (function)
|
||||||
method_name (str); The name of the chached method.
|
cached_method_name (str): The name of the chached method.
|
||||||
list_name (str): Name of the argument which is the bulk lookup list
|
list_name (str): Name of the argument which is the bulk lookup list
|
||||||
num_args (int)
|
num_args (int): number of positional arguments (excluding ``self``,
|
||||||
|
but including list_name) to use as cache keys. Defaults to all
|
||||||
|
named args of the function.
|
||||||
inlineCallbacks (bool): Whether orig is a generator that should
|
inlineCallbacks (bool): Whether orig is a generator that should
|
||||||
be wrapped by defer.inlineCallbacks
|
be wrapped by defer.inlineCallbacks
|
||||||
"""
|
"""
|
||||||
self.orig = orig
|
super(CacheListDescriptor, self).__init__(
|
||||||
|
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
|
||||||
|
|
||||||
if inlineCallbacks:
|
|
||||||
self.function_to_call = defer.inlineCallbacks(orig)
|
|
||||||
else:
|
|
||||||
self.function_to_call = orig
|
|
||||||
|
|
||||||
self.num_args = num_args
|
|
||||||
self.list_name = list_name
|
self.list_name = list_name
|
||||||
|
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
|
||||||
self.list_pos = self.arg_names.index(self.list_name)
|
self.list_pos = self.arg_names.index(self.list_name)
|
||||||
|
|
||||||
self.cached_method_name = cached_method_name
|
self.cached_method_name = cached_method_name
|
||||||
|
|
||||||
self.sentinel = object()
|
self.sentinel = object()
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
|
||||||
raise Exception(
|
|
||||||
"Not enough explicit positional arguments to key off of for %r."
|
|
||||||
" (@cached cannot key off of *args or **kwars)"
|
|
||||||
% (orig.__name__,)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.list_name not in self.arg_names:
|
if self.list_name not in self.arg_names:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Couldn't see arguments %r for %r."
|
"Couldn't see arguments %r for %r."
|
||||||
|
@ -425,8 +433,7 @@ class CacheListDescriptor(object):
|
||||||
args_to_call[self.list_name] = missing
|
args_to_call[self.list_name] = missing
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
ret_d = defer.maybeDeferred(
|
||||||
preserve_context_over_fn,
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
self.function_to_call,
|
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -435,7 +442,6 @@ class CacheListDescriptor(object):
|
||||||
# We need to create deferreds for each arg in the list so that
|
# We need to create deferreds for each arg in the list so that
|
||||||
# we can insert the new deferred into the cache.
|
# we can insert the new deferred into the cache.
|
||||||
for arg in missing:
|
for arg in missing:
|
||||||
with PreserveLoggingContext():
|
|
||||||
observer = ret_d.observe()
|
observer = ret_d.observe()
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||||
|
|
||||||
|
@ -463,7 +469,7 @@ class CacheListDescriptor(object):
|
||||||
results.update(res)
|
results.update(res)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return preserve_context_over_deferred(defer.gatherResults(
|
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
cached_defers.values(),
|
cached_defers.values(),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addCallback(update_results_dict).addErrback(
|
).addCallback(update_results_dict).addErrback(
|
||||||
|
@ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||||
self.cache.invalidate(self.key)
|
self.cache.invalidate(self.key)
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
|
||||||
iterable=False):
|
iterable=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
|
@ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
|
||||||
iterable=False):
|
cache_context=False, iterable=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
|
@ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
|
def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
|
||||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. A single argument
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
|
@ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
|
||||||
cache (Cache): The underlying cache to use.
|
cache (Cache): The underlying cache to use.
|
||||||
list_name (str): The name of the argument that is the list to use to
|
list_name (str): The name of the argument that is the list to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args (int): Number of arguments to use as the key in the cache.
|
num_args (int): Number of arguments to use as the key in the cache
|
||||||
|
(including list_name). Defaults to all named parameters.
|
||||||
inlineCallbacks (bool): Should the function be wrapped in an
|
inlineCallbacks (bool): Should the function be wrapped in an
|
||||||
`defer.inlineCallbacks`?
|
`defer.inlineCallbacks`?
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ class StreamChangeCache(object):
|
||||||
def has_entity_changed(self, entity, stream_pos):
|
def has_entity_changed(self, entity, stream_pos):
|
||||||
"""Returns True if the entity may have been updated since stream_pos
|
"""Returns True if the entity may have been updated since stream_pos
|
||||||
"""
|
"""
|
||||||
assert type(stream_pos) is int
|
assert type(stream_pos) is int or type(stream_pos) is long
|
||||||
|
|
||||||
if stream_pos < self._earliest_known_stream_pos:
|
if stream_pos < self._earliest_known_stream_pos:
|
||||||
self.metrics.inc_misses()
|
self.metrics.inc_misses()
|
||||||
|
|
|
@ -12,6 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
""" Thread-local-alike tracking of log contexts within synapse
|
||||||
|
|
||||||
|
This module provides objects and utilities for tracking contexts through
|
||||||
|
synapse code, so that log lines can include a request identifier, and so that
|
||||||
|
CPU and database activity can be accounted for against the request that caused
|
||||||
|
them.
|
||||||
|
|
||||||
|
See doc/log_contexts.rst for details on how this works.
|
||||||
|
"""
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
||||||
def preserve_context_over_deferred(deferred, context=None):
|
def preserve_context_over_deferred(deferred, context=None):
|
||||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||||
be invoked with the current context.
|
be invoked with the current context.
|
||||||
|
|
||||||
|
Deprecated: this almost certainly doesn't do want you want, ie make
|
||||||
|
the deferred follow the synapse logcontext rules: try
|
||||||
|
``make_deferred_yieldable`` instead.
|
||||||
"""
|
"""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
|
@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
|
||||||
|
|
||||||
|
|
||||||
def preserve_fn(f):
|
def preserve_fn(f):
|
||||||
"""Ensures that function is called with correct context and that context is
|
"""Wraps a function, to ensure that the current context is restored after
|
||||||
restored after return. Useful for wrapping functions that return a deferred
|
return from the function, and that the sentinel context is set once the
|
||||||
which you don't yield on.
|
deferred returned by the funtion completes.
|
||||||
|
|
||||||
|
Useful for wrapping functions that return a deferred which you don't yield
|
||||||
|
on.
|
||||||
"""
|
"""
|
||||||
|
def reset_context(result):
|
||||||
|
LoggingContext.set_current_context(LoggingContext.sentinel)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# XXX: why is this here rather than inside g? surely we want to preserve
|
||||||
|
# the context from the time the function was called, not when it was
|
||||||
|
# wrapped?
|
||||||
current = LoggingContext.current_context()
|
current = LoggingContext.current_context()
|
||||||
|
|
||||||
def g(*args, **kwargs):
|
def g(*args, **kwargs):
|
||||||
with PreserveLoggingContext(current):
|
|
||||||
res = f(*args, **kwargs)
|
res = f(*args, **kwargs)
|
||||||
if isinstance(res, defer.Deferred):
|
if isinstance(res, defer.Deferred) and not res.called:
|
||||||
return preserve_context_over_deferred(
|
# The function will have reset the context before returning, so
|
||||||
res, context=LoggingContext.sentinel
|
# we need to restore it now.
|
||||||
)
|
LoggingContext.set_current_context(current)
|
||||||
else:
|
|
||||||
|
# The original context will be restored when the deferred
|
||||||
|
# completes, but there is nothing waiting for it, so it will
|
||||||
|
# get leaked into the reactor or some other function which
|
||||||
|
# wasn't expecting it. We therefore need to reset the context
|
||||||
|
# here.
|
||||||
|
#
|
||||||
|
# (If this feels asymmetric, consider it this way: we are
|
||||||
|
# effectively forking a new thread of execution. We are
|
||||||
|
# probably currently within a ``with LoggingContext()`` block,
|
||||||
|
# which is supposed to have a single entry and exit point. But
|
||||||
|
# by spawning off another deferred, we are effectively
|
||||||
|
# adding a new exit point.)
|
||||||
|
res.addBoth(reset_context)
|
||||||
return res
|
return res
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def make_deferred_yieldable(deferred):
|
||||||
|
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||||
|
|
||||||
|
If the deferred has completed (or is not actually a Deferred), essentially
|
||||||
|
does nothing (just returns another completed deferred with the
|
||||||
|
result/failure).
|
||||||
|
|
||||||
|
If the deferred has not yet completed, resets the logcontext before
|
||||||
|
returning a deferred. Then, when the deferred completes, restores the
|
||||||
|
current logcontext before running callbacks/errbacks.
|
||||||
|
|
||||||
|
(This is more-or-less the opposite operation to preserve_fn.)
|
||||||
|
"""
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
r = yield deferred
|
||||||
|
defer.returnValue(r)
|
||||||
|
|
||||||
|
|
||||||
# modules to ignore in `logcontext_tracer`
|
# modules to ignore in `logcontext_tracer`
|
||||||
_to_ignore = [
|
_to_ignore = [
|
||||||
"synapse.util.logcontext",
|
"synapse.util.logcontext",
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import phonenumbers
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
|
||||||
|
def phone_number_to_msisdn(country, number):
|
||||||
|
"""
|
||||||
|
Takes an ISO-3166-1 2 letter country code and phone number and
|
||||||
|
returns an msisdn representing the canonical version of that
|
||||||
|
phone number.
|
||||||
|
Args:
|
||||||
|
country (str): ISO-3166-1 2 letter country code
|
||||||
|
number (str): Phone number in a national or international format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) The canonical form of the phone number, as an msisdn
|
||||||
|
Raises:
|
||||||
|
SynapseError if the number could not be parsed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
phoneNumber = phonenumbers.parse(number, country)
|
||||||
|
except phonenumbers.NumberParseException:
|
||||||
|
raise SynapseError(400, "Unable to parse phone number")
|
||||||
|
return phonenumbers.format_number(
|
||||||
|
phoneNumber, phonenumbers.PhoneNumberFormat.E164
|
||||||
|
)[1:]
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import synapse.util.logcontext
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
|
@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_retry_limiter(destination, clock, store, **kwargs):
|
def get_retry_limiter(destination, clock, store, ignore_backoff=False,
|
||||||
|
**kwargs):
|
||||||
"""For a given destination check if we have previously failed to
|
"""For a given destination check if we have previously failed to
|
||||||
send a request there and are waiting before retrying the destination.
|
send a request there and are waiting before retrying the destination.
|
||||||
If we are not ready to retry the destination, this will raise a
|
If we are not ready to retry the destination, this will raise a
|
||||||
|
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
|
||||||
that will mark the destination as down if an exception is thrown (excluding
|
that will mark the destination as down if an exception is thrown (excluding
|
||||||
CodeMessageException with code < 500)
|
CodeMessageException with code < 500)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): name of homeserver
|
||||||
|
clock (synapse.util.clock): timing source
|
||||||
|
store (synapse.storage.transactions.TransactionStore): datastore
|
||||||
|
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||||
|
try the request anyway. We will still update the next
|
||||||
|
retry_interval on success/failure.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
|
||||||
|
|
||||||
now = int(clock.time_msec())
|
now = int(clock.time_msec())
|
||||||
|
|
||||||
if retry_last_ts + retry_interval > now:
|
if not ignore_backoff and retry_last_ts + retry_interval > now:
|
||||||
raise NotRetryingDestination(
|
raise NotRetryingDestination(
|
||||||
retry_last_ts=retry_last_ts,
|
retry_last_ts=retry_last_ts,
|
||||||
retry_interval=retry_interval,
|
retry_interval=retry_interval,
|
||||||
|
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
valid_err_code = False
|
valid_err_code = False
|
||||||
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
if exc_type is None:
|
||||||
|
valid_err_code = True
|
||||||
|
elif not issubclass(exc_type, Exception):
|
||||||
|
# avoid treating exceptions which don't derive from Exception as
|
||||||
|
# failures; this is mostly so as not to catch defer._DefGen.
|
||||||
|
valid_err_code = True
|
||||||
|
elif issubclass(exc_type, CodeMessageException):
|
||||||
# Some error codes are perfectly fine for some APIs, whereas other
|
# Some error codes are perfectly fine for some APIs, whereas other
|
||||||
# APIs may expect to never received e.g. a 404. It's important to
|
# APIs may expect to never received e.g. a 404. It's important to
|
||||||
# handle 404 as some remote servers will return a 404 when the HS
|
# handle 404 as some remote servers will return a 404 when the HS
|
||||||
|
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
|
||||||
else:
|
else:
|
||||||
valid_err_code = False
|
valid_err_code = False
|
||||||
|
|
||||||
if exc_type is None or valid_err_code:
|
if valid_err_code:
|
||||||
# We connected successfully.
|
# We connected successfully.
|
||||||
if not self.retry_interval:
|
if not self.retry_interval:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.debug("Connection to %s was successful; clearing backoff",
|
||||||
|
self.destination)
|
||||||
retry_last_ts = 0
|
retry_last_ts = 0
|
||||||
self.retry_interval = 0
|
self.retry_interval = 0
|
||||||
else:
|
else:
|
||||||
|
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
|
||||||
else:
|
else:
|
||||||
self.retry_interval = self.min_retry_interval
|
self.retry_interval = self.min_retry_interval
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
|
||||||
|
self.destination, exc_type, exc_val, self.retry_interval
|
||||||
|
)
|
||||||
retry_last_ts = int(self.clock.time_msec())
|
retry_last_ts = int(self.clock.time_msec())
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
|
||||||
"Failed to store set_destination_retry_timings",
|
"Failed to store set_destination_retry_timings",
|
||||||
)
|
)
|
||||||
|
|
||||||
store_retry_timings()
|
# we deliberately do this in the background.
|
||||||
|
synapse.util.logcontext.preserve_fn(store_retry_timings)()
|
||||||
|
|
|
@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||||
if prev_membership not in MEMBERSHIP_PRIORITY:
|
if prev_membership not in MEMBERSHIP_PRIORITY:
|
||||||
prev_membership = "leave"
|
prev_membership = "leave"
|
||||||
|
|
||||||
|
# Always allow the user to see their own leave events, otherwise
|
||||||
|
# they won't see the room disappear if they reject the invite
|
||||||
|
if membership == "leave" and (
|
||||||
|
prev_membership == "join" or prev_membership == "invite"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
new_priority = MEMBERSHIP_PRIORITY.index(membership)
|
new_priority = MEMBERSHIP_PRIORITY.index(membership)
|
||||||
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
|
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
|
||||||
if old_priority < new_priority:
|
if old_priority < new_priority:
|
||||||
|
|
|
@ -23,6 +23,9 @@ from tests.utils import (
|
||||||
|
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
|
||||||
user_localpart = "test_user"
|
user_localpart = "test_user"
|
||||||
|
|
||||||
|
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.datastore = hs.get_datastore()
|
self.datastore = hs.get_datastore()
|
||||||
|
|
||||||
|
def test_errors_on_invalid_filters(self):
|
||||||
|
invalid_filters = [
|
||||||
|
{"boom": {}},
|
||||||
|
{"account_data": "Hello World"},
|
||||||
|
{"event_fields": ["\\foo"]},
|
||||||
|
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
|
||||||
|
{"event_format": "other"},
|
||||||
|
{"room": {"not_rooms": ["#foo:pik-test"]}},
|
||||||
|
{"presence": {"senders": ["@bar;pik.test.com"]}}
|
||||||
|
]
|
||||||
|
for filter in invalid_filters:
|
||||||
|
with self.assertRaises(SynapseError) as check_filter_error:
|
||||||
|
self.filtering.check_valid_filter(filter)
|
||||||
|
self.assertIsInstance(check_filter_error.exception, SynapseError)
|
||||||
|
|
||||||
|
def test_valid_filters(self):
|
||||||
|
valid_filters = [
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
"timeline": {"limit": 20},
|
||||||
|
"state": {"not_types": ["m.room.member"]},
|
||||||
|
"ephemeral": {"limit": 0, "not_types": ["*"]},
|
||||||
|
"include_leave": False,
|
||||||
|
"rooms": ["!dee:pik-test"],
|
||||||
|
"not_rooms": ["!gee:pik-test"],
|
||||||
|
"account_data": {"limit": 0, "types": ["*"]}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
"state": {
|
||||||
|
"types": ["m.room.*"],
|
||||||
|
"not_rooms": ["!726s6s6q:example.com"]
|
||||||
|
},
|
||||||
|
"timeline": {
|
||||||
|
"limit": 10,
|
||||||
|
"types": ["m.room.message"],
|
||||||
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
|
"not_senders": ["@spam:example.com"]
|
||||||
|
},
|
||||||
|
"ephemeral": {
|
||||||
|
"types": ["m.receipt", "m.typing"],
|
||||||
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
|
"not_senders": ["@spam:example.com"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"presence": {
|
||||||
|
"types": ["m.presence"],
|
||||||
|
"not_senders": ["@alice:example.com"]
|
||||||
|
},
|
||||||
|
"event_format": "client",
|
||||||
|
"event_fields": ["type", "content", "sender"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for filter in valid_filters:
|
||||||
|
try:
|
||||||
|
self.filtering.check_valid_filter(filter)
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
self.fail(e)
|
||||||
|
|
||||||
|
def test_limits_are_applied(self):
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|
||||||
def test_definition_types_works_with_literals(self):
|
def test_definition_types_works_with_literals(self):
|
||||||
definition = {
|
definition = {
|
||||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
"types": ["m.room.message", "org.matrix.foo.bar"]
|
||||||
|
|
|
@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
|
||||||
"room_alias": "#another:remote",
|
"room_alias": "#another:remote",
|
||||||
},
|
},
|
||||||
retry_on_dns_fail=False,
|
retry_on_dns_fail=False,
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
state = UserPresenceState.default(user_id)
|
state = UserPresenceState.default(user_id)
|
||||||
state = state.copy_and_replace(
|
state = state.copy_and_replace(
|
||||||
state=PresenceState.ONLINE,
|
state=PresenceState.ONLINE,
|
||||||
last_active_ts=now,
|
last_active_ts=0,
|
||||||
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
self.mock_federation.make_query.assert_called_with(
|
self.mock_federation.make_query.assert_called_with(
|
||||||
destination="remote",
|
destination="remote",
|
||||||
query_type="profile",
|
query_type="profile",
|
||||||
args={"user_id": "@alice:remote", "field": "displayname"}
|
args={"user_id": "@alice:remote", "field": "displayname"},
|
||||||
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
),
|
),
|
||||||
json_data_callback=ANY,
|
json_data_callback=ANY,
|
||||||
long_retries=True,
|
long_retries=True,
|
||||||
|
backoff_on_404=True,
|
||||||
),
|
),
|
||||||
defer.succeed((200, "OK"))
|
defer.succeed((200, "OK"))
|
||||||
)
|
)
|
||||||
|
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
),
|
),
|
||||||
json_data_callback=ANY,
|
json_data_callback=ANY,
|
||||||
long_retries=True,
|
long_retries=True,
|
||||||
|
backoff_on_404=True,
|
||||||
),
|
),
|
||||||
defer.succeed((200, "OK"))
|
defer.succeed((200, "OK"))
|
||||||
)
|
)
|
||||||
|
|
|
@ -68,7 +68,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
code, body = yield get
|
code, body = yield get
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
self.assertEquals(body["events"]["field_names"], [
|
self.assertEquals(body["events"]["field_names"], [
|
||||||
"position", "internal", "json", "state_group"
|
"position", "event_id", "room_id", "type", "state_key",
|
||||||
])
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||||
class FilterTestCase(unittest.TestCase):
|
class FilterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
USER_ID = "@apple:test"
|
USER_ID = "@apple:test"
|
||||||
EXAMPLE_FILTER = {"type": ["m.*"]}
|
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||||
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
|
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||||
TO_REGISTER = [filter]
|
TO_REGISTER = [filter]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_one_1col(self):
|
def test_select_one_1col(self):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
self.mock_txn.fetchall.return_value = [("Value",)]
|
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
||||||
|
|
||||||
value = yield self.datastore._simple_select_one_onecol(
|
value = yield self.datastore._simple_select_one_onecol(
|
||||||
table="tablename",
|
table="tablename",
|
||||||
|
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_list(self):
|
def test_select_list(self):
|
||||||
self.mock_txn.rowcount = 3
|
self.mock_txn.rowcount = 3
|
||||||
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,))
|
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
||||||
self.mock_txn.description = (
|
self.mock_txn.description = (
|
||||||
("colA", None, None, None, None, None, None),
|
("colA", None, None, None, None, None, None),
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import signedjson.key
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import tests.unittest
|
||||||
|
import tests.utils
|
||||||
|
|
||||||
|
|
||||||
|
class KeyStoreTestCase(tests.unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
|
||||||
|
self.store = None # type: synapse.storage.keys.KeyStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
hs = yield tests.utils.setup_test_homeserver()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_server_verify_keys(self):
|
||||||
|
key1 = signedjson.key.decode_verify_key_base64(
|
||||||
|
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
|
||||||
|
)
|
||||||
|
key2 = signedjson.key.decode_verify_key_base64(
|
||||||
|
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
||||||
|
)
|
||||||
|
yield self.store.store_server_verify_key(
|
||||||
|
"server1", "from_server", 0, key1
|
||||||
|
)
|
||||||
|
yield self.store.store_server_verify_key(
|
||||||
|
"server1", "from_server", 0, key2
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield self.store.get_server_verify_keys(
|
||||||
|
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
|
||||||
|
|
||||||
|
self.assertEqual(len(res.keys()), 2)
|
||||||
|
self.assertEqual(res["ed25519:key1"].version, "key1")
|
||||||
|
self.assertEqual(res["ed25519:key2"].version, "key2")
|
|
@ -0,0 +1,14 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,177 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import mock
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.util import async
|
||||||
|
from synapse.util import logcontext
|
||||||
|
from twisted.internet import defer
|
||||||
|
from synapse.util.caches import descriptors
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DescriptorTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(1, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_num_args(self):
|
||||||
|
"""Only the first num_args arguments should matter to the cache"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached(num_args=1)
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(2, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached; we should be able to vary
|
||||||
|
# the second argument and still get the cached result.
|
||||||
|
r = yield obj.fn(1, 4)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(2, 5)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_logcontexts(self):
|
||||||
|
"""Check that logcontexts are set and restored correctly when
|
||||||
|
using the cache."""
|
||||||
|
|
||||||
|
complete_lookup = defer.Deferred()
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def inner_fn():
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
yield complete_lookup
|
||||||
|
defer.returnValue(1)
|
||||||
|
|
||||||
|
return inner_fn()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_lookup():
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.name = "c1"
|
||||||
|
r = yield obj.fn(1)
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
c1)
|
||||||
|
defer.returnValue(r)
|
||||||
|
|
||||||
|
def check_result(r):
|
||||||
|
self.assertEqual(r, 1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# set off a deferred which will do a cache lookup
|
||||||
|
d1 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
d1.addCallback(check_result)
|
||||||
|
|
||||||
|
# and another
|
||||||
|
d2 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
d2.addCallback(check_result)
|
||||||
|
|
||||||
|
# let the lookup complete
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
|
||||||
|
return defer.gatherResults([d1, d2])
|
||||||
|
|
||||||
|
def test_cache_logcontexts_with_exception(self):
|
||||||
|
"""Check that the cache sets and restores logcontexts correctly when
|
||||||
|
the lookup function throws an exception"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def inner_fn():
|
||||||
|
yield async.run_on_reactor()
|
||||||
|
raise SynapseError(400, "blah")
|
||||||
|
|
||||||
|
return inner_fn()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_lookup():
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.name = "c1"
|
||||||
|
try:
|
||||||
|
yield obj.fn(1)
|
||||||
|
self.fail("No exception thrown")
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
c1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# set off a deferred which will do a cache lookup
|
||||||
|
d1 = do_lookup()
|
||||||
|
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel)
|
||||||
|
|
||||||
|
return d1
|
|
@ -0,0 +1,33 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from synapse import util
|
||||||
|
from twisted.internet import defer
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ClockTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_time_bound_deferred(self):
|
||||||
|
# just a deferred which never resolves
|
||||||
|
slow_deferred = defer.Deferred()
|
||||||
|
|
||||||
|
clock = util.Clock()
|
||||||
|
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield time_bound
|
||||||
|
self.fail("Expected timedout error, but got nothing")
|
||||||
|
except util.DeferredTimedOutError:
|
||||||
|
pass
|
|
@ -1,8 +1,10 @@
|
||||||
|
import twisted.python.failure
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util import logcontext
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||||
context_one.test_key = "one"
|
context_one.test_key = "one"
|
||||||
yield sleep(0)
|
yield sleep(0)
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
|
||||||
|
def _test_preserve_fn(self, function):
|
||||||
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
callback_completed = [False]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def cb():
|
||||||
|
context_one.test_key = "one"
|
||||||
|
yield function()
|
||||||
|
self._check_test_key("one")
|
||||||
|
|
||||||
|
callback_completed[0] = True
|
||||||
|
|
||||||
|
with LoggingContext() as context_one:
|
||||||
|
context_one.test_key = "one"
|
||||||
|
|
||||||
|
# fire off function, but don't wait on it.
|
||||||
|
logcontext.preserve_fn(cb)()
|
||||||
|
|
||||||
|
self._check_test_key("one")
|
||||||
|
|
||||||
|
# now wait for the function under test to have run, and check that
|
||||||
|
# the logcontext is left in a sane state.
|
||||||
|
d2 = defer.Deferred()
|
||||||
|
|
||||||
|
def check_logcontext():
|
||||||
|
if not callback_completed[0]:
|
||||||
|
reactor.callLater(0.01, check_logcontext)
|
||||||
|
return
|
||||||
|
|
||||||
|
# make sure that the context was reset before it got thrown back
|
||||||
|
# into the reactor
|
||||||
|
try:
|
||||||
|
self.assertIs(LoggingContext.current_context(),
|
||||||
|
sentinel_context)
|
||||||
|
d2.callback(None)
|
||||||
|
except BaseException:
|
||||||
|
d2.errback(twisted.python.failure.Failure())
|
||||||
|
|
||||||
|
reactor.callLater(0.01, check_logcontext)
|
||||||
|
|
||||||
|
# test is done once d2 finishes
|
||||||
|
return d2
|
||||||
|
|
||||||
|
def test_preserve_fn_with_blocking_fn(self):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def blocking_function():
|
||||||
|
yield sleep(0)
|
||||||
|
|
||||||
|
return self._test_preserve_fn(blocking_function)
|
||||||
|
|
||||||
|
def test_preserve_fn_with_non_blocking_fn(self):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def nonblocking_function():
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
yield defer.succeed(None)
|
||||||
|
|
||||||
|
return self._test_preserve_fn(nonblocking_function)
|
||||||
|
|
Loading…
Reference in New Issue