Merge branch 'release-v0.18.5' of github.com:matrix-org/synapse
This commit is contained in:
commit
f5a4001bb1
|
@ -0,0 +1,17 @@
|
|||
sudo: false
|
||||
language: python
|
||||
python: 2.7
|
||||
|
||||
# tell travis to cache ~/.cache/pip
|
||||
cache: pip
|
||||
|
||||
env:
|
||||
- TOX_ENV=packaging
|
||||
- TOX_ENV=pep8
|
||||
- TOX_ENV=py27
|
||||
|
||||
install:
|
||||
- pip install tox
|
||||
|
||||
script:
|
||||
- tox -e $TOX_ENV
|
65
CHANGES.rst
65
CHANGES.rst
|
@ -1,3 +1,68 @@
|
|||
Changes in synapse v0.18.5 (2016-12-16)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix federation /backfill returning events it shouldn't (PR #1700)
|
||||
* Fix crash in url preview (PR #1701)
|
||||
|
||||
|
||||
Changes in synapse v0.18.5-rc3 (2016-12-13)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add support for E2E for guests (PR #1653)
|
||||
* Add new API appservice specific public room list (PR #1676)
|
||||
* Add new room membership APIs (PR #1680)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Enable guest access for private rooms by default (PR #653)
|
||||
* Limit the number of events that can be created on a given room concurrently
|
||||
(PR #1620)
|
||||
* Log the args that we have on UI auth completion (PR #1649)
|
||||
* Stop generating refresh_tokens (PR #1654)
|
||||
* Stop putting a time caveat on access tokens (PR #1656)
|
||||
* Remove unspecced GET endpoints for e2e keys (PR #1694)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix handling of 500 and 429's over federation (PR #1650)
|
||||
* Fix Content-Type header parsing (PR #1660)
|
||||
* Fix error when previewing sites that include unicode, thanks to kyrias (PR
|
||||
#1664)
|
||||
* Fix some cases where we drop read receipts (PR #1678)
|
||||
* Fix bug where calls to ``/sync`` didn't correctly timeout (PR #1683)
|
||||
* Fix bug where E2E key query would fail if a single remote host failed (PR
|
||||
#1686)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.18.5-rc2 (2016-11-24)
|
||||
===========================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Don't send old events over federation, fixes bug in -rc1.
|
||||
|
||||
Changes in synapse v0.18.5-rc1 (2016-11-24)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Implement "event_fields" in filters (PR #1638)
|
||||
|
||||
Changes:
|
||||
|
||||
* Use external ldap auth pacakge (PR #1628)
|
||||
* Split out federation transaction sending to a worker (PR #1635)
|
||||
* Fail with a coherent error message if `/sync?filter=` is invalid (PR #1636)
|
||||
* More efficient notif count queries (PR #1644)
|
||||
|
||||
|
||||
Changes in synapse v0.18.4 (2016-11-22)
|
||||
=======================================
|
||||
|
||||
|
|
31
README.rst
31
README.rst
|
@ -120,6 +120,7 @@ Installing prerequisites on Mac OS X::
|
|||
xcode-select --install
|
||||
sudo easy_install pip
|
||||
sudo pip install virtualenv
|
||||
brew install pkg-config libffi
|
||||
|
||||
Installing prerequisites on Raspbian::
|
||||
|
||||
|
@ -136,6 +137,10 @@ Installing prerequisites on openSUSE::
|
|||
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
||||
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
||||
|
||||
Installing prerequisites on OpenBSD::
|
||||
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
|
||||
libxslt
|
||||
|
||||
To install the synapse homeserver run::
|
||||
|
||||
virtualenv -p python2.7 ~/.synapse
|
||||
|
@ -369,6 +374,32 @@ Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Mo
|
|||
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
|
||||
- Packages: ``pkg install py27-matrix-synapse``
|
||||
|
||||
|
||||
OpenBSD
|
||||
-------
|
||||
|
||||
There is currently no port for OpenBSD. Additionally, OpenBSD's security
|
||||
settings require a slightly more difficult installation process.
|
||||
|
||||
1) Create a new directory in ``/usr/local`` called ``_synapse``. Also, create a
|
||||
new user called ``_synapse`` and set that directory as the new user's home.
|
||||
This is required because, by default, OpenBSD only allows binaries which need
|
||||
write and execute permissions on the same memory space to be run from
|
||||
``/usr/local``.
|
||||
2) ``su`` to the new ``_synapse`` user and change to their home directory.
|
||||
3) Create a new virtualenv: ``virtualenv -p python2.7 ~/.synapse``
|
||||
4) Source the virtualenv configuration located at
|
||||
``/usr/local/_synapse/.synapse/bin/activate``. This is done in ``ksh`` by
|
||||
using the ``.`` command, rather than ``bash``'s ``source``.
|
||||
5) Optionally, use ``pip`` to install ``lxml``, which Synapse needs to parse
|
||||
webpages for their titles.
|
||||
6) Use ``pip`` to install this repository: ``pip install
|
||||
https://github.com/matrix-org/synapse/tarball/master``
|
||||
7) Optionally, change ``_synapse``'s shell to ``/bin/false`` to reduce the
|
||||
chance of a compromised Synapse server being used to take over your box.
|
||||
|
||||
After this, you may proceed with the rest of the install directions.
|
||||
|
||||
NixOS
|
||||
-----
|
||||
|
||||
|
|
|
@ -22,3 +22,4 @@ export SYNAPSE_CACHE_FACTOR=1
|
|||
--federation-reader \
|
||||
--client-reader \
|
||||
--appservice \
|
||||
--federation-sender \
|
||||
|
|
|
@ -15,6 +15,6 @@ tox -e py27 --notest -v
|
|||
|
||||
TOX_BIN=$TOX_DIR/py27/bin
|
||||
$TOX_BIN/pip install setuptools
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
$TOX_BIN/pip install psycopg2
|
||||
{ python synapse/python_dependencies.py
|
||||
echo lxml psycopg2
|
||||
} | xargs $TOX_BIN/pip install
|
||||
|
|
73
setup.py
73
setup.py
|
@ -23,6 +23,45 @@ import sys
|
|||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
# Some notes on `setup.py test`:
|
||||
#
|
||||
# Once upon a time we used to try to make `setup.py test` run `tox` to run the
|
||||
# tests. That's a bad idea for three reasons:
|
||||
#
|
||||
# 1: `setup.py test` is supposed to find out whether the tests work in the
|
||||
# *current* environmentt, not whatever tox sets up.
|
||||
# 2: Empirically, trying to install tox during the test run wasn't working ("No
|
||||
# module named virtualenv").
|
||||
# 3: The tox documentation advises against it[1].
|
||||
#
|
||||
# Even further back in time, we used to use setuptools_trial [2]. That has its
|
||||
# own set of issues: for instance, it requires installation of Twisted to build
|
||||
# an sdist (because the recommended mode of usage is to add it to
|
||||
# `setup_requires`). That in turn means that in order to successfully run tox
|
||||
# you have to have the python header files installed for whichever version of
|
||||
# python tox uses (which is python3 on recent ubuntus, for example).
|
||||
#
|
||||
# So, for now at least, we stick with what appears to be the convention among
|
||||
# Twisted projects, and don't attempt to do anything when someone runs
|
||||
# `setup.py test`; instead we direct people to run `trial` directly if they
|
||||
# care.
|
||||
#
|
||||
# [1]: http://tox.readthedocs.io/en/2.5.0/example/basic.html#integration-with-setup-py-test-command
|
||||
# [2]: https://pypi.python.org/pypi/setuptools_trial
|
||||
class TestCommand(Command):
|
||||
user_options = []
|
||||
|
||||
def initialize_options(self):
|
||||
pass
|
||||
|
||||
def finalize_options(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
print ("""Synapse's tests cannot be run via setup.py. To run them, try:
|
||||
PYTHONPATH="." trial tests
|
||||
""")
|
||||
|
||||
def read_file(path_segments):
|
||||
"""Read a file from the package. Takes a list of strings to join to
|
||||
make the path"""
|
||||
|
@ -39,38 +78,6 @@ def exec_file(path_segments):
|
|||
return result
|
||||
|
||||
|
||||
class Tox(Command):
|
||||
user_options = [('tox-args=', 'a', "Arguments to pass to tox")]
|
||||
|
||||
def initialize_options(self):
|
||||
self.tox_args = None
|
||||
|
||||
def finalize_options(self):
|
||||
self.test_args = []
|
||||
self.test_suite = True
|
||||
|
||||
def run(self):
|
||||
#import here, cause outside the eggs aren't loaded
|
||||
try:
|
||||
import tox
|
||||
except ImportError:
|
||||
try:
|
||||
self.distribution.fetch_build_eggs("tox")
|
||||
import tox
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"The tests need 'tox' to run. Please install 'tox'."
|
||||
)
|
||||
import shlex
|
||||
args = self.tox_args
|
||||
if args:
|
||||
args = shlex.split(self.tox_args)
|
||||
else:
|
||||
args = []
|
||||
errno = tox.cmdline(args=args)
|
||||
sys.exit(errno)
|
||||
|
||||
|
||||
version = exec_file(("synapse", "__init__.py"))["__version__"]
|
||||
dependencies = exec_file(("synapse", "python_dependencies.py"))
|
||||
long_description = read_file(("README.rst",))
|
||||
|
@ -86,5 +93,5 @@ setup(
|
|||
zip_safe=False,
|
||||
long_description=long_description,
|
||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||
cmdclass={'test': Tox},
|
||||
cmdclass={'test': TestCommand},
|
||||
)
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.18.4"
|
||||
__version__ = "0.18.5"
|
||||
|
|
|
@ -39,6 +39,9 @@ AuthEventTypes = (
|
|||
EventTypes.ThirdPartyInvite,
|
||||
)
|
||||
|
||||
# guests always get this device id.
|
||||
GUEST_DEVICE_ID = "guest_device"
|
||||
|
||||
|
||||
class Auth(object):
|
||||
"""
|
||||
|
@ -51,17 +54,6 @@ class Auth(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
# Docs for these currently lives at
|
||||
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# In addition, we have type == delete_pusher which grants access only to
|
||||
# delete pushers.
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
"gen = ",
|
||||
"guest = ",
|
||||
"type = ",
|
||||
"time < ",
|
||||
"user_id = ",
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_from_context(self, event, context, do_sig_check=True):
|
||||
|
@ -685,31 +677,28 @@ class Auth(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_access_token(self, token, rights="access"):
|
||||
""" Get a registered user's ID.
|
||||
""" Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
token (str): The access token to get the user by.
|
||||
rights (str): The operation being performed; the access token must
|
||||
allow this.
|
||||
Returns:
|
||||
dict : dict that includes the user and the ID of their access token.
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.get_user_from_macaroon(token, rights)
|
||||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
if self.hs.config.expire_access_token:
|
||||
raise
|
||||
ret = yield self._look_up_user_by_access_token(token)
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
except Exception: # deserialize can throw more-or-less anything
|
||||
# doesn't look like a macaroon: treat it as an opaque token which
|
||||
# must be in the database.
|
||||
# TODO: it would be nice to get rid of this, but apparently some
|
||||
# people use access tokens which aren't macaroons
|
||||
r = yield self._look_up_user_by_access_token(token)
|
||||
defer.returnValue(r)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_macaroon(self, macaroon_str, rights="access"):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
|
@ -724,11 +713,36 @@ class Auth(object):
|
|||
guest = True
|
||||
|
||||
if guest:
|
||||
# Guest access tokens are not stored in the database (there can
|
||||
# only be one access token per guest, anyway).
|
||||
#
|
||||
# In order to prevent guest access tokens being used as regular
|
||||
# user access tokens (and hence getting around the invalidation
|
||||
# process), we look up the user id and check that it is indeed
|
||||
# a guest user.
|
||||
#
|
||||
# It would of course be much easier to store guest access
|
||||
# tokens in the database as well, but that would break existing
|
||||
# guest tokens.
|
||||
stored_user = yield self.store.get_user_by_id(user_id)
|
||||
if not stored_user:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Unknown user_id %s" % user_id,
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
if not stored_user["is_guest"]:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Guest access token used for regular user",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
# all guests get the same device id
|
||||
"device_id": GUEST_DEVICE_ID,
|
||||
}
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
|
@ -750,7 +764,7 @@ class Auth(object):
|
|||
# macaroon. They probably should be.
|
||||
# TODO: build the dictionary from the macaroon once the
|
||||
# above are fixed
|
||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||
ret = yield self._look_up_user_by_access_token(token)
|
||||
if ret["user"] != user:
|
||||
logger.error(
|
||||
"Macaroon user (%s) != DB user (%s)",
|
||||
|
@ -798,27 +812,38 @@ class Auth(object):
|
|||
|
||||
Args:
|
||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
||||
type_string(str): The kind of token required (e.g. "access",
|
||||
"delete_pusher")
|
||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||
This should really always be True, but no clients currently implement
|
||||
token refresh, so we can't enforce expiry yet.
|
||||
user_id (str): The user_id required
|
||||
"""
|
||||
v = pymacaroons.Verifier()
|
||||
|
||||
# the verifier runs a test for every caveat on the macaroon, to check
|
||||
# that it is met for the current request. Each caveat must match at
|
||||
# least one of the predicates specified by satisfy_exact or
|
||||
# specify_general.
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = " + type_string)
|
||||
v.satisfy_exact("user_id = %s" % user_id)
|
||||
v.satisfy_exact("guest = true")
|
||||
|
||||
# verify_expiry should really always be True, but there exist access
|
||||
# tokens in the wild which expire when they should not, so we can't
|
||||
# enforce expiry yet (so we have to allow any caveat starting with
|
||||
# 'time < ' in access tokens).
|
||||
#
|
||||
# On the other hand, short-term login tokens (as used by CAS login, for
|
||||
# example) have an expiry time which we do want to enforce.
|
||||
|
||||
if verify_expiry:
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
else:
|
||||
v.satisfy_general(lambda c: c.startswith("time < "))
|
||||
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_general(self._verify_recognizes_caveats)
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
def _verify_expiry(self, caveat):
|
||||
|
@ -829,15 +854,6 @@ class Auth(object):
|
|||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
def _verify_recognizes_caveats(self, caveat):
|
||||
first_space = caveat.find(" ")
|
||||
if first_space < 0:
|
||||
return False
|
||||
second_space = caveat.find(" ", first_space + 1)
|
||||
if second_space < 0:
|
||||
return False
|
||||
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _look_up_user_by_access_token(self, token):
|
||||
ret = yield self.store.get_user_by_access_token(token)
|
||||
|
|
|
@ -39,6 +39,7 @@ class Codes(object):
|
|||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||
MISSING_PARAM = "M_MISSING_PARAM"
|
||||
INVALID_PARAM = "M_INVALID_PARAM"
|
||||
TOO_LARGE = "M_TOO_LARGE"
|
||||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
|
|
|
@ -71,6 +71,21 @@ class Filtering(object):
|
|||
if key in user_filter_json["room"]:
|
||||
self._check_definition(user_filter_json["room"][key])
|
||||
|
||||
if "event_fields" in user_filter_json:
|
||||
if type(user_filter_json["event_fields"]) != list:
|
||||
raise SynapseError(400, "event_fields must be a list of strings")
|
||||
for field in user_filter_json["event_fields"]:
|
||||
if not isinstance(field, basestring):
|
||||
raise SynapseError(400, "Event field must be a string")
|
||||
# Don't allow '\\' in event field filters. This makes matching
|
||||
# events a lot easier as we can then use a negative lookbehind
|
||||
# assertion to split '\.' If we allowed \\ then it would
|
||||
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
||||
if r'\\' in field:
|
||||
raise SynapseError(
|
||||
400, r'The escape character \ cannot itself be escaped'
|
||||
)
|
||||
|
||||
def _check_definition_room_lists(self, definition):
|
||||
"""Check that "rooms" and "not_rooms" are lists of room ids if they
|
||||
are present
|
||||
|
@ -152,6 +167,7 @@ class FilterCollection(object):
|
|||
self.include_leave = filter_json.get("room", {}).get(
|
||||
"include_leave", False
|
||||
)
|
||||
self.event_fields = filter_json.get("event_fields", [])
|
||||
|
||||
def __repr__(self):
|
||||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||
|
@ -186,6 +202,26 @@ class FilterCollection(object):
|
|||
def filter_room_account_data(self, events):
|
||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
||||
|
||||
def blocks_all_presence(self):
|
||||
return (
|
||||
self._presence_filter.filters_all_types() or
|
||||
self._presence_filter.filters_all_senders()
|
||||
)
|
||||
|
||||
def blocks_all_room_ephemeral(self):
|
||||
return (
|
||||
self._room_ephemeral_filter.filters_all_types() or
|
||||
self._room_ephemeral_filter.filters_all_senders() or
|
||||
self._room_ephemeral_filter.filters_all_rooms()
|
||||
)
|
||||
|
||||
def blocks_all_room_timeline(self):
|
||||
return (
|
||||
self._room_timeline_filter.filters_all_types() or
|
||||
self._room_timeline_filter.filters_all_senders() or
|
||||
self._room_timeline_filter.filters_all_rooms()
|
||||
)
|
||||
|
||||
|
||||
class Filter(object):
|
||||
def __init__(self, filter_json):
|
||||
|
@ -202,6 +238,15 @@ class Filter(object):
|
|||
|
||||
self.contains_url = self.filter_json.get("contains_url", None)
|
||||
|
||||
def filters_all_types(self):
|
||||
return "*" in self.not_types
|
||||
|
||||
def filters_all_senders(self):
|
||||
return "*" in self.not_senders
|
||||
|
||||
def filters_all_rooms(self):
|
||||
return "*" in self.not_rooms
|
||||
|
||||
def check(self, event):
|
||||
"""Checks whether the filter matches the given event.
|
||||
|
||||
|
|
|
@ -0,0 +1,331 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.federation import send_queue
|
||||
from synapse.federation.units import Edu
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from synapse import events
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger("synapse.app.appservice")
|
||||
|
||||
|
||||
class FederationSenderSlaveStore(
|
||||
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||
SlavedRegistrationStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class FederationSenderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse federation_sender now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
send_handler = FederationSenderHandler(self)
|
||||
|
||||
send_handler.on_start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args.update((yield send_handler.stream_positions()))
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
yield send_handler.process_replication(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(30)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse federation sender", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.federation_sender"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.send_federation:
|
||||
sys.stderr.write(
|
||||
"\nThe send_federation must be disabled in the main synapse process"
|
||||
"\nbefore they can be run in a separate worker."
|
||||
"\nPlease add ``send_federation: false`` to the main config"
|
||||
"\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.send_federation = True
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ps = FederationSenderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
ps.get_datastore().start_profiling()
|
||||
ps.get_state_handler().start_caching()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-federation-sender",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
class FederationSenderHandler(object):
|
||||
"""Processes the replication stream and forwards the appropriate entries
|
||||
to the federation sender.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self._room_serials = {}
|
||||
self._room_typing = {}
|
||||
|
||||
def on_start(self):
|
||||
# There may be some events that are persisted but haven't been sent,
|
||||
# so send them now.
|
||||
self.federation_sender.notify_new_events(
|
||||
self.store.get_room_max_stream_ordering()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stream_positions(self):
|
||||
stream_id = yield self.store.get_federation_out_pos("federation")
|
||||
defer.returnValue({
|
||||
"federation": stream_id,
|
||||
|
||||
# Ack stuff we've "processed", this should only be called from
|
||||
# one process.
|
||||
"federation_ack": stream_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_replication(self, result):
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
fed_stream = result.get("federation")
|
||||
if fed_stream:
|
||||
latest_id = int(fed_stream["position"])
|
||||
|
||||
# The federation stream containis a bunch of different types of
|
||||
# rows that need to be handled differently. We parse the rows, put
|
||||
# them into the appropriate collection and then send them off.
|
||||
presence_to_send = {}
|
||||
keyed_edus = {}
|
||||
edus = {}
|
||||
failures = {}
|
||||
device_destinations = set()
|
||||
|
||||
# Parse the rows in the stream
|
||||
for row in fed_stream["rows"]:
|
||||
position, typ, content_js = row
|
||||
content = json.loads(content_js)
|
||||
|
||||
if typ == send_queue.PRESENCE_TYPE:
|
||||
destination = content["destination"]
|
||||
state = UserPresenceState.from_dict(content["state"])
|
||||
|
||||
presence_to_send.setdefault(destination, []).append(state)
|
||||
elif typ == send_queue.KEYED_EDU_TYPE:
|
||||
key = content["key"]
|
||||
edu = Edu(**content["edu"])
|
||||
|
||||
keyed_edus.setdefault(
|
||||
edu.destination, {}
|
||||
)[(edu.destination, tuple(key))] = edu
|
||||
elif typ == send_queue.EDU_TYPE:
|
||||
edu = Edu(**content)
|
||||
|
||||
edus.setdefault(edu.destination, []).append(edu)
|
||||
elif typ == send_queue.FAILURE_TYPE:
|
||||
destination = content["destination"]
|
||||
failure = content["failure"]
|
||||
|
||||
failures.setdefault(destination, []).append(failure)
|
||||
elif typ == send_queue.DEVICE_MESSAGE_TYPE:
|
||||
device_destinations.add(content["destination"])
|
||||
else:
|
||||
raise Exception("Unrecognised federation type: %r", typ)
|
||||
|
||||
# We've finished collecting, send everything off
|
||||
for destination, states in presence_to_send.items():
|
||||
self.federation_sender.send_presence(destination, states)
|
||||
|
||||
for destination, edu_map in keyed_edus.items():
|
||||
for key, edu in edu_map.items():
|
||||
self.federation_sender.send_edu(
|
||||
edu.destination, edu.edu_type, edu.content, key=key,
|
||||
)
|
||||
|
||||
for destination, edu_list in edus.items():
|
||||
for edu in edu_list:
|
||||
self.federation_sender.send_edu(
|
||||
edu.destination, edu.edu_type, edu.content, key=None,
|
||||
)
|
||||
|
||||
for destination, failure_list in failures.items():
|
||||
for failure in failure_list:
|
||||
self.federation_sender.send_failure(destination, failure)
|
||||
|
||||
for destination in device_destinations:
|
||||
self.federation_sender.send_device_messages(destination)
|
||||
|
||||
# Record where we are in the stream.
|
||||
yield self.store.update_federation_out_pos(
|
||||
"federation", latest_id
|
||||
)
|
||||
|
||||
# We also need to poke the federation sender when new events happen
|
||||
event_stream = result.get("events")
|
||||
if event_stream:
|
||||
latest_pos = event_stream["position"]
|
||||
self.federation_sender.notify_new_events(latest_pos)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -89,6 +89,9 @@ class ApplicationService(object):
|
|||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
|
||||
if "|" in self.id:
|
||||
raise Exception("application service ID cannot contain '|' character")
|
||||
|
||||
# .protocols is a publicly visible field
|
||||
if protocols:
|
||||
self.protocols = set(protocols)
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
|
|||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
" valid result", uri)
|
||||
defer.returnValue(None)
|
||||
|
||||
for instance in info.get("instances", []):
|
||||
network_id = instance.get("network_id", None)
|
||||
if network_id is not None:
|
||||
instance["instance_id"] = ThirdPartyInstanceID(
|
||||
service.id, network_id,
|
||||
).to_string()
|
||||
|
||||
defer.returnValue(info)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||
|
|
|
@ -50,6 +50,7 @@ handlers:
|
|||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
|
|
|
@ -27,17 +27,23 @@ class PasswordAuthProviderConfig(Config):
|
|||
ldap_config = config.get("ldap_config", {})
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
if self.ldap_enabled:
|
||||
from synapse.util.ldap_auth_provider import LdapAuthProvider
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
parsed_config = LdapAuthProvider.parse_config(ldap_config)
|
||||
self.password_providers.append((LdapAuthProvider, parsed_config))
|
||||
|
||||
providers = config.get("password_providers", [])
|
||||
for provider in providers:
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
# This is for backwards compat when the ldap auth provider resided
|
||||
# in this package.
|
||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
provider_class = LdapAuthProvider
|
||||
else:
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
|
@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config):
|
|||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# password_providers:
|
||||
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
|
||||
# - module: "ldap_auth_provider.LdapAuthProvider"
|
||||
# config:
|
||||
# enabled: true
|
||||
# uri: "ldap://ldap.example.com:389"
|
||||
|
|
|
@ -32,7 +32,6 @@ class RegistrationConfig(Config):
|
|||
)
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||
|
@ -55,11 +54,6 @@ class RegistrationConfig(Config):
|
|||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
||||
# Sets the expiry for the short term user creation in
|
||||
# milliseconds. For instance the bellow duration is two weeks
|
||||
# in milliseconds.
|
||||
user_creation_max_duration: 1209600000
|
||||
|
||||
# Set the number of bcrypt rounds used to generate password hash.
|
||||
# Larger numbers increase the work factor needed to generate the hash.
|
||||
# The default number of rounds is 12.
|
||||
|
|
|
@ -30,6 +30,11 @@ class ServerConfig(Config):
|
|||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
|
||||
# Whether to send federation traffic out in this process. This only
|
||||
# applies to some federation traffic, and so shouldn't be used to
|
||||
# "disable" federation
|
||||
self.send_federation = config.get("send_federation", True)
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
self.public_baseurl += '/'
|
||||
|
|
|
@ -16,6 +16,17 @@
|
|||
from synapse.api.constants import EventTypes
|
||||
from . import EventBase
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
import re
|
||||
|
||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||
# (?<!stuff) matches if the current position in the string is not preceded
|
||||
# by a match for 'stuff'.
|
||||
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
|
||||
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
|
||||
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
|
||||
|
||||
|
||||
def prune_event(event):
|
||||
""" Returns a pruned version of the given event, which removes all keys we
|
||||
|
@ -97,6 +108,83 @@ def prune_event(event):
|
|||
)
|
||||
|
||||
|
||||
def _copy_field(src, dst, field):
|
||||
"""Copy the field in 'src' to 'dst'.
|
||||
|
||||
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
|
||||
then dst={"foo":{"bar":5}}.
|
||||
|
||||
Args:
|
||||
src(dict): The dict to read from.
|
||||
dst(dict): The dict to modify.
|
||||
field(list<str>): List of keys to drill down to in 'src'.
|
||||
"""
|
||||
if len(field) == 0: # this should be impossible
|
||||
return
|
||||
if len(field) == 1: # common case e.g. 'origin_server_ts'
|
||||
if field[0] in src:
|
||||
dst[field[0]] = src[field[0]]
|
||||
return
|
||||
|
||||
# Else is a nested field e.g. 'content.body'
|
||||
# Pop the last field as that's the key to move across and we need the
|
||||
# parent dict in order to access the data. Drill down to the right dict.
|
||||
key_to_move = field.pop(-1)
|
||||
sub_dict = src
|
||||
for sub_field in field: # e.g. sub_field => "content"
|
||||
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
|
||||
sub_dict = sub_dict[sub_field]
|
||||
else:
|
||||
return
|
||||
|
||||
if key_to_move not in sub_dict:
|
||||
return
|
||||
|
||||
# Insert the key into the output dictionary, creating nested objects
|
||||
# as required. We couldn't do this any earlier or else we'd need to delete
|
||||
# the empty objects if the key didn't exist.
|
||||
sub_out_dict = dst
|
||||
for sub_field in field:
|
||||
sub_out_dict = sub_out_dict.setdefault(sub_field, {})
|
||||
sub_out_dict[key_to_move] = sub_dict[key_to_move]
|
||||
|
||||
|
||||
def only_fields(dictionary, fields):
|
||||
"""Return a new dict with only the fields in 'dictionary' which are present
|
||||
in 'fields'.
|
||||
|
||||
If there are no event fields specified then all fields are included.
|
||||
The entries may include '.' charaters to indicate sub-fields.
|
||||
So ['content.body'] will include the 'body' field of the 'content' object.
|
||||
A literal '.' character in a field name may be escaped using a '\'.
|
||||
|
||||
Args:
|
||||
dictionary(dict): The dictionary to read from.
|
||||
fields(list<str>): A list of fields to copy over. Only shallow refs are
|
||||
taken.
|
||||
Returns:
|
||||
dict: A new dictionary with only the given fields. If fields was empty,
|
||||
the same dictionary is returned.
|
||||
"""
|
||||
if len(fields) == 0:
|
||||
return dictionary
|
||||
|
||||
# for each field, convert it:
|
||||
# ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
|
||||
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
|
||||
|
||||
# for each element of the output array of arrays:
|
||||
# remove escaping so we can use the right key names.
|
||||
split_fields[:] = [
|
||||
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
|
||||
]
|
||||
|
||||
output = {}
|
||||
for field_array in split_fields:
|
||||
_copy_field(dictionary, output, field_array)
|
||||
return output
|
||||
|
||||
|
||||
def format_event_raw(d):
|
||||
return d
|
||||
|
||||
|
@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d):
|
|||
|
||||
def serialize_event(e, time_now_ms, as_client_event=True,
|
||||
event_format=format_event_for_client_v1,
|
||||
token_id=None):
|
||||
token_id=None, only_event_fields=None):
|
||||
# FIXME(erikj): To handle the case of presence events and the like
|
||||
if not isinstance(e, EventBase):
|
||||
return e
|
||||
|
@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||
d["unsigned"]["transaction_id"] = txn_id
|
||||
|
||||
if as_client_event:
|
||||
return event_format(d)
|
||||
else:
|
||||
return d
|
||||
d = event_format(d)
|
||||
|
||||
if only_event_fields:
|
||||
if (not isinstance(only_event_fields, list) or
|
||||
not all(isinstance(f, basestring) for f in only_event_fields)):
|
||||
raise TypeError("only_event_fields must be a list of strings")
|
||||
d = only_fields(d, only_event_fields)
|
||||
|
||||
return d
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
"""
|
||||
|
||||
from .replication import ReplicationLayer
|
||||
from .transport.client import TransportLayerClient
|
||||
|
||||
|
||||
def initialize_http_replication(homeserver):
|
||||
transport = TransportLayerClient(homeserver)
|
||||
def initialize_http_replication(hs):
|
||||
transport = hs.get_federation_transport_client()
|
||||
|
||||
return ReplicationLayer(homeserver, transport)
|
||||
return ReplicationLayer(hs, transport)
|
||||
|
|
|
@ -18,7 +18,6 @@ from twisted.internet import defer
|
|||
|
||||
from .federation_base import FederationBase
|
||||
from synapse.api.constants import Membership
|
||||
from .units import Edu
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
|
@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
|
|||
# synapse.federation.federation_client is a silly name
|
||||
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
|
||||
|
||||
sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
|
||||
|
||||
sent_edus_counter = metrics.register_counter("sent_edus")
|
||||
|
||||
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||
|
||||
|
||||
|
@ -92,63 +87,6 @@ class FederationClient(FederationBase):
|
|||
|
||||
self._get_pdu_cache.start()
|
||||
|
||||
@log_function
|
||||
def send_pdu(self, pdu, destinations):
|
||||
"""Informs the replication layer about a new PDU generated within the
|
||||
home server that should be transmitted to others.
|
||||
|
||||
TODO: Figure out when we should actually resolve the deferred.
|
||||
|
||||
Args:
|
||||
pdu (Pdu): The new Pdu.
|
||||
|
||||
Returns:
|
||||
Deferred: Completes when we have successfully processed the PDU
|
||||
and replicated it to any interested remote home servers.
|
||||
"""
|
||||
order = self._order
|
||||
self._order += 1
|
||||
|
||||
sent_pdus_destination_dist.inc_by(len(destinations))
|
||||
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
|
||||
|
||||
# TODO, add errback, etc.
|
||||
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
|
||||
|
||||
logger.debug(
|
||||
"[%s] transaction_layer.enqueue_pdu... done",
|
||||
pdu.event_id
|
||||
)
|
||||
|
||||
def send_presence(self, destination, states):
|
||||
if destination != self.server_name:
|
||||
self._transaction_queue.enqueue_presence(destination, states)
|
||||
|
||||
@log_function
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
edu = Edu(
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
sent_edus_counter.inc()
|
||||
|
||||
self._transaction_queue.enqueue_edu(edu, key=key)
|
||||
|
||||
@log_function
|
||||
def send_device_messages(self, destination):
|
||||
"""Sends the device messages in the local database to the remote
|
||||
destination"""
|
||||
self._transaction_queue.enqueue_device_messages(destination)
|
||||
|
||||
@log_function
|
||||
def send_failure(self, failure, destination):
|
||||
self._transaction_queue.enqueue_failure(failure, destination)
|
||||
return defer.succeed(None)
|
||||
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args,
|
||||
retry_on_dns_fail=False):
|
||||
|
@ -717,12 +655,15 @@ class FederationClient(FederationBase):
|
|||
raise RuntimeError("Failed to send to any server.")
|
||||
|
||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||
search_filter=None):
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None):
|
||||
if destination == self.server_name:
|
||||
return
|
||||
|
||||
return self.transport_layer.get_public_rooms(
|
||||
destination, limit, since_token, search_filter
|
||||
destination, limit, since_token, search_filter,
|
||||
include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -20,8 +20,6 @@ a given transport.
|
|||
from .federation_client import FederationClient
|
||||
from .federation_server import FederationServer
|
||||
|
||||
from .transaction_queue import TransactionQueue
|
||||
|
||||
from .persistence import TransactionActions
|
||||
|
||||
import logging
|
||||
|
@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
|
|||
self._clock = hs.get_clock()
|
||||
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
self._transaction_queue = TransactionQueue(hs, transport_layer)
|
||||
|
||||
self._order = 0
|
||||
|
||||
self.hs = hs
|
||||
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A federation sender that forwards things to be sent across replication to
|
||||
a worker process.
|
||||
|
||||
It assumes there is a single worker process feeding off of it.
|
||||
|
||||
Each row in the replication stream consists of a type and some json, where the
|
||||
types indicate whether they are presence, or edus, etc.
|
||||
|
||||
Ephemeral or non-event data are queued up in-memory. When the worker requests
|
||||
updates since a particular point, all in-memory data since before that point is
|
||||
dropped. We also expire things in the queue after 5 minutes, to ensure that a
|
||||
dead worker doesn't cause the queues to grow limitlessly.
|
||||
|
||||
Events are replicated via a separate events stream.
|
||||
"""
|
||||
|
||||
from .units import Edu
|
||||
|
||||
from synapse.util.metrics import Measure
|
||||
import synapse.metrics
|
||||
|
||||
from blist import sorteddict
|
||||
import ujson
|
||||
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
|
||||
PRESENCE_TYPE = "p"
|
||||
KEYED_EDU_TYPE = "k"
|
||||
EDU_TYPE = "e"
|
||||
FAILURE_TYPE = "f"
|
||||
DEVICE_MESSAGE_TYPE = "d"
|
||||
|
||||
|
||||
class FederationRemoteSendQueue(object):
|
||||
"""A drop in replacement for TransactionQueue"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.presence_map = {}
|
||||
self.presence_changed = sorteddict()
|
||||
|
||||
self.keyed_edu = {}
|
||||
self.keyed_edu_changed = sorteddict()
|
||||
|
||||
self.edus = sorteddict()
|
||||
|
||||
self.failures = sorteddict()
|
||||
|
||||
self.device_messages = sorteddict()
|
||||
|
||||
self.pos = 1
|
||||
self.pos_time = sorteddict()
|
||||
|
||||
# EVERYTHING IS SAD. In particular, python only makes new scopes when
|
||||
# we make a new function, so we need to make a new function so the inner
|
||||
# lambda binds to the queue rather than to the name of the queue which
|
||||
# changes. ARGH.
|
||||
def register(name, queue):
|
||||
metrics.register_callback(
|
||||
queue_name + "_size",
|
||||
lambda: len(queue),
|
||||
)
|
||||
|
||||
for queue_name in [
|
||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||
"edus", "failures", "device_messages", "pos_time",
|
||||
]:
|
||||
register(queue_name, getattr(self, queue_name))
|
||||
|
||||
self.clock.looping_call(self._clear_queue, 30 * 1000)
|
||||
|
||||
def _next_pos(self):
|
||||
pos = self.pos
|
||||
self.pos += 1
|
||||
self.pos_time[self.clock.time_msec()] = pos
|
||||
return pos
|
||||
|
||||
def _clear_queue(self):
|
||||
"""Clear the queues for anything older than N minutes"""
|
||||
|
||||
FIVE_MINUTES_AGO = 5 * 60 * 1000
|
||||
now = self.clock.time_msec()
|
||||
|
||||
keys = self.pos_time.keys()
|
||||
time = keys.bisect_left(now - FIVE_MINUTES_AGO)
|
||||
if not keys[:time]:
|
||||
return
|
||||
|
||||
position_to_delete = max(keys[:time])
|
||||
for key in keys[:time]:
|
||||
del self.pos_time[key]
|
||||
|
||||
self._clear_queue_before_pos(position_to_delete)
|
||||
|
||||
def _clear_queue_before_pos(self, position_to_delete):
|
||||
"""Clear all the queues from before a given position"""
|
||||
with Measure(self.clock, "send_queue._clear"):
|
||||
# Delete things out of presence maps
|
||||
keys = self.presence_changed.keys()
|
||||
i = keys.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.presence_changed[key]
|
||||
|
||||
user_ids = set(
|
||||
user_id for uids in self.presence_changed.values() for _, user_id in uids
|
||||
)
|
||||
|
||||
to_del = [
|
||||
user_id for user_id in self.presence_map if user_id not in user_ids
|
||||
]
|
||||
for user_id in to_del:
|
||||
del self.presence_map[user_id]
|
||||
|
||||
# Delete things out of keyed edus
|
||||
keys = self.keyed_edu_changed.keys()
|
||||
i = keys.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.keyed_edu_changed[key]
|
||||
|
||||
live_keys = set()
|
||||
for edu_key in self.keyed_edu_changed.values():
|
||||
live_keys.add(edu_key)
|
||||
|
||||
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
|
||||
for edu_key in to_del:
|
||||
del self.keyed_edu[edu_key]
|
||||
|
||||
# Delete things out of edu map
|
||||
keys = self.edus.keys()
|
||||
i = keys.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.edus[key]
|
||||
|
||||
# Delete things out of failure map
|
||||
keys = self.failures.keys()
|
||||
i = keys.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.failures[key]
|
||||
|
||||
# Delete things out of device map
|
||||
keys = self.device_messages.keys()
|
||||
i = keys.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.device_messages[key]
|
||||
|
||||
def notify_new_events(self, current_id):
|
||||
"""As per TransactionQueue"""
|
||||
# We don't need to replicate this as it gets sent down a different
|
||||
# stream.
|
||||
pass
|
||||
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
edu = Edu(
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if key:
|
||||
assert isinstance(key, tuple)
|
||||
self.keyed_edu[(destination, key)] = edu
|
||||
self.keyed_edu_changed[pos] = (destination, key)
|
||||
else:
|
||||
self.edus[pos] = edu
|
||||
|
||||
def send_presence(self, destination, states):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
self.presence_map.update({
|
||||
state.user_id: state
|
||||
for state in states
|
||||
})
|
||||
|
||||
self.presence_changed[pos] = [
|
||||
(destination, state.user_id) for state in states
|
||||
]
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
self.failures[pos] = (destination, str(failure))
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
self.device_messages[pos] = destination
|
||||
|
||||
def get_current_token(self):
|
||||
return self.pos - 1
|
||||
|
||||
def get_replication_rows(self, token, limit, federation_ack=None):
|
||||
"""
|
||||
Args:
|
||||
token (int)
|
||||
limit (int)
|
||||
federation_ack (int): Optional. The position where the worker is
|
||||
explicitly acknowledged it has handled. Allows us to drop
|
||||
data from before that point
|
||||
"""
|
||||
# TODO: Handle limit.
|
||||
|
||||
# To handle restarts where we wrap around
|
||||
if token > self.pos:
|
||||
token = -1
|
||||
|
||||
rows = []
|
||||
|
||||
# There should be only one reader, so lets delete everything its
|
||||
# acknowledged its seen.
|
||||
if federation_ack:
|
||||
self._clear_queue_before_pos(federation_ack)
|
||||
|
||||
# Fetch changed presence
|
||||
keys = self.presence_changed.keys()
|
||||
i = keys.bisect_right(token)
|
||||
dest_user_ids = set(
|
||||
(pos, dest_user_id)
|
||||
for pos in keys[i:]
|
||||
for dest_user_id in self.presence_changed[pos]
|
||||
)
|
||||
|
||||
for (key, (dest, user_id)) in dest_user_ids:
|
||||
rows.append((key, PRESENCE_TYPE, ujson.dumps({
|
||||
"destination": dest,
|
||||
"state": self.presence_map[user_id].as_dict(),
|
||||
})))
|
||||
|
||||
# Fetch changes keyed edus
|
||||
keys = self.keyed_edu_changed.keys()
|
||||
i = keys.bisect_right(token)
|
||||
keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
|
||||
|
||||
for (pos, (destination, edu_key)) in keyed_edus:
|
||||
rows.append(
|
||||
(pos, KEYED_EDU_TYPE, ujson.dumps({
|
||||
"key": edu_key,
|
||||
"edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
|
||||
}))
|
||||
)
|
||||
|
||||
# Fetch changed edus
|
||||
keys = self.edus.keys()
|
||||
i = keys.bisect_right(token)
|
||||
edus = set((k, self.edus[k]) for k in keys[i:])
|
||||
|
||||
for (pos, edu) in edus:
|
||||
rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
|
||||
|
||||
# Fetch changed failures
|
||||
keys = self.failures.keys()
|
||||
i = keys.bisect_right(token)
|
||||
failures = set((k, self.failures[k]) for k in keys[i:])
|
||||
|
||||
for (pos, (destination, failure)) in failures:
|
||||
rows.append((pos, FAILURE_TYPE, ujson.dumps({
|
||||
"destination": destination,
|
||||
"failure": failure,
|
||||
})))
|
||||
|
||||
# Fetch changed device messages
|
||||
keys = self.device_messages.keys()
|
||||
i = keys.bisect_right(token)
|
||||
device_messages = set((k, self.device_messages[k]) for k in keys[i:])
|
||||
|
||||
for (pos, destination) in device_messages:
|
||||
rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
|
||||
"destination": destination,
|
||||
})))
|
||||
|
||||
# Sort rows based on pos
|
||||
rows.sort()
|
||||
|
||||
return rows
|
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
|||
from .persistence import TransactionActions
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
|
@ -26,6 +27,7 @@ from synapse.util.retryutils import (
|
|||
get_retry_limiter, NotRetryingDestination,
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
import synapse.metrics
|
||||
|
||||
|
@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
|
||||
sent_pdus_destination_dist = client_metrics.register_distribution(
|
||||
"sent_pdu_destinations"
|
||||
)
|
||||
sent_edus_counter = client_metrics.register_counter("sent_edus")
|
||||
|
||||
|
||||
class TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
|
@ -44,13 +52,14 @@ class TransactionQueue(object):
|
|||
It batches pending PDUs into single transactions.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, transport_layer):
|
||||
def __init__(self, hs):
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
|
||||
self.transport_layer = transport_layer
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -95,6 +104,11 @@ class TransactionQueue(object):
|
|||
# HACK to get unique tx id
|
||||
self._next_txn_id = int(self.clock.time_msec())
|
||||
|
||||
self._order = 1
|
||||
|
||||
self._is_processing = False
|
||||
self._last_poked_id = -1
|
||||
|
||||
def can_send_to(self, destination):
|
||||
"""Can we send messages to the given server?
|
||||
|
||||
|
@ -115,11 +129,61 @@ class TransactionQueue(object):
|
|||
else:
|
||||
return not destination.startswith("localhost")
|
||||
|
||||
def enqueue_pdu(self, pdu, destinations, order):
|
||||
@defer.inlineCallbacks
|
||||
def notify_new_events(self, current_id):
|
||||
"""This gets called when we have some new events we might want to
|
||||
send out to other servers.
|
||||
"""
|
||||
self._last_poked_id = max(current_id, self._last_poked_id)
|
||||
|
||||
if self._is_processing:
|
||||
return
|
||||
|
||||
try:
|
||||
self._is_processing = True
|
||||
while True:
|
||||
last_token = yield self.store.get_federation_out_pos("events")
|
||||
next_token, events = yield self.store.get_all_new_events_stream(
|
||||
last_token, self._last_poked_id, limit=20,
|
||||
)
|
||||
|
||||
logger.debug("Handling %s -> %s", last_token, next_token)
|
||||
|
||||
if not events and next_token >= self._last_poked_id:
|
||||
break
|
||||
|
||||
for event in events:
|
||||
users_in_room = yield self.state.get_current_user_in_room(
|
||||
event.room_id, latest_event_ids=[event.event_id],
|
||||
)
|
||||
|
||||
destinations = set(
|
||||
get_domain_from_id(user_id) for user_id in users_in_room
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(event.state_key))
|
||||
|
||||
logger.debug("Sending %s to %r", event, destinations)
|
||||
|
||||
self._send_pdu(event, destinations)
|
||||
|
||||
yield self.store.update_federation_out_pos(
|
||||
"events", next_token
|
||||
)
|
||||
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
def _send_pdu(self, pdu, destinations):
|
||||
# We loop through all destinations to see whether we already have
|
||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||
# table and we'll get back to it later.
|
||||
|
||||
order = self._order
|
||||
self._order += 1
|
||||
|
||||
destinations = set(destinations)
|
||||
destinations = set(
|
||||
dest for dest in destinations if self.can_send_to(dest)
|
||||
|
@ -130,6 +194,8 @@ class TransactionQueue(object):
|
|||
if not destinations:
|
||||
return
|
||||
|
||||
sent_pdus_destination_dist.inc_by(len(destinations))
|
||||
|
||||
for destination in destinations:
|
||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||
(pdu, order)
|
||||
|
@ -139,7 +205,10 @@ class TransactionQueue(object):
|
|||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def enqueue_presence(self, destination, states):
|
||||
def send_presence(self, destination, states):
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
self.pending_presence_by_dest.setdefault(destination, {}).update({
|
||||
state.user_id: state for state in states
|
||||
})
|
||||
|
@ -148,12 +217,19 @@ class TransactionQueue(object):
|
|||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def enqueue_edu(self, edu, key=None):
|
||||
destination = edu.destination
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
edu = Edu(
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
sent_edus_counter.inc()
|
||||
|
||||
if key:
|
||||
self.pending_edus_keyed_by_dest.setdefault(
|
||||
destination, {}
|
||||
|
@ -165,7 +241,7 @@ class TransactionQueue(object):
|
|||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def enqueue_failure(self, failure, destination):
|
||||
def send_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
||||
|
@ -180,7 +256,7 @@ class TransactionQueue(object):
|
|||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def enqueue_device_messages(self, destination):
|
||||
def send_device_messages(self, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
||||
|
@ -191,6 +267,9 @@ class TransactionQueue(object):
|
|||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def get_current_token(self):
|
||||
return 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _attempt_new_transaction(self, destination):
|
||||
# list of (pending_pdu, deferred, order)
|
||||
|
@ -383,6 +462,13 @@ class TransactionQueue(object):
|
|||
code = e.code
|
||||
response = e.response
|
||||
|
||||
if e.code == 429 or 500 <= e.code:
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
|
|
|
@ -249,10 +249,15 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(self, remote_server, limit, since_token,
|
||||
search_filter=None):
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None):
|
||||
path = PREFIX + "/publicRooms"
|
||||
|
||||
args = {}
|
||||
args = {
|
||||
"include_all_networks": "true" if include_all_networks else "false",
|
||||
}
|
||||
if third_party_instance_id:
|
||||
args["third_party_instance_id"] = third_party_instance_id,
|
||||
if limit:
|
||||
args["limit"] = [str(limit)]
|
||||
if since_token:
|
||||
|
|
|
@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
|
|||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||
parse_boolean_from_args,
|
||||
)
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
|
|||
def on_GET(self, origin, content, query):
|
||||
limit = parse_integer_from_args(query, "limit", 0)
|
||||
since_token = parse_string_from_args(query, "since", None)
|
||||
include_all_networks = parse_boolean_from_args(
|
||||
query, "include_all_networks", False
|
||||
)
|
||||
third_party_instance_id = parse_string_from_args(
|
||||
query, "third_party_instance_id", None
|
||||
)
|
||||
|
||||
if include_all_networks:
|
||||
network_tuple = None
|
||||
elif third_party_instance_id:
|
||||
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||
else:
|
||||
network_tuple = ThirdPartyInstanceID(None, None)
|
||||
|
||||
data = yield self.room_list_handler.get_local_public_room_list(
|
||||
limit, since_token
|
||||
limit, since_token,
|
||||
network_tuple=network_tuple
|
||||
)
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ from .profile import ProfileHandler
|
|||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
from .search import SearchHandler
|
||||
|
||||
|
||||
|
@ -56,7 +55,6 @@ class Handlers(object):
|
|||
self.profile_handler = ProfileHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
self.search_handler = SearchHandler(hs)
|
||||
self.room_context_handler = RoomContextHandler(hs)
|
||||
|
|
|
@ -61,6 +61,8 @@ class AuthHandler(BaseHandler):
|
|||
for module, config in hs.config.password_providers
|
||||
]
|
||||
|
||||
logger.info("Extra password_providers: %r", self.password_providers)
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
|
@ -160,7 +162,15 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
for f in flows:
|
||||
if len(set(f) - set(creds.keys())) == 0:
|
||||
logger.info("Auth completed with creds: %r", creds)
|
||||
# it's very useful to know what args are stored, but this can
|
||||
# include the password in the case of registering, so only log
|
||||
# the keys (confusingly, clientdict may contain a password
|
||||
# param, creds is just what the user authed as for UI auth
|
||||
# and is not sensitive).
|
||||
logger.info(
|
||||
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||
creds, clientdict.keys()
|
||||
)
|
||||
defer.returnValue((True, creds, clientdict, session['id']))
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
|
@ -378,12 +388,10 @@ class AuthHandler(BaseHandler):
|
|||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
Creates a new access/refresh token for the user.
|
||||
Creates a new access token for the user with the given user ID.
|
||||
|
||||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
@ -398,16 +406,13 @@ class AuthHandler(BaseHandler):
|
|||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
A tuple of:
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
|
@ -418,7 +423,7 @@ class AuthHandler(BaseHandler):
|
|||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_user_exists(self, user_id):
|
||||
|
@ -529,35 +534,19 @@ class AuthHandler(BaseHandler):
|
|||
device_id)
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def issue_refresh_token(self, user_id, device_id=None):
|
||||
refresh_token = self.generate_refresh_token(user_id)
|
||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
|
||||
device_id)
|
||||
defer.returnValue(refresh_token)
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None,
|
||||
duration_in_ms=(60 * 60 * 1000)):
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
# Include a nonce, to make sure that each login gets a different
|
||||
# access token.
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (
|
||||
stringutils.random_string_with_symbols(16),
|
||||
))
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_refresh_token(self, user_id):
|
||||
m = self._generate_base_macaroon(user_id)
|
||||
m.add_first_party_caveat("type = refresh")
|
||||
# Important to add a nonce, because otherwise every refresh token for a
|
||||
# user will be the same.
|
||||
m.add_first_party_caveat("nonce = %s" % (
|
||||
stringutils.random_string_with_symbols(16),
|
||||
))
|
||||
return m.serialize()
|
||||
|
||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
|
|
|
@ -34,9 +34,9 @@ class DeviceMessageHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation = hs.get_federation_sender()
|
||||
|
||||
self.federation.register_edu_handler(
|
||||
hs.get_replication_layer().register_edu_handler(
|
||||
"m.direct_to_device", self.on_direct_to_device_edu
|
||||
)
|
||||
|
||||
|
|
|
@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler):
|
|||
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
||||
|
||||
yield self.store.set_room_is_public(room_id, visibility == "public")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_appservice_room_list(self, appservice_id, network_id,
|
||||
room_id, visibility):
|
||||
"""Add or remove a room from the appservice/network specific public
|
||||
room list.
|
||||
|
||||
Args:
|
||||
appservice_id (str): ID of the appservice that owns the list
|
||||
network_id (str): The ID of the network the list is associated with
|
||||
room_id (str)
|
||||
visibility (str): either "public" or "private"
|
||||
"""
|
||||
if visibility not in ["public", "private"]:
|
||||
raise SynapseError(400, "Invalid visibility setting")
|
||||
|
||||
yield self.store.set_room_is_public_appservice(
|
||||
room_id, appservice_id, network_id, visibility == "public"
|
||||
)
|
||||
|
|
|
@ -111,6 +111,11 @@ class E2eKeysHandler(object):
|
|||
failures[destination] = {
|
||||
"status": 503, "message": "Not ready for retry",
|
||||
}
|
||||
except Exception as e:
|
||||
# include ConnectionRefused and other errors
|
||||
failures[destination] = {
|
||||
"status": 503, "message": e.message
|
||||
}
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(do_remote_query)(destination)
|
||||
|
@ -222,6 +227,11 @@ class E2eKeysHandler(object):
|
|||
failures[destination] = {
|
||||
"status": 503, "message": "Not ready for retry",
|
||||
}
|
||||
except Exception as e:
|
||||
# include ConnectionRefused and other errors
|
||||
failures[destination] = {
|
||||
"status": 503, "message": e.message
|
||||
}
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(claim_client_keys)(destination)
|
||||
|
|
|
@ -80,22 +80,6 @@ class FederationHandler(BaseHandler):
|
|||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
|
||||
def handle_new_event(self, event, destinations):
|
||||
""" Takes in an event from the client to server side, that has already
|
||||
been authed and handled by the state module, and sends it to any
|
||||
remote home servers that may be interested.
|
||||
|
||||
Args:
|
||||
event: The event to send
|
||||
destinations: A list of destinations to send it to
|
||||
|
||||
Returns:
|
||||
Deferred: Resolved when it has successfully been queued for
|
||||
processing.
|
||||
"""
|
||||
|
||||
return self.replication_layer.send_pdu(event, destinations)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
||||
|
@ -268,9 +252,12 @@ class FederationHandler(BaseHandler):
|
|||
except:
|
||||
return False
|
||||
|
||||
# Parses mapping `event_id -> (type, state_key) -> state event_id`
|
||||
# to get all state ids that we're interested in.
|
||||
event_map = yield self.store.get_events([
|
||||
e_id for key_to_eid in event_to_state_ids.values()
|
||||
for key, e_id in key_to_eid
|
||||
e_id
|
||||
for key_to_eid in event_to_state_ids.values()
|
||||
for key, e_id in key_to_eid.items()
|
||||
if key[0] != EventTypes.Member or check_match(key[1])
|
||||
])
|
||||
|
||||
|
@ -830,25 +817,6 @@ class FederationHandler(BaseHandler):
|
|||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
|
||||
new_pdu = event
|
||||
|
||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
destinations = set(
|
||||
get_domain_from_id(user_id) for user_id in users_in_room
|
||||
if not self.hs.is_mine_id(user_id)
|
||||
)
|
||||
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
"on_send_join_request: Sending event: %s, signatures: %s",
|
||||
event.event_id,
|
||||
event.signatures,
|
||||
)
|
||||
|
||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||
|
||||
state_ids = context.prev_state_ids.values()
|
||||
auth_chain = yield self.store.get_auth_chain(set(
|
||||
[event.event_id] + state_ids
|
||||
|
@ -1055,24 +1023,6 @@ class FederationHandler(BaseHandler):
|
|||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
)
|
||||
|
||||
new_pdu = event
|
||||
|
||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
destinations = set(
|
||||
get_domain_from_id(user_id) for user_id in users_in_room
|
||||
if not self.hs.is_mine_id(user_id)
|
||||
)
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
"on_send_leave_request: Sending event: %s, signatures: %s",
|
||||
event.event_id,
|
||||
event.signatures,
|
||||
)
|
||||
|
||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts():
|
||||
receipts_handler = self.hs.get_handlers().receipts_handler
|
||||
receipts = yield receipts_handler.get_receipts_for_room(
|
||||
receipts = yield self.store.get_linearized_receipts_for_room(
|
||||
room_id,
|
||||
now_token.receipt_key
|
||||
to_key=now_token.receipt_key,
|
||||
)
|
||||
if not receipts:
|
||||
receipts = []
|
||||
defer.returnValue(receipts)
|
||||
|
||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||
|
|
|
@ -22,9 +22,9 @@ from synapse.events.utils import serialize_event
|
|||
from synapse.events.validator import EventValidator
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.types import (
|
||||
UserID, RoomAlias, RoomStreamToken, get_domain_from_id
|
||||
UserID, RoomAlias, RoomStreamToken,
|
||||
)
|
||||
from synapse.util.async import run_on_reactor, ReadWriteLock
|
||||
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -50,6 +50,10 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
@ -191,36 +195,38 @@ class MessageHandler(BaseHandler):
|
|||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
self.validator.validate_new(builder)
|
||||
with (yield self.limiter.queue(builder.room_id)):
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
content = builder.content
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
try:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -599,13 +605,6 @@ class MessageHandler(BaseHandler):
|
|||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
destinations = [
|
||||
get_domain_from_id(user_id) for user_id in users_in_room
|
||||
if not self.hs.is_mine_id(user_id)
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _notify():
|
||||
yield run_on_reactor()
|
||||
|
@ -618,7 +617,3 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
# If invite, remove room_state from unsigned before sending.
|
||||
event.unsigned.pop("invite_room_state", None)
|
||||
|
||||
preserve_fn(federation_handler.handle_new_event)(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
|
|
@ -91,28 +91,29 @@ class PresenceHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.wheel_timer = WheelTimer()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.replication = hs.get_replication_layer()
|
||||
self.federation = hs.get_federation_sender()
|
||||
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self.federation.register_edu_handler(
|
||||
self.replication.register_edu_handler(
|
||||
"m.presence", self.incoming_presence
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
self.replication.register_edu_handler(
|
||||
"m.presence_invite",
|
||||
lambda origin, content: self.invite_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
observer_user=UserID.from_string(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
self.replication.register_edu_handler(
|
||||
"m.presence_accept",
|
||||
lambda origin, content: self.accept_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
observer_user=UserID.from_string(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
self.replication.register_edu_handler(
|
||||
"m.presence_deny",
|
||||
lambda origin, content: self.deny_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
|
|
|
@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler):
|
|||
self.server_name = hs.config.server_name
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_edu_handler(
|
||||
self.federation = hs.get_federation_sender()
|
||||
hs.get_replication_layer().register_edu_handler(
|
||||
"m.receipt", self._received_remote_receipt
|
||||
)
|
||||
self.clock = self.hs.get_clock()
|
||||
|
@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler):
|
|||
|
||||
if not res:
|
||||
# res will be None if this read receipt is 'old'
|
||||
defer.returnValue(False)
|
||||
continue
|
||||
|
||||
stream_id, max_persisted_id = res
|
||||
|
||||
|
@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler):
|
|||
if max_batch_id is None or max_persisted_id > max_batch_id:
|
||||
max_batch_id = max_persisted_id
|
||||
|
||||
if min_batch_id is None:
|
||||
# no new receipts
|
||||
defer.returnValue(False)
|
||||
|
||||
affected_room_ids = list(set([r["room_id"] for r in receipts]))
|
||||
|
||||
with PreserveLoggingContext():
|
||||
|
|
|
@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler):
|
|||
"User ID already taken.",
|
||||
errcode=Codes.USER_IN_USE,
|
||||
)
|
||||
user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
|
||||
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
|
||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
||||
raise AuthError(
|
||||
403,
|
||||
|
@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler):
|
|||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
|
||||
def get_or_create_user(self, requester, localpart, displayname,
|
||||
password_hash=None):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
token = self.auth_handler().generate_access_token(
|
||||
user_id, None, duration_in_ms)
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
|
|
|
@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler):
|
|||
"join_rules": JoinRules.INVITE,
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": False,
|
||||
"guest_can_join": True,
|
||||
},
|
||||
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
|
||||
"join_rules": JoinRules.INVITE,
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": True,
|
||||
"guest_can_join": True,
|
||||
},
|
||||
RoomCreationPreset.PUBLIC_CHAT: {
|
||||
"join_rules": JoinRules.PUBLIC,
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": False,
|
||||
"guest_can_join": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler):
|
|||
content={"history_visibility": config["history_visibility"]}
|
||||
)
|
||||
|
||||
if config["guest_can_join"]:
|
||||
if (EventTypes.GuestAccess, '') not in initial_state:
|
||||
yield send(
|
||||
etype=EventTypes.GuestAccess,
|
||||
content={"guest_access": "can_join"}
|
||||
)
|
||||
|
||||
for (etype, state_key), content in initial_state.items():
|
||||
yield send(
|
||||
etype=etype,
|
||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.constants import (
|
|||
)
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
|
||||
from collections import namedtuple
|
||||
from unpaddedbase64 import encode_base64, decode_base64
|
||||
|
@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
|
|||
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
||||
|
||||
|
||||
# This is used to indicate we should only return rooms published to the main list.
|
||||
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||
|
||||
|
||||
class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomListHandler, self).__init__(hs)
|
||||
|
@ -41,22 +46,43 @@ class RoomListHandler(BaseHandler):
|
|||
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
||||
|
||||
def get_local_public_room_list(self, limit=None, since_token=None,
|
||||
search_filter=None):
|
||||
if search_filter:
|
||||
# We explicitly don't bother caching searches.
|
||||
return self._get_public_room_list(limit, since_token, search_filter)
|
||||
search_filter=None,
|
||||
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||
"""Generate a local public room list.
|
||||
|
||||
There are multiple different lists: the main one plus one per third
|
||||
party network. A client can ask for a specific list or to return all.
|
||||
|
||||
Args:
|
||||
limit (int)
|
||||
since_token (str)
|
||||
search_filter (dict)
|
||||
network_tuple (ThirdPartyInstanceID): Which public list to use.
|
||||
This can be (None, None) to indicate the main list, or a particular
|
||||
appservice and network id to use an appservice specific one.
|
||||
Setting to None returns all public rooms across all lists.
|
||||
"""
|
||||
if search_filter or (network_tuple and network_tuple.appservice_id is not None):
|
||||
# We explicitly don't bother caching searches or requests for
|
||||
# appservice specific lists.
|
||||
return self._get_public_room_list(
|
||||
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||
)
|
||||
|
||||
result = self.response_cache.get((limit, since_token))
|
||||
if not result:
|
||||
result = self.response_cache.set(
|
||||
(limit, since_token),
|
||||
self._get_public_room_list(limit, since_token)
|
||||
self._get_public_room_list(
|
||||
limit, since_token, network_tuple=network_tuple
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_public_room_list(self, limit=None, since_token=None,
|
||||
search_filter=None):
|
||||
search_filter=None,
|
||||
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||
if since_token and since_token != "END":
|
||||
since_token = RoomListNextBatch.from_token(since_token)
|
||||
else:
|
||||
|
@ -73,14 +99,15 @@ class RoomListHandler(BaseHandler):
|
|||
current_public_id = yield self.store.get_current_public_room_stream_id()
|
||||
public_room_stream_id = since_token.public_room_stream_id
|
||||
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
|
||||
public_room_stream_id, current_public_id
|
||||
public_room_stream_id, current_public_id,
|
||||
network_tuple=network_tuple,
|
||||
)
|
||||
else:
|
||||
stream_token = yield self.store.get_room_max_stream_ordering()
|
||||
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
|
||||
|
||||
room_ids = yield self.store.get_public_room_ids_at_stream_id(
|
||||
public_room_stream_id
|
||||
public_room_stream_id, network_tuple=network_tuple,
|
||||
)
|
||||
|
||||
# We want to return rooms in a particular order: the number of joined
|
||||
|
@ -311,7 +338,8 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
||||
search_filter=None):
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None,):
|
||||
if search_filter:
|
||||
# We currently don't support searching across federation, so we have
|
||||
# to do it manually without pagination
|
||||
|
@ -320,6 +348,8 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
res = yield self._get_remote_list_cached(
|
||||
server_name, limit=limit, since_token=since_token,
|
||||
include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
|
||||
if search_filter:
|
||||
|
@ -332,22 +362,30 @@ class RoomListHandler(BaseHandler):
|
|||
defer.returnValue(res)
|
||||
|
||||
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
||||
search_filter=None):
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None,):
|
||||
repl_layer = self.hs.get_replication_layer()
|
||||
if search_filter:
|
||||
# We can't cache when asking for search
|
||||
return repl_layer.get_public_rooms(
|
||||
server_name, limit=limit, since_token=since_token,
|
||||
search_filter=search_filter,
|
||||
search_filter=search_filter, include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
|
||||
result = self.remote_response_cache.get((server_name, limit, since_token))
|
||||
key = (
|
||||
server_name, limit, since_token, include_all_networks,
|
||||
third_party_instance_id,
|
||||
)
|
||||
result = self.remote_response_cache.get(key)
|
||||
if not result:
|
||||
result = self.remote_response_cache.set(
|
||||
(server_name, limit, since_token),
|
||||
key,
|
||||
repl_layer.get_public_rooms(
|
||||
server_name, limit=limit, since_token=since_token,
|
||||
search_filter=search_filter,
|
||||
include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -277,6 +277,7 @@ class SyncHandler(object):
|
|||
"""
|
||||
with Measure(self.clock, "load_filtered_recents"):
|
||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||
block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
|
||||
|
||||
if recents is None or newly_joined_room or timeline_limit < len(recents):
|
||||
limited = True
|
||||
|
@ -293,7 +294,7 @@ class SyncHandler(object):
|
|||
else:
|
||||
recents = []
|
||||
|
||||
if not limited:
|
||||
if not limited or block_all_timeline:
|
||||
defer.returnValue(TimelineBatch(
|
||||
events=recents,
|
||||
prev_batch=now_token,
|
||||
|
@ -509,6 +510,7 @@ class SyncHandler(object):
|
|||
Returns:
|
||||
Deferred(SyncResult)
|
||||
"""
|
||||
logger.info("Calculating sync response for %r", sync_config.user)
|
||||
|
||||
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
||||
# this is due to some of the underlying streams not supporting the ability
|
||||
|
@ -531,9 +533,14 @@ class SyncHandler(object):
|
|||
)
|
||||
newly_joined_rooms, newly_joined_users = res
|
||||
|
||||
yield self._generate_sync_entry_for_presence(
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||
block_all_presence_data = (
|
||||
since_token is None and
|
||||
sync_config.filter_collection.blocks_all_presence()
|
||||
)
|
||||
if not block_all_presence_data:
|
||||
yield self._generate_sync_entry_for_presence(
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||
|
||||
|
@ -569,16 +576,20 @@ class SyncHandler(object):
|
|||
# We only delete messages when a new message comes in, but that's
|
||||
# fine so long as we delete them at some point.
|
||||
|
||||
logger.debug("Deleting messages up to %d", since_stream_id)
|
||||
yield self.store.delete_messages_for_device(
|
||||
deleted = yield self.store.delete_messages_for_device(
|
||||
user_id, device_id, since_stream_id
|
||||
)
|
||||
logger.info("Deleted %d to-device messages up to %d",
|
||||
deleted, since_stream_id)
|
||||
|
||||
logger.debug("Getting messages up to %d", now_token.to_device_key)
|
||||
messages, stream_id = yield self.store.get_new_messages_for_device(
|
||||
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||
)
|
||||
logger.debug("Got messages up to %d: %r", stream_id, messages)
|
||||
|
||||
logger.info(
|
||||
"Returning %d to-device messages between %d and %d (current token: %d)",
|
||||
len(messages), since_stream_id, stream_id, now_token.to_device_key
|
||||
)
|
||||
sync_result_builder.now_token = now_token.copy_and_replace(
|
||||
"to_device_key", stream_id
|
||||
)
|
||||
|
@ -709,13 +720,20 @@ class SyncHandler(object):
|
|||
`(newly_joined_rooms, newly_joined_users)`
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
|
||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||
sync_result_builder.sync_config,
|
||||
now_token=sync_result_builder.now_token,
|
||||
since_token=sync_result_builder.since_token,
|
||||
block_all_room_ephemeral = (
|
||||
sync_result_builder.since_token is None and
|
||||
sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
|
||||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
if block_all_room_ephemeral:
|
||||
ephemeral_by_room = {}
|
||||
else:
|
||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||
sync_result_builder.sync_config,
|
||||
now_token=sync_result_builder.now_token,
|
||||
since_token=sync_result_builder.since_token,
|
||||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", user_id=user_id,
|
||||
|
|
|
@ -55,9 +55,9 @@ class TypingHandler(object):
|
|||
self.clock = hs.get_clock()
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation = hs.get_federation_sender()
|
||||
|
||||
self.federation.register_edu_handler("m.typing", self._recv_edu)
|
||||
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
|
||||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
import cgi
|
||||
import simplejson as json
|
||||
import logging
|
||||
import random
|
||||
|
@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
c_type = response.headers.getRawHeaders("Content-Type")
|
||||
|
||||
if "application/json" not in c_type:
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json"
|
||||
)
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
defer.returnValue(json.loads(body))
|
||||
|
@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
c_type = response.headers.getRawHeaders("Content-Type")
|
||||
|
||||
if "application/json" not in c_type:
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json"
|
||||
)
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
|
@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
c_type = response.headers.getRawHeaders("Content-Type")
|
||||
|
||||
if "application/json" not in c_type:
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json"
|
||||
)
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
|
||||
|
@ -525,3 +511,29 @@ def _flatten_response_never_received(e):
|
|||
)
|
||||
else:
|
||||
return "%s: %s" % (type(e).__name__, e.message,)
|
||||
|
||||
|
||||
def check_content_type_is_json(headers):
|
||||
"""
|
||||
Check that a set of HTTP headers have a Content-Type header, and that it
|
||||
is application/json.
|
||||
|
||||
Args:
|
||||
headers (twisted.web.http_headers.Headers): headers to check
|
||||
|
||||
Raises:
|
||||
RuntimeError if the
|
||||
|
||||
"""
|
||||
c_type = headers.getRawHeaders("Content-Type")
|
||||
if c_type is None:
|
||||
raise RuntimeError(
|
||||
"No Content-Type header"
|
||||
)
|
||||
|
||||
c_type = c_type[0] # only the first header
|
||||
val, options = cgi.parse_header(c_type)
|
||||
if val != "application/json":
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json: was '%s'" % c_type
|
||||
)
|
||||
|
|
|
@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
|
|||
parameter is present and not one of "true" or "false".
|
||||
"""
|
||||
|
||||
if name in request.args:
|
||||
return parse_boolean_from_args(request.args, name, default, required)
|
||||
|
||||
|
||||
def parse_boolean_from_args(args, name, default=None, required=False):
|
||||
if name in args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
}[request.args[name][0]]
|
||||
}[args[name][0]]
|
||||
except:
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of"
|
||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError
|
||||
|
||||
from synapse.util import DeferredTimedOutError
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
|
@ -143,6 +144,12 @@ class Notifier(object):
|
|||
|
||||
self.clock = hs.get_clock()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
if hs.should_send_federation():
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
else:
|
||||
self.federation_sender = None
|
||||
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
self.clock.looping_call(
|
||||
|
@ -220,6 +227,9 @@ class Notifier(object):
|
|||
# poke any interested application service.
|
||||
self.appservice_handler.notify_interested_services(room_stream_id)
|
||||
|
||||
if self.federation_sender:
|
||||
self.federation_sender.notify_new_events(room_stream_id)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
|
||||
|
@ -285,14 +295,7 @@ class Notifier(object):
|
|||
|
||||
result = None
|
||||
if timeout:
|
||||
# Will be set to a _NotificationListener that we'll be waiting on.
|
||||
# Allows us to cancel it.
|
||||
listener = None
|
||||
|
||||
def timed_out():
|
||||
if listener:
|
||||
listener.deferred.cancel()
|
||||
timer = self.clock.call_later(timeout / 1000., timed_out)
|
||||
end_time = self.clock.time_msec() + timeout
|
||||
|
||||
prev_token = from_token
|
||||
while not result:
|
||||
|
@ -303,6 +306,10 @@ class Notifier(object):
|
|||
if result:
|
||||
break
|
||||
|
||||
now = self.clock.time_msec()
|
||||
if end_time <= now:
|
||||
break
|
||||
|
||||
# Now we wait for the _NotifierUserStream to be told there
|
||||
# is a new token.
|
||||
# We need to supply the token we supplied to callback so
|
||||
|
@ -310,11 +317,14 @@ class Notifier(object):
|
|||
prev_token = current_token
|
||||
listener = user_stream.new_listener(prev_token)
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
yield self.clock.time_bound_deferred(
|
||||
listener.deferred,
|
||||
time_out=(end_time - now) / 1000.
|
||||
)
|
||||
except DeferredTimedOutError:
|
||||
break
|
||||
except defer.CancelledError:
|
||||
break
|
||||
|
||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
||||
else:
|
||||
current_token = user_stream.current_token
|
||||
result = yield callback(from_token, current_token)
|
||||
|
@ -483,22 +493,27 @@ class Notifier(object):
|
|||
"""
|
||||
listener = _NotificationListener(None)
|
||||
|
||||
def timed_out():
|
||||
listener.deferred.cancel()
|
||||
end_time = self.clock.time_msec() + timeout
|
||||
|
||||
timer = self.clock.call_later(timeout / 1000., timed_out)
|
||||
while True:
|
||||
listener.deferred = self.replication_deferred.observe()
|
||||
result = yield callback()
|
||||
if result:
|
||||
break
|
||||
|
||||
now = self.clock.time_msec()
|
||||
if end_time <= now:
|
||||
break
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
yield self.clock.time_bound_deferred(
|
||||
listener.deferred,
|
||||
time_out=(end_time - now) / 1000.
|
||||
)
|
||||
except DeferredTimedOutError:
|
||||
break
|
||||
except defer.CancelledError:
|
||||
break
|
||||
|
||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
|
|
@ -87,12 +87,12 @@ class BulkPushRuleEvaluator:
|
|||
condition_cache = {}
|
||||
|
||||
for uid, rules in self.rules_by_user.items():
|
||||
display_name = None
|
||||
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
|
||||
if member_ev_id:
|
||||
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
|
||||
if member_ev:
|
||||
display_name = member_ev.content.get("displayname", None)
|
||||
display_name = room_members.get(uid, {}).get("display_name", None)
|
||||
if not display_name:
|
||||
# Handle the case where we are pushing a membership event to
|
||||
# that user, as they might not be already joined.
|
||||
if event.type == EventTypes.Member and event.state_key == uid:
|
||||
display_name = event.content.get("displayname", None)
|
||||
|
||||
filtered = filtered_by_user[uid]
|
||||
if len(filtered) == 0:
|
||||
|
|
|
@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = {
|
|||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||
},
|
||||
"ldap": {
|
||||
"ldap3>=1.0": ["ldap3>=1.0"],
|
||||
"matrix-synapse-ldap3": {
|
||||
"matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
|
||||
},
|
||||
"psutil": {
|
||||
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||
|
@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False):
|
|||
def github_link(project, version, egg):
|
||||
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
||||
|
||||
|
||||
DEPENDENCY_LINKS = {
|
||||
}
|
||||
|
||||
|
@ -156,6 +157,7 @@ def list_requirements():
|
|||
result.append(requirement)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.stdout.writelines(req + "\n" for req in list_requirements())
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
|
||||
class ExpireCacheResource(Resource):
|
||||
"""
|
||||
HTTP endpoint for expiring storage caches.
|
||||
|
||||
POST /_synapse/replication/expire_cache HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"invalidate": [
|
||||
{
|
||||
"name": "func_name",
|
||||
"keys": ["key1", "key2"]
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self) # Resource is old-style, so no super()
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
for row in content["invalidate"]:
|
||||
name = row["name"]
|
||||
keys = tuple(row["keys"])
|
||||
|
||||
getattr(self.store, name).invalidate(keys)
|
||||
|
||||
respond_with_json_bytes(request, 200, "{}")
|
|
@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
|
|||
from synapse.http.server import request_handler, finish_request
|
||||
from synapse.replication.pusher_resource import PusherResource
|
||||
from synapse.replication.presence_resource import PresenceResource
|
||||
from synapse.replication.expire_cache import ExpireCacheResource
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -44,6 +45,7 @@ STREAM_NAMES = (
|
|||
("caches",),
|
||||
("to_device",),
|
||||
("public_rooms",),
|
||||
("federation",),
|
||||
)
|
||||
|
||||
|
||||
|
@ -116,11 +118,14 @@ class ReplicationResource(Resource):
|
|||
self.sources = hs.get_event_sources()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
self.notifier = hs.notifier
|
||||
self.clock = hs.get_clock()
|
||||
self.config = hs.get_config()
|
||||
|
||||
self.putChild("remove_pushers", PusherResource(hs))
|
||||
self.putChild("syncing_users", PresenceResource(hs))
|
||||
self.putChild("expire_cache", ExpireCacheResource(hs))
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
|
@ -134,6 +139,7 @@ class ReplicationResource(Resource):
|
|||
pushers_token = self.store.get_pushers_stream_token()
|
||||
caches_token = self.store.get_cache_stream_token()
|
||||
public_rooms_token = self.store.get_current_public_room_stream_id()
|
||||
federation_token = self.federation_sender.get_current_token()
|
||||
|
||||
defer.returnValue(_ReplicationToken(
|
||||
room_stream_token,
|
||||
|
@ -148,6 +154,7 @@ class ReplicationResource(Resource):
|
|||
caches_token,
|
||||
int(stream_token.to_device_key),
|
||||
int(public_rooms_token),
|
||||
int(federation_token),
|
||||
))
|
||||
|
||||
@request_handler()
|
||||
|
@ -164,8 +171,13 @@ class ReplicationResource(Resource):
|
|||
}
|
||||
request_streams["streams"] = parse_string(request, "streams")
|
||||
|
||||
federation_ack = parse_integer(request, "federation_ack", None)
|
||||
|
||||
def replicate():
|
||||
return self.replicate(request_streams, limit)
|
||||
return self.replicate(
|
||||
request_streams, limit,
|
||||
federation_ack=federation_ack
|
||||
)
|
||||
|
||||
writer = yield self.notifier.wait_for_replication(replicate, timeout)
|
||||
result = writer.finish()
|
||||
|
@ -183,7 +195,7 @@ class ReplicationResource(Resource):
|
|||
finish_request(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self, request_streams, limit):
|
||||
def replicate(self, request_streams, limit, federation_ack=None):
|
||||
writer = _Writer()
|
||||
current_token = yield self.current_replication_token()
|
||||
logger.debug("Replicating up to %r", current_token)
|
||||
|
@ -202,6 +214,7 @@ class ReplicationResource(Resource):
|
|||
yield self.caches(writer, current_token, limit, request_streams)
|
||||
yield self.to_device(writer, current_token, limit, request_streams)
|
||||
yield self.public_rooms(writer, current_token, limit, request_streams)
|
||||
self.federation(writer, current_token, limit, request_streams, federation_ack)
|
||||
self.streams(writer, current_token, request_streams)
|
||||
|
||||
logger.debug("Replicated %d rows", writer.total)
|
||||
|
@ -462,7 +475,24 @@ class ReplicationResource(Resource):
|
|||
)
|
||||
upto_token = _position_from_rows(public_rooms_rows, current_position)
|
||||
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
|
||||
"position", "room_id", "visibility"
|
||||
"position", "room_id", "visibility", "appservice_id", "network_id",
|
||||
), position=upto_token)
|
||||
|
||||
def federation(self, writer, current_token, limit, request_streams, federation_ack):
|
||||
if self.config.send_federation:
|
||||
return
|
||||
|
||||
current_position = current_token.federation
|
||||
|
||||
federation = request_streams.get("federation")
|
||||
|
||||
if federation is not None and federation != current_position:
|
||||
federation_rows = self.federation_sender.get_replication_rows(
|
||||
federation, limit, federation_ack=federation_ack,
|
||||
)
|
||||
upto_token = _position_from_rows(federation_rows, current_position)
|
||||
writer.write_header_and_rows("federation", federation_rows, (
|
||||
"position", "type", "content",
|
||||
), position=upto_token)
|
||||
|
||||
|
||||
|
@ -497,6 +527,7 @@ class _Writer(object):
|
|||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
|
||||
"federation",
|
||||
))):
|
||||
__slots__ = []
|
||||
|
||||
|
|
|
@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
|
|||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
def stream_positions(self):
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
|
@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
|
|||
logger.info("Got unexpected cache_func: %r", cache_func)
|
||||
self._cache_id_gen.advance(int(stream["position"]))
|
||||
return defer.succeed(None)
|
||||
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
txn.call_after(cache_func.invalidate, keys)
|
||||
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_invalidation_poke(self, cache_func, keys):
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(self.expire_cache_url, {
|
||||
"invalidate": [{
|
||||
"name": cache_func.__name__,
|
||||
"keys": list(keys),
|
||||
}]
|
||||
})
|
||||
except:
|
||||
logger.exception("Failed to poke on expire_cache")
|
||||
|
|
|
@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
|||
"DeviceInboxStreamChangeCache",
|
||||
self._device_inbox_id_gen.get_current_token()
|
||||
)
|
||||
self._device_federation_outbox_stream_cache = StreamChangeCache(
|
||||
"DeviceFederationOutboxStreamChangeCache",
|
||||
self._device_inbox_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
|
||||
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
|
||||
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
|
||||
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
|
||||
delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||
|
@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
|||
self._device_inbox_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
stream_id = row[0]
|
||||
user_id = row[1]
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
user_id, stream_id
|
||||
)
|
||||
entity = row[1]
|
||||
|
||||
if entity.startswith("@"):
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
entity, stream_id
|
||||
)
|
||||
else:
|
||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||
entity, stream_id
|
||||
)
|
||||
|
||||
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
||||
|
|
|
@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
|
|||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import ujson as json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# So, um, we want to borrow a load of functions intended for reading from
|
||||
# a DataStore, but we don't want to take functions that either write to the
|
||||
|
@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
|
||||
)
|
||||
|
||||
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
|
||||
|
||||
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
|
||||
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
|
@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
stream = result.get("events")
|
||||
if stream:
|
||||
self._stream_id_gen.advance(int(stream["position"]))
|
||||
|
||||
if stream["rows"]:
|
||||
logger.info("Got %d event rows", len(stream["rows"]))
|
||||
|
||||
for row in stream["rows"]:
|
||||
self._process_replication_row(
|
||||
row, backfilled=False, state_resets=state_resets
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.room import RoomStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
|
@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
|
|||
DataStore.get_current_public_room_stream_id.__func__
|
||||
)
|
||||
get_public_room_ids_at_stream_id = (
|
||||
DataStore.get_public_room_ids_at_stream_id.__func__
|
||||
RoomStore.__dict__["get_public_room_ids_at_stream_id"]
|
||||
)
|
||||
get_public_room_ids_at_stream_id_txn = (
|
||||
DataStore.get_public_room_ids_at_stream_id_txn.__func__
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.transactions import TransactionStore
|
||||
|
@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
|
|||
class TransactionStore(BaseSlavedStore):
|
||||
get_destination_retry_timings = TransactionStore.__dict__[
|
||||
"get_destination_retry_timings"
|
||||
].orig
|
||||
]
|
||||
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
||||
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
|
||||
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
|
||||
|
||||
# For now, don't record the destination rety timings
|
||||
def set_destination_retry_timings(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
prep_send_transaction = DataStore.prep_send_transaction.__func__
|
||||
delivered_txn = DataStore.delivered_txn.__func__
|
||||
|
|
|
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
|
|||
def register_servlets(hs, http_server):
|
||||
ClientDirectoryServer(hs).register(http_server)
|
||||
ClientDirectoryListServer(hs).register(http_server)
|
||||
ClientAppserviceDirectoryListServer(hs).register(http_server)
|
||||
|
||||
|
||||
class ClientDirectoryServer(ClientV1RestServlet):
|
||||
|
@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_PUT(self, request, network_id, room_id):
|
||||
content = parse_json_object_from_request(request)
|
||||
visibility = content.get("visibility", "public")
|
||||
return self._edit(request, network_id, room_id, visibility)
|
||||
|
||||
def on_DELETE(self, request, network_id, room_id):
|
||||
return self._edit(request, network_id, room_id, "private")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _edit(self, request, network_id, room_id, visibility):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if not requester.app_service:
|
||||
raise AuthError(
|
||||
403, "Only appservices can edit the appservice published room list"
|
||||
)
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||
requester.app_service.id, network_id, room_id, visibility,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
password=login_submission["password"],
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
device_id = yield self._register_device(
|
||||
registered_user_id, login_submission
|
||||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
else:
|
||||
|
|
|
@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
if "displayname" not in user_json:
|
||||
raise SynapseError(400, "Expected 'displayname' key.")
|
||||
|
||||
if "duration_seconds" not in user_json:
|
||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||
|
||||
localpart = user_json["localpart"].encode("utf-8")
|
||||
displayname = user_json["displayname"].encode("utf-8")
|
||||
duration_seconds = 0
|
||||
try:
|
||||
duration_seconds = int(user_json["duration_seconds"])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||
if user_json.get("password_hash") else None
|
||||
|
||||
|
@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
requester=requester,
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_in_ms=(duration_seconds * 1000),
|
||||
password_hash=password_hash
|
||||
)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
|
|||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.types import UserID, RoomID, RoomAlias
|
||||
from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
|
||||
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_string, parse_integer
|
||||
|
@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||
since_token = content.get("since", None)
|
||||
search_filter = content.get("filter", None)
|
||||
|
||||
include_all_networks = content.get("include_all_networks", False)
|
||||
third_party_instance_id = content.get("third_party_instance_id", None)
|
||||
|
||||
if include_all_networks:
|
||||
network_tuple = None
|
||||
if third_party_instance_id is not None:
|
||||
raise SynapseError(
|
||||
400, "Can't use include_all_networks with an explicit network"
|
||||
)
|
||||
elif third_party_instance_id is None:
|
||||
network_tuple = ThirdPartyInstanceID(None, None)
|
||||
else:
|
||||
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
if server:
|
||||
data = yield handler.get_remote_public_room_list(
|
||||
|
@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||
limit=limit,
|
||||
since_token=since_token,
|
||||
search_filter=search_filter,
|
||||
include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
else:
|
||||
data = yield handler.get_local_public_room_list(
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
search_filter=search_filter,
|
||||
network_tuple=network_tuple,
|
||||
)
|
||||
|
||||
defer.returnValue((200, data))
|
||||
|
@ -369,6 +386,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
}))
|
||||
|
||||
|
||||
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"joined": users_with_profile
|
||||
}))
|
||||
|
||||
|
||||
# TODO: Needs better unit testing
|
||||
class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
||||
|
@ -692,6 +727,22 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, results))
|
||||
|
||||
|
||||
class JoinedRoomsRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/joined_rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(JoinedRoomsRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
|
||||
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
|
||||
|
||||
|
||||
def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
||||
"""Registers a transaction-based path.
|
||||
|
||||
|
@ -727,6 +778,7 @@ def register_servlets(hs, http_server):
|
|||
RoomStateEventRestServlet(hs).register(http_server)
|
||||
RoomCreateRestServlet(hs).register(http_server)
|
||||
RoomMemberListRestServlet(hs).register(http_server)
|
||||
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
||||
RoomMessageListRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
RoomForgetRestServlet(hs).register(http_server)
|
||||
|
@ -738,4 +790,5 @@ def register_servlets(hs, http_server):
|
|||
RoomRedactEventRestServlet(hs).register(http_server)
|
||||
RoomTypingRestServlet(hs).register(http_server)
|
||||
SearchRestServlet(hs).register(http_server)
|
||||
JoinedRoomsRestServlet(hs).register(http_server)
|
||||
RoomEventContext(hs).register(http_server)
|
||||
|
|
|
@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
devices = yield self.device_handler.get_devices_by_user(
|
||||
requester.user.to_string()
|
||||
)
|
||||
|
@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
device = yield self.device_handler.get_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
|
@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
|
|
|
@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -94,10 +94,6 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
class KeyQueryServlet(RestServlet):
|
||||
"""
|
||||
GET /keys/query/<user_id> HTTP/1.1
|
||||
|
||||
GET /keys/query/<user_id>/<device_id> HTTP/1.1
|
||||
|
||||
POST /keys/query HTTP/1.1
|
||||
Content-Type: application/json
|
||||
{
|
||||
|
@ -131,11 +127,7 @@ class KeyQueryServlet(RestServlet):
|
|||
"""
|
||||
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/keys/query(?:"
|
||||
"/(?P<user_id>[^/]*)(?:"
|
||||
"/(?P<device_id>[^/]*)"
|
||||
")?"
|
||||
")?",
|
||||
"/keys/query$",
|
||||
releases=()
|
||||
)
|
||||
|
||||
|
@ -149,31 +141,16 @@ class KeyQueryServlet(RestServlet):
|
|||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
auth_user_id = requester.user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
result = yield self.e2e_keys_handler.query_devices(
|
||||
{"device_keys": {user_id: device_ids}},
|
||||
timeout,
|
||||
)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class OneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
|
||||
|
||||
POST /keys/claim HTTP/1.1
|
||||
{
|
||||
"one_time_keys": {
|
||||
|
@ -191,9 +168,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/keys/claim(?:/?|(?:/"
|
||||
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
|
||||
")?)",
|
||||
"/keys/claim$",
|
||||
releases=()
|
||||
)
|
||||
|
||||
|
@ -203,18 +178,8 @@ class OneTimeKeyServlet(RestServlet):
|
|||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
||||
timeout,
|
||||
)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||
|
|
|
@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
super(ReceiptRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.receipts_handler = hs.get_handlers().receipts_handler
|
||||
self.receipts_handler = hs.get_receipts_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
|
@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
kind = "user"
|
||||
if "kind" in request.args:
|
||||
kind = request.args["kind"][0]
|
||||
|
||||
if kind == "guest":
|
||||
ret = yield self._do_guest_registration()
|
||||
ret = yield self._do_guest_registration(body)
|
||||
defer.returnValue(ret)
|
||||
return
|
||||
elif kind != "user":
|
||||
|
@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet):
|
|||
"Do not understand membership kind: %s" % (kind,)
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# we do basic sanity checks here because the auth layer will store these
|
||||
# in sessions. Pull out the username/password provided to us.
|
||||
desired_password = None
|
||||
|
@ -373,8 +374,7 @@ class RegisterRestServlet(RestServlet):
|
|||
def _create_registration_details(self, user_id, params):
|
||||
"""Complete registration of newly-registered user
|
||||
|
||||
Allocates device_id if one was not given; also creates access_token
|
||||
and refresh_token.
|
||||
Allocates device_id if one was not given; also creates access_token.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
|
@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet):
|
|||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token, refresh_token = (
|
||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||
access_token = (
|
||||
yield self.auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
|
@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet):
|
|||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
"device_id": device_id,
|
||||
})
|
||||
|
||||
|
@ -421,20 +420,28 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self):
|
||||
def _do_guest_registration(self, params):
|
||||
if not self.hs.config.allow_guest_access:
|
||||
defer.returnValue((403, "Guest access is disabled"))
|
||||
user_id, _ = yield self.registration_handler.register(
|
||||
generate_token=False,
|
||||
make_guest=True
|
||||
)
|
||||
|
||||
# we don't allow guests to specify their own device_id, because
|
||||
# we have nowhere to store it.
|
||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
access_token = self.auth_handler.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
||||
# so long as we don't return a refresh_token here.
|
||||
defer.returnValue((200, {
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}))
|
||||
|
|
|
@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _put(self, request, message_type, txn_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet):
|
|||
time_now = self.clock.time_msec()
|
||||
|
||||
joined = self.encode_joined(
|
||||
sync_result.joined, time_now, requester.access_token_id
|
||||
sync_result.joined, time_now, requester.access_token_id, filter.event_fields
|
||||
)
|
||||
|
||||
invited = self.encode_invited(
|
||||
|
@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
archived = self.encode_archived(
|
||||
sync_result.archived, time_now, requester.access_token_id
|
||||
sync_result.archived, time_now, requester.access_token_id, filter.event_fields
|
||||
)
|
||||
|
||||
response_content = {
|
||||
|
@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet):
|
|||
formatted.append(event)
|
||||
return {"events": formatted}
|
||||
|
||||
def encode_joined(self, rooms, time_now, token_id):
|
||||
def encode_joined(self, rooms, time_now, token_id, event_fields):
|
||||
"""
|
||||
Encode the joined rooms in a sync result
|
||||
|
||||
|
@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet):
|
|||
calculations
|
||||
token_id(int): ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
|
||||
event_fields(list<str>): List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
Returns:
|
||||
dict[str, dict[str, object]]: the joined rooms list, in our
|
||||
response format
|
||||
|
@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet):
|
|||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = self.encode_room(
|
||||
room, time_now, token_id
|
||||
room, time_now, token_id, only_fields=event_fields
|
||||
)
|
||||
|
||||
return joined
|
||||
|
@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return invited
|
||||
|
||||
def encode_archived(self, rooms, time_now, token_id):
|
||||
def encode_archived(self, rooms, time_now, token_id, event_fields):
|
||||
"""
|
||||
Encode the archived rooms in a sync result
|
||||
|
||||
|
@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet):
|
|||
calculations
|
||||
token_id(int): ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
|
||||
event_fields(list<str>): List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
Returns:
|
||||
dict[str, dict[str, object]]: The invited rooms list, in our
|
||||
response format
|
||||
|
@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet):
|
|||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = self.encode_room(
|
||||
room, time_now, token_id, joined=False
|
||||
room, time_now, token_id, joined=False, only_fields=event_fields
|
||||
)
|
||||
|
||||
return joined
|
||||
|
||||
@staticmethod
|
||||
def encode_room(room, time_now, token_id, joined=True):
|
||||
def encode_room(room, time_now, token_id, joined=True, only_fields=None):
|
||||
"""
|
||||
Args:
|
||||
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
|
||||
|
@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet):
|
|||
of transaction IDs
|
||||
joined (bool): True if the user is joined to this room - will mean
|
||||
we handle ephemeral events
|
||||
|
||||
only_fields(list<str>): Optional. The list of event fields to include.
|
||||
Returns:
|
||||
dict[str, object]: the room, encoded in our response format
|
||||
"""
|
||||
|
@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet):
|
|||
return serialize_event(
|
||||
event, time_now, token_id=token_id,
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
only_event_fields=only_fields,
|
||||
)
|
||||
|
||||
state_dict = room.state
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
|
@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(TokenRefreshRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
refresh_result = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token
|
||||
)
|
||||
(user_id, new_refresh_token, device_id) = refresh_result
|
||||
new_access_token = yield auth_handler.issue_access_token(
|
||||
user_id, device_id
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}))
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing required key 'refresh_token'.")
|
||||
except StoreError:
|
||||
raise AuthError(403, "Did not recognize refresh token")
|
||||
raise AuthError(403, "tokenrefresh is no longer supported.")
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -381,7 +381,10 @@ def _calc_og(tree, media_uri):
|
|||
if 'og:title' not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
og['og:title'] = title[0].text.strip() if title else None
|
||||
if title and title[0].text is not None:
|
||||
og['og:title'] = title[0].text.strip()
|
||||
else:
|
||||
og['og:title'] = None
|
||||
|
||||
if 'og:image' not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
|
@ -543,5 +546,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
|||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + "…"
|
||||
description = new_desc.strip() + u"…"
|
||||
return description if description else None
|
||||
|
|
|
@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
|
|||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||
from synapse.federation.transport.client import TransportLayerClient
|
||||
from synapse.federation.transaction_queue import TransactionQueue
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
|
@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler
|
|||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
|
@ -124,6 +128,9 @@ class HomeServer(object):
|
|||
'http_client_context_factory',
|
||||
'simple_http_client',
|
||||
'media_repository',
|
||||
'federation_transport_client',
|
||||
'federation_sender',
|
||||
'receipts_handler',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
|
@ -265,9 +272,30 @@ class HomeServer(object):
|
|||
def build_media_repository(self):
|
||||
return MediaRepository(self)
|
||||
|
||||
def build_federation_transport_client(self):
|
||||
return TransportLayerClient(self)
|
||||
|
||||
def build_federation_sender(self):
|
||||
if self.should_send_federation():
|
||||
return TransactionQueue(self)
|
||||
elif not self.config.worker_app:
|
||||
return FederationRemoteSendQueue(self)
|
||||
else:
|
||||
raise Exception("Workers cannot send federation traffic")
|
||||
|
||||
def build_receipts_handler(self):
|
||||
return ReceiptsHandler(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
def should_send_federation(self):
|
||||
"Should this server be sending federation traffic directly?"
|
||||
return self.config.send_federation and (
|
||||
not self.config.worker_app
|
||||
or self.config.worker_app == "synapse.app.federation_sender"
|
||||
)
|
||||
|
||||
|
||||
def _make_dependency_method(depname):
|
||||
def _get(hs):
|
||||
|
|
|
@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
|
@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
)
|
||||
|
||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||
|
||||
super(DataStore, self).__init__(hs)
|
||||
|
||||
|
|
|
@ -561,12 +561,17 @@ class SQLBaseStore(object):
|
|||
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
||||
"SELECT %(retcol)s FROM %(table)s %(where)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
||||
"where": where,
|
||||
}
|
||||
|
||||
txn.execute(sql, keyvalues.values())
|
||||
|
@ -744,10 +749,15 @@ class SQLBaseStore(object):
|
|||
|
||||
@staticmethod
|
||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
update_sql = "UPDATE %s SET %s %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
|
|
|
@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
def get_app_services(self):
|
||||
return self.services_cache
|
||||
|
||||
def get_if_app_services_interested_in_user(self, user_id):
|
||||
"""Check if the user is one associated with an app service
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.is_interested_in_user(user_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_app_service_by_user_id(self, user_id):
|
||||
"""Retrieve an application service from their user ID.
|
||||
|
||||
|
|
|
@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore):
|
|||
device_id(str): The recipient device_id.
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves when the messages have been deleted.
|
||||
A deferred that resolves to the number of messages deleted.
|
||||
"""
|
||||
def delete_messages_for_device_txn(txn):
|
||||
sql = (
|
||||
|
@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore):
|
|||
" AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||
return txn.rowcount
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_messages_for_device", delete_messages_for_device_txn
|
||||
|
@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore):
|
|||
return defer.succeed([])
|
||||
|
||||
def get_all_new_device_messages_txn(txn):
|
||||
# We limit like this as we might have multiple rows per stream_id, and
|
||||
# we want to make sure we always get all entries for any stream_id
|
||||
# we return.
|
||||
upper_pos = min(current_pos, last_pos + limit)
|
||||
sql = (
|
||||
"SELECT stream_id FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY stream_id"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_pos, current_pos, limit))
|
||||
stream_ids = txn.fetchall()
|
||||
if not stream_ids:
|
||||
return []
|
||||
max_stream_id_in_limit = stream_ids[-1]
|
||||
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, device_id, message_json"
|
||||
"SELECT stream_id, user_id"
|
||||
" FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
)
|
||||
txn.execute(sql, (last_pos, max_stream_id_in_limit))
|
||||
return txn.fetchall()
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = (
|
||||
"SELECT stream_id, destination"
|
||||
" FROM device_federation_outbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows.extend(txn.fetchall())
|
||||
|
||||
return rows
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
|
|
|
@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
columns=["user_id", "stream_ordering"],
|
||||
)
|
||||
|
||||
self.register_background_index_update(
|
||||
"event_push_actions_highlights_index",
|
||||
index_name="event_push_actions_highlights_index",
|
||||
table="event_push_actions",
|
||||
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
|
||||
where_clause="highlight=1"
|
||||
)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
||||
"""
|
||||
Args:
|
||||
|
@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
topological_ordering, stream_ordering
|
||||
)
|
||||
|
||||
# First get number of notifications.
|
||||
# We don't need to put a notif=1 clause as all rows always have
|
||||
# notif=1
|
||||
sql = (
|
||||
"SELECT sum(notif), sum(highlight)"
|
||||
"SELECT count(*)"
|
||||
" FROM event_push_actions ea"
|
||||
" WHERE"
|
||||
" user_id = ?"
|
||||
|
@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return {
|
||||
"notify_count": row[0] or 0,
|
||||
"highlight_count": row[1] or 0,
|
||||
}
|
||||
else:
|
||||
return {"notify_count": 0, "highlight_count": 0}
|
||||
notify_count = row[0] if row else 0
|
||||
|
||||
# Now get the number of highlights
|
||||
sql = (
|
||||
"SELECT count(*)"
|
||||
" FROM event_push_actions ea"
|
||||
" WHERE"
|
||||
" highlight = 1"
|
||||
" AND user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND %s"
|
||||
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
row = txn.fetchone()
|
||||
highlight_count = row[0] if row else 0
|
||||
|
||||
return {
|
||||
"notify_count": notify_count,
|
||||
"highlight_count": highlight_count,
|
||||
}
|
||||
|
||||
ret = yield self.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
|
|
|
@ -54,6 +54,7 @@ def encode_json(json_object):
|
|||
else:
|
||||
return json.dumps(json_object, ensure_ascii=False)
|
||||
|
||||
|
||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||
# control how we batch/bulk fetch events from the database.
|
||||
# The values are plucked out of thing air to make initial sync run faster
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
import simplejson as json
|
||||
|
@ -24,6 +25,13 @@ import simplejson as json
|
|||
class FilteringStore(SQLBaseStore):
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def get_user_filter(self, user_localpart, filter_id):
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
try:
|
||||
int(filter_id)
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||
|
||||
def_json = yield self._simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 38
|
||||
SCHEMA_VERSION = 39
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState",
|
|||
status_msg (str): User set status message.
|
||||
"""
|
||||
|
||||
def as_dict(self):
|
||||
return dict(self._asdict())
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return UserPresenceState(**d)
|
||||
|
||||
def copy_and_replace(self, **kwargs):
|
||||
return self._replace(**kwargs)
|
||||
|
||||
|
|
|
@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore):
|
|||
event=event,
|
||||
)
|
||||
|
||||
local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
|
||||
# We ignore app service users for now. This is so that we don't fill
|
||||
# up the `get_if_users_have_pushers` cache with AS entries that we
|
||||
# know don't have pushers, nor even read receipts.
|
||||
local_users_in_room = set(
|
||||
u for u in users_in_room
|
||||
if self.hs.is_mine_id(u)
|
||||
and not self.get_if_app_services_interested_in_user(u)
|
||||
)
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||
local_users_in_room, on_invalidate=cache_context.invalidate,
|
||||
local_users_in_room,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
user_ids = set(
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
|
|
|
@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
|
||||
max_persisted_id = self._stream_id_gen.get_current_token()
|
||||
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||
|
||||
defer.returnValue((stream_id, max_persisted_id))
|
||||
|
||||
|
|
|
@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_refresh_token_to_user(self, user_id, token, device_id=None):
|
||||
"""Adds a refresh token for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
token (str): The new refresh token to add.
|
||||
device_id (str): ID of the device to associate with the access
|
||||
token
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
"""
|
||||
next_id = self._refresh_tokens_id_gen.get_next()
|
||||
|
||||
yield self._simple_insert(
|
||||
"refresh_tokens",
|
||||
{
|
||||
"id": next_id,
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"device_id": device_id,
|
||||
},
|
||||
desc="add_refresh_token_to_user",
|
||||
)
|
||||
|
||||
def register(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_localpart=None, admin=False):
|
||||
|
@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
token
|
||||
)
|
||||
|
||||
def exchange_refresh_token(self, refresh_token, token_generator):
|
||||
"""Exchange a refresh token for a new one.
|
||||
|
||||
Doing so invalidates the old refresh token - refresh tokens are single
|
||||
use.
|
||||
|
||||
Args:
|
||||
refresh_token (str): The refresh token of a user.
|
||||
token_generator (fn: str -> str): Function which, when given a
|
||||
user ID, returns a unique refresh token for that user. This
|
||||
function must never return the same value twice.
|
||||
Returns:
|
||||
tuple of (user_id, new_refresh_token, device_id)
|
||||
Raises:
|
||||
StoreError if no user was found with that refresh token.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"exchange_refresh_token",
|
||||
self._exchange_refresh_token,
|
||||
refresh_token,
|
||||
token_generator
|
||||
)
|
||||
|
||||
def _exchange_refresh_token(self, txn, old_token, token_generator):
|
||||
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
|
||||
txn.execute(sql, (old_token,))
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
raise StoreError(403, "Did not recognize refresh token")
|
||||
user_id = rows[0]["user_id"]
|
||||
device_id = rows[0]["device_id"]
|
||||
|
||||
# TODO(danielwh): Maybe perform a validation on the macaroon that
|
||||
# macaroon.user_id == user_id.
|
||||
|
||||
new_token = token_generator(user_id)
|
||||
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
|
||||
txn.execute(sql, (new_token, old_token,))
|
||||
|
||||
return user_id, new_token, device_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from .engines import PostgresEngine, Sqlite3Engine
|
||||
|
@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore):
|
|||
entries = self._simple_select_list_txn(
|
||||
txn,
|
||||
table="public_room_list_stream",
|
||||
keyvalues={"room_id": room_id},
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"appservice_id": None,
|
||||
"network_id": None,
|
||||
},
|
||||
retcols=("stream_id", "visibility"),
|
||||
)
|
||||
|
||||
|
@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore):
|
|||
"stream_id": next_id,
|
||||
"room_id": room_id,
|
||||
"visibility": is_public,
|
||||
"appservice_id": None,
|
||||
"network_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore):
|
|||
"set_room_is_public",
|
||||
set_room_is_public_txn, next_id,
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
|
||||
is_public):
|
||||
"""Edit the appservice/network specific public room list.
|
||||
|
||||
Each appservice can have a number of published room lists associated
|
||||
with them, keyed off of an appservice defined `network_id`, which
|
||||
basically represents a single instance of a bridge to a third party
|
||||
network.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
appservice_id (str)
|
||||
network_id (str)
|
||||
is_public (bool): Whether to publish or unpublish the room from the
|
||||
list.
|
||||
"""
|
||||
def set_room_is_public_appservice_txn(txn, next_id):
|
||||
if is_public:
|
||||
try:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="appservice_room_list",
|
||||
values={
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
"room_id": room_id
|
||||
},
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
# We've already inserted, nothing to do.
|
||||
return
|
||||
else:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="appservice_room_list",
|
||||
keyvalues={
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
"room_id": room_id
|
||||
},
|
||||
)
|
||||
|
||||
entries = self._simple_select_list_txn(
|
||||
txn,
|
||||
table="public_room_list_stream",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
},
|
||||
retcols=("stream_id", "visibility"),
|
||||
)
|
||||
|
||||
entries.sort(key=lambda r: r["stream_id"])
|
||||
|
||||
add_to_stream = True
|
||||
if entries:
|
||||
add_to_stream = bool(entries[-1]["visibility"]) != is_public
|
||||
|
||||
if add_to_stream:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="public_room_list_stream",
|
||||
values={
|
||||
"stream_id": next_id,
|
||||
"room_id": room_id,
|
||||
"visibility": is_public,
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
}
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn, next_id,
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
def get_public_room_ids(self):
|
||||
return self._simple_select_onecol(
|
||||
|
@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore):
|
|||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def get_public_room_ids_at_stream_id(self, stream_id):
|
||||
@cached(num_args=2, max_entries=100)
|
||||
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
|
||||
"""Get pulbic rooms for a particular list, or across all lists.
|
||||
|
||||
Args:
|
||||
stream_id (int)
|
||||
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
|
||||
means the main list, None means all lsits.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_public_room_ids_at_stream_id",
|
||||
self.get_public_room_ids_at_stream_id_txn, stream_id
|
||||
self.get_public_room_ids_at_stream_id_txn,
|
||||
stream_id, network_tuple=network_tuple
|
||||
)
|
||||
|
||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
|
||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
||||
network_tuple):
|
||||
return {
|
||||
rm
|
||||
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
|
||||
for rm, vis in self.get_published_at_stream_id_txn(
|
||||
txn, stream_id, network_tuple=network_tuple
|
||||
).items()
|
||||
if vis
|
||||
}
|
||||
|
||||
def get_published_at_stream_id_txn(self, txn, stream_id):
|
||||
sql = ("""
|
||||
SELECT room_id, visibility FROM public_room_list_stream
|
||||
INNER JOIN (
|
||||
SELECT room_id, max(stream_id) AS stream_id
|
||||
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
||||
if network_tuple:
|
||||
# We want to get from a particular list. No aggregation required.
|
||||
|
||||
sql = ("""
|
||||
SELECT room_id, visibility FROM public_room_list_stream
|
||||
INNER JOIN (
|
||||
SELECT room_id, max(stream_id) AS stream_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id <= ? %s
|
||||
GROUP BY room_id
|
||||
) grouped USING (room_id, stream_id)
|
||||
""")
|
||||
|
||||
if network_tuple.appservice_id is not None:
|
||||
txn.execute(
|
||||
sql % ("AND appservice_id = ? AND network_id = ?",),
|
||||
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
||||
)
|
||||
else:
|
||||
txn.execute(
|
||||
sql % ("AND appservice_id IS NULL",),
|
||||
(stream_id,)
|
||||
)
|
||||
return dict(txn.fetchall())
|
||||
else:
|
||||
# We want to get from all lists, so we need to aggregate the results
|
||||
|
||||
logger.info("Executing full list")
|
||||
|
||||
sql = ("""
|
||||
SELECT room_id, visibility
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id <= ?
|
||||
GROUP BY room_id
|
||||
) grouped USING (room_id, stream_id)
|
||||
""")
|
||||
INNER JOIN (
|
||||
SELECT
|
||||
room_id, max(stream_id) AS stream_id, appservice_id,
|
||||
network_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id <= ?
|
||||
GROUP BY room_id, appservice_id, network_id
|
||||
) grouped USING (room_id, stream_id)
|
||||
""")
|
||||
|
||||
txn.execute(sql, (stream_id,))
|
||||
return dict(txn.fetchall())
|
||||
txn.execute(
|
||||
sql,
|
||||
(stream_id,)
|
||||
)
|
||||
|
||||
def get_public_room_changes(self, prev_stream_id, new_stream_id):
|
||||
results = {}
|
||||
# A room is visible if its visible on any list.
|
||||
for room_id, visibility in txn.fetchall():
|
||||
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||
|
||||
return results
|
||||
|
||||
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
||||
network_tuple):
|
||||
def get_public_room_changes_txn(txn):
|
||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
|
||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
||||
txn, prev_stream_id, network_tuple
|
||||
)
|
||||
|
||||
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
|
||||
now_rooms_dict = self.get_published_at_stream_id_txn(
|
||||
txn, new_stream_id, network_tuple
|
||||
)
|
||||
|
||||
now_rooms_visible = set(
|
||||
rm for rm, vis in now_rooms_dict.items() if vis
|
||||
|
@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore):
|
|||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||
def get_all_new_public_rooms(txn):
|
||||
sql = ("""
|
||||
SELECT stream_id, room_id, visibility FROM public_room_list_stream
|
||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id > ? AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes
|
|||
from synapse.types import get_domain_from_id
|
||||
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,7 +35,15 @@ RoomsForUser = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
||||
|
||||
|
||||
class RoomMemberStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberStore, self).__init__(hs)
|
||||
self.register_background_update_handler(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||
)
|
||||
|
||||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database.
|
||||
|
@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"display_name": event.content.get("displayname", None),
|
||||
"avatar_url": event.content.get("avatar_url", None),
|
||||
}
|
||||
for event in events
|
||||
]
|
||||
|
@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=member_event_ids,
|
||||
retcols=['user_id'],
|
||||
retcols=['user_id', 'display_name', 'avatar_url'],
|
||||
keyvalues={
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
|
@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore):
|
|||
desc="_get_joined_users_from_context",
|
||||
)
|
||||
|
||||
users_in_room = set(row["user_id"] for row in rows)
|
||||
users_in_room = {
|
||||
row["user_id"]: {
|
||||
"display_name": row["display_name"],
|
||||
"avatar_url": row["avatar_url"],
|
||||
}
|
||||
for row in rows
|
||||
}
|
||||
|
||||
if event is not None and event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
if event.event_id in member_event_ids:
|
||||
users_in_room.add(event.state_key)
|
||||
users_in_room[event.state_key] = {
|
||||
"display_name": event.content.get("displayname", None),
|
||||
"avatar_url": event.content.get("avatar_url", None),
|
||||
}
|
||||
|
||||
defer.returnValue(users_in_room)
|
||||
|
||||
|
@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore):
|
|||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_add_membership_profile(self, progress, batch_size):
|
||||
target_min_stream_id = progress.get(
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||
)
|
||||
max_stream_id = progress.get(
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||
)
|
||||
|
||||
INSERT_CLUMP_SIZE = 1000
|
||||
|
||||
def add_membership_profile_txn(txn):
|
||||
sql = ("""
|
||||
SELECT stream_ordering, event_id, events.room_id, content
|
||||
FROM events
|
||||
INNER JOIN room_memberships USING (event_id)
|
||||
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||
AND type = 'm.room.member'
|
||||
ORDER BY stream_ordering DESC
|
||||
LIMIT ?
|
||||
""")
|
||||
|
||||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
min_stream_id = rows[-1]["stream_ordering"]
|
||||
|
||||
to_update = []
|
||||
for row in rows:
|
||||
event_id = row["event_id"]
|
||||
room_id = row["room_id"]
|
||||
try:
|
||||
content = json.loads(row["content"])
|
||||
except:
|
||||
continue
|
||||
|
||||
display_name = content.get("displayname", None)
|
||||
avatar_url = content.get("avatar_url", None)
|
||||
|
||||
if display_name or avatar_url:
|
||||
to_update.append((
|
||||
display_name, avatar_url, event_id, room_id
|
||||
))
|
||||
|
||||
to_update_sql = ("""
|
||||
UPDATE room_memberships SET display_name = ?, avatar_url = ?
|
||||
WHERE event_id = ? AND room_id = ?
|
||||
""")
|
||||
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
|
||||
clump = to_update[index:index + INSERT_CLUMP_SIZE]
|
||||
txn.executemany(to_update_sql, clump)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
"max_stream_id_exclusive": min_stream_id,
|
||||
}
|
||||
|
||||
self._background_update_progress_txn(
|
||||
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
result = yield self.runInteraction(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE appservice_room_list(
|
||||
appservice_id TEXT NOT NULL,
|
||||
network_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Each appservice can have multiple published room lists associated with them,
|
||||
-- keyed of a particular network_id
|
||||
CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
|
||||
appservice_id, network_id, room_id
|
||||
);
|
||||
|
||||
ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
|
||||
ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;
|
|
@ -0,0 +1,16 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id);
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('event_push_actions_highlights_index', '{}');
|
|
@ -0,0 +1,22 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE federation_stream_position(
|
||||
type TEXT NOT NULL,
|
||||
stream_id INTEGER NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1);
|
||||
INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events;
|
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE room_memberships ADD COLUMN display_name TEXT;
|
||||
ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT;
|
||||
|
||||
INSERT into background_updates (update_name, progress_json)
|
||||
VALUES ('room_membership_profile_update', '{}');
|
|
@ -653,7 +653,10 @@ class StateStore(SQLBaseStore):
|
|||
else:
|
||||
state_dict = results[group]
|
||||
|
||||
state_dict.update(group_state_dict)
|
||||
state_dict.update({
|
||||
(intern_string(k[0]), intern_string(k[1])): v
|
||||
for k, v in group_state_dict.items()
|
||||
})
|
||||
|
||||
self._state_group_cache.update(
|
||||
cache_seq_num,
|
||||
|
|
|
@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore):
|
|||
def get_room_max_stream_ordering(self):
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_room_min_stream_ordering(self):
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
def get_stream_token_for_event(self, event_id):
|
||||
"""The stream token for an event
|
||||
Args:
|
||||
|
@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore):
|
|||
"token": end_token,
|
||||
},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_new_events_stream(self, from_id, current_id, limit):
|
||||
"""Get all new events"""
|
||||
|
||||
def get_all_new_events_stream_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id"
|
||||
" FROM events AS e"
|
||||
" WHERE"
|
||||
" ? < e.stream_ordering AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (from_id, current_id, limit))
|
||||
rows = txn.fetchall()
|
||||
|
||||
upper_bound = current_id
|
||||
if len(rows) == limit:
|
||||
upper_bound = rows[-1][0]
|
||||
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_all_new_events_stream", get_all_new_events_stream_txn,
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue((upper_bound, events))
|
||||
|
||||
def get_federation_out_pos(self, typ):
|
||||
return self._simple_select_one_onecol(
|
||||
table="federation_stream_position",
|
||||
retcol="stream_id",
|
||||
keyvalues={"type": typ},
|
||||
desc="get_federation_out_pos"
|
||||
)
|
||||
|
||||
def update_federation_out_pos(self, typ, stream_id):
|
||||
return self._simple_update_one(
|
||||
table="federation_stream_position",
|
||||
keyvalues={"type": typ},
|
||||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_federation_out_pos",
|
||||
)
|
||||
|
|
|
@ -200,25 +200,48 @@ class TransactionStore(SQLBaseStore):
|
|||
|
||||
def _set_destination_retry_timings(self, txn, destination,
|
||||
retry_last_ts, retry_interval):
|
||||
txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
|
||||
self.database_engine.lock_table(txn, "destinations")
|
||||
|
||||
self._simple_upsert_txn(
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_destination_retry_timings, (destination,)
|
||||
)
|
||||
|
||||
# We need to be careful here as the data may have changed from under us
|
||||
# due to a worker setting the timings.
|
||||
|
||||
prev_row = self._simple_select_one_txn(
|
||||
txn,
|
||||
"destinations",
|
||||
table="destinations",
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
values={
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
},
|
||||
insertion_values={
|
||||
"destination": destination,
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
}
|
||||
retcols=("retry_last_ts", "retry_interval"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if not prev_row:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="destinations",
|
||||
values={
|
||||
"destination": destination,
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
}
|
||||
)
|
||||
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
"destinations",
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
updatevalues={
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
},
|
||||
)
|
||||
|
||||
def get_destinations_needing_retry(self):
|
||||
"""Get all destinations which are due a retry for sending a transaction.
|
||||
|
||||
|
|
|
@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
class ThirdPartyInstanceID(
|
||||
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
|
||||
):
|
||||
# Deny iteration because it will bite you if you try to create a singleton
|
||||
# set by:
|
||||
# users = set(user)
|
||||
def __iter__(self):
|
||||
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
||||
|
||||
# Because this class is a namedtuple of strings, it is deeply immutable.
|
||||
def __copy__(self):
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, s):
|
||||
bits = s.split("|", 2)
|
||||
if len(bits) != 2:
|
||||
raise SynapseError(400, "Invalid ID %r" % (s,))
|
||||
|
||||
return cls(appservice_id=bits[0], network_id=bits[1])
|
||||
|
||||
def to_string(self):
|
||||
return "%s|%s" % (self.appservice_id, self.network_id,)
|
||||
|
||||
__str__ = to_string
|
||||
|
||||
@classmethod
|
||||
def create(cls, appservice_id, network_id,):
|
||||
return cls(appservice_id=appservice_id, network_id=network_id)
|
||||
|
|
|
@ -24,6 +24,11 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeferredTimedOutError(SynapseError):
|
||||
def __init__(self):
|
||||
super(SynapseError).__init__(504, "Timed out")
|
||||
|
||||
|
||||
def unwrapFirstError(failure):
|
||||
# defer.gatherResults and DeferredLists wrap failures.
|
||||
failure.trap(defer.FirstError)
|
||||
|
@ -89,7 +94,7 @@ class Clock(object):
|
|||
|
||||
def timed_out_fn():
|
||||
try:
|
||||
ret_deferred.errback(SynapseError(504, "Timed out"))
|
||||
ret_deferred.errback(DeferredTimedOutError())
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -197,6 +197,64 @@ class Linearizer(object):
|
|||
defer.returnValue(_ctx_manager())
|
||||
|
||||
|
||||
class Limiter(object):
|
||||
"""Limits concurrent access to resources based on a key. Useful to ensure
|
||||
only a few thing happen at a time on a given resource.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield limiter.queue("test_key")):
|
||||
# do some work.
|
||||
|
||||
"""
|
||||
def __init__(self, max_count):
|
||||
"""
|
||||
Args:
|
||||
max_count(int): The maximum number of concurrent access
|
||||
"""
|
||||
self.max_count = max_count
|
||||
|
||||
# key_to_defer is a map from the key to a 2 element list where
|
||||
# the first element is the number of things executing
|
||||
# the second element is a list of deferreds for the things blocked from
|
||||
# executing.
|
||||
self.key_to_defer = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def queue(self, key):
|
||||
entry = self.key_to_defer.setdefault(key, [0, []])
|
||||
|
||||
# If the number of things executing is greater than the maximum
|
||||
# then add a deferred to the list of blocked items
|
||||
# When on of the things currently executing finishes it will callback
|
||||
# this item so that it can continue executing.
|
||||
if entry[0] >= self.max_count:
|
||||
new_defer = defer.Deferred()
|
||||
entry[1].append(new_defer)
|
||||
with PreserveLoggingContext():
|
||||
yield new_defer
|
||||
|
||||
entry[0] += 1
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# We've finished executing so check if there are any things
|
||||
# blocked waiting to execute and start one of them
|
||||
entry[0] -= 1
|
||||
try:
|
||||
entry[1].pop(0).callback(None)
|
||||
except IndexError:
|
||||
# If nothing else is executing for this key then remove it
|
||||
# from the map
|
||||
if entry[0] == 0:
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
|
||||
class ReadWriteLock(object):
|
||||
"""A deferred style read write lock.
|
||||
|
||||
|
|
|
@ -76,15 +76,26 @@ class JsonEncodedObject(object):
|
|||
d.update(self.unrecognized_keys)
|
||||
return d
|
||||
|
||||
def get_internal_dict(self):
|
||||
d = {
|
||||
k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
|
||||
if k in self.valid_keys
|
||||
}
|
||||
d.update(self.unrecognized_keys)
|
||||
return d
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||
|
||||
|
||||
def _encode(obj):
|
||||
def _encode(obj, internal=False):
|
||||
if type(obj) is list:
|
||||
return [_encode(o) for o in obj]
|
||||
return [_encode(o, internal=internal) for o in obj]
|
||||
|
||||
if isinstance(obj, JsonEncodedObject):
|
||||
return obj.get_dict()
|
||||
if internal:
|
||||
return obj.get_internal_dict()
|
||||
else:
|
||||
return obj.get_dict()
|
||||
|
||||
return obj
|
||||
|
|
|
@ -1,369 +0,0 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.types import UserID
|
||||
|
||||
import ldap3
|
||||
import ldap3.core.exceptions
|
||||
|
||||
import logging
|
||||
|
||||
try:
|
||||
import ldap3
|
||||
import ldap3.core.exceptions
|
||||
except ImportError:
|
||||
ldap3 = None
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LDAPMode(object):
|
||||
SIMPLE = "simple",
|
||||
SEARCH = "search",
|
||||
|
||||
LIST = (SIMPLE, SEARCH)
|
||||
|
||||
|
||||
class LdapAuthProvider(object):
|
||||
__version__ = "0.1"
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
self.account_handler = account_handler
|
||||
|
||||
if not ldap3:
|
||||
raise RuntimeError(
|
||||
'Missing ldap3 library. This is required for LDAP Authentication.'
|
||||
)
|
||||
|
||||
self.ldap_mode = config.mode
|
||||
self.ldap_uri = config.uri
|
||||
self.ldap_start_tls = config.start_tls
|
||||
self.ldap_base = config.base
|
||||
self.ldap_attributes = config.attributes
|
||||
if self.ldap_mode == LDAPMode.SEARCH:
|
||||
self.ldap_bind_dn = config.bind_dn
|
||||
self.ldap_bind_password = config.bind_password
|
||||
self.ldap_filter = config.filter
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_password(self, user_id, password):
|
||||
""" Attempt to authenticate a user against an LDAP Server
|
||||
and register an account if none exists.
|
||||
|
||||
Returns:
|
||||
True if authentication against LDAP was successful
|
||||
"""
|
||||
localpart = UserID.from_string(user_id).localpart
|
||||
|
||||
try:
|
||||
server = ldap3.Server(self.ldap_uri)
|
||||
logger.debug(
|
||||
"Attempting LDAP connection with %s",
|
||||
self.ldap_uri
|
||||
)
|
||||
|
||||
if self.ldap_mode == LDAPMode.SIMPLE:
|
||||
result, conn = self._ldap_simple_bind(
|
||||
server=server, localpart=localpart, password=password
|
||||
)
|
||||
logger.debug(
|
||||
'LDAP authentication method simple bind returned: %s (conn: %s)',
|
||||
result,
|
||||
conn
|
||||
)
|
||||
if not result:
|
||||
defer.returnValue(False)
|
||||
elif self.ldap_mode == LDAPMode.SEARCH:
|
||||
result, conn = self._ldap_authenticated_search(
|
||||
server=server, localpart=localpart, password=password
|
||||
)
|
||||
logger.debug(
|
||||
'LDAP auth method authenticated search returned: %s (conn: %s)',
|
||||
result,
|
||||
conn
|
||||
)
|
||||
if not result:
|
||||
defer.returnValue(False)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Invalid LDAP mode specified: {mode}'.format(
|
||||
mode=self.ldap_mode
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"User authenticated against LDAP server: %s",
|
||||
conn
|
||||
)
|
||||
except NameError:
|
||||
logger.warn(
|
||||
"Authentication method yielded no LDAP connection, aborting!"
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
# check if user with user_id exists
|
||||
if (yield self.account_handler.check_user_exists(user_id)):
|
||||
# exists, authentication complete
|
||||
conn.unbind()
|
||||
defer.returnValue(True)
|
||||
|
||||
else:
|
||||
# does not exist, fetch metadata for account creation from
|
||||
# existing ldap connection
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
|
||||
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
||||
query = "(&{filter}{user_filter})".format(
|
||||
filter=query,
|
||||
user_filter=self.ldap_filter
|
||||
)
|
||||
logger.debug(
|
||||
"ldap registration filter: %s",
|
||||
query
|
||||
)
|
||||
|
||||
conn.search(
|
||||
search_base=self.ldap_base,
|
||||
search_filter=query,
|
||||
attributes=[
|
||||
self.ldap_attributes['name'],
|
||||
self.ldap_attributes['mail']
|
||||
]
|
||||
)
|
||||
|
||||
if len(conn.response) == 1:
|
||||
attrs = conn.response[0]['attributes']
|
||||
mail = attrs[self.ldap_attributes['mail']][0]
|
||||
name = attrs[self.ldap_attributes['name']][0]
|
||||
|
||||
# create account
|
||||
user_id, access_token = (
|
||||
yield self.account_handler.register(localpart=localpart)
|
||||
)
|
||||
|
||||
# TODO: bind email, set displayname with data from ldap directory
|
||||
|
||||
logger.info(
|
||||
"Registration based on LDAP data was successful: %d: %s (%s, %)",
|
||||
user_id,
|
||||
localpart,
|
||||
name,
|
||||
mail
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
else:
|
||||
if len(conn.response) == 0:
|
||||
logger.warn("LDAP registration failed, no result.")
|
||||
else:
|
||||
logger.warn(
|
||||
"LDAP registration failed, too many results (%s)",
|
||||
len(conn.response)
|
||||
)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during ldap authentication: %s", e)
|
||||
defer.returnValue(False)
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
class _LdapConfig(object):
|
||||
pass
|
||||
|
||||
ldap_config = _LdapConfig()
|
||||
|
||||
ldap_config.enabled = config.get("enabled", False)
|
||||
|
||||
ldap_config.mode = LDAPMode.SIMPLE
|
||||
|
||||
# verify config sanity
|
||||
_require_keys(config, [
|
||||
"uri",
|
||||
"base",
|
||||
"attributes",
|
||||
])
|
||||
|
||||
ldap_config.uri = config["uri"]
|
||||
ldap_config.start_tls = config.get("start_tls", False)
|
||||
ldap_config.base = config["base"]
|
||||
ldap_config.attributes = config["attributes"]
|
||||
|
||||
if "bind_dn" in config:
|
||||
ldap_config.mode = LDAPMode.SEARCH
|
||||
_require_keys(config, [
|
||||
"bind_dn",
|
||||
"bind_password",
|
||||
])
|
||||
|
||||
ldap_config.bind_dn = config["bind_dn"]
|
||||
ldap_config.bind_password = config["bind_password"]
|
||||
ldap_config.filter = config.get("filter", None)
|
||||
|
||||
# verify attribute lookup
|
||||
_require_keys(config['attributes'], [
|
||||
"uid",
|
||||
"name",
|
||||
"mail",
|
||||
])
|
||||
|
||||
return ldap_config
|
||||
|
||||
def _ldap_simple_bind(self, server, localpart, password):
|
||||
""" Attempt a simple bind with the credentials
|
||||
given by the user against the LDAP server.
|
||||
|
||||
Returns True, LDAP3Connection
|
||||
if the bind was successful
|
||||
Returns False, None
|
||||
if an error occured
|
||||
"""
|
||||
|
||||
try:
|
||||
# bind with the the local users ldap credentials
|
||||
bind_dn = "{prop}={value},{base}".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart,
|
||||
base=self.ldap_base
|
||||
)
|
||||
conn = ldap3.Connection(server, bind_dn, password,
|
||||
authentication=ldap3.AUTH_SIMPLE)
|
||||
logger.debug(
|
||||
"Established LDAP connection in simple bind mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if conn.bind():
|
||||
# GOOD: bind okay
|
||||
logger.debug("LDAP Bind successful in simple bind mode.")
|
||||
return True, conn
|
||||
|
||||
# BAD: bind failed
|
||||
logger.info(
|
||||
"Binding against LDAP failed for '%s' failed: %s",
|
||||
localpart, conn.result['description']
|
||||
)
|
||||
conn.unbind()
|
||||
return False, None
|
||||
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during LDAP authentication: %s", e)
|
||||
return False, None
|
||||
|
||||
def _ldap_authenticated_search(self, server, localpart, password):
|
||||
""" Attempt to login with the preconfigured bind_dn
|
||||
and then continue searching and filtering within
|
||||
the base_dn
|
||||
|
||||
Returns (True, LDAP3Connection)
|
||||
if a single matching DN within the base was found
|
||||
that matched the filter expression, and with which
|
||||
a successful bind was achieved
|
||||
|
||||
The LDAP3Connection returned is the instance that was used to
|
||||
verify the password not the one using the configured bind_dn.
|
||||
Returns (False, None)
|
||||
if an error occured
|
||||
"""
|
||||
|
||||
try:
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
self.ldap_bind_dn,
|
||||
self.ldap_bind_password
|
||||
)
|
||||
logger.debug(
|
||||
"Established LDAP connection in search mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded LDAP connection in search mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if not conn.bind():
|
||||
logger.warn(
|
||||
"Binding against LDAP with `bind_dn` failed: %s",
|
||||
conn.result['description']
|
||||
)
|
||||
conn.unbind()
|
||||
return False, None
|
||||
|
||||
# construct search_filter like (uid=localpart)
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
if self.ldap_filter:
|
||||
# combine with the AND expression
|
||||
query = "(&{query}{filter})".format(
|
||||
query=query,
|
||||
filter=self.ldap_filter
|
||||
)
|
||||
logger.debug(
|
||||
"LDAP search filter: %s",
|
||||
query
|
||||
)
|
||||
conn.search(
|
||||
search_base=self.ldap_base,
|
||||
search_filter=query
|
||||
)
|
||||
|
||||
if len(conn.response) == 1:
|
||||
# GOOD: found exactly one result
|
||||
user_dn = conn.response[0]['dn']
|
||||
logger.debug('LDAP search found dn: %s', user_dn)
|
||||
|
||||
# unbind and simple bind with user_dn to verify the password
|
||||
# Note: do not use rebind(), for some reason it did not verify
|
||||
# the password for me!
|
||||
conn.unbind()
|
||||
return self._ldap_simple_bind(server, localpart, password)
|
||||
else:
|
||||
# BAD: found 0 or > 1 results, abort!
|
||||
if len(conn.response) == 0:
|
||||
logger.info(
|
||||
"LDAP search returned no results for '%s'",
|
||||
localpart
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"LDAP search returned too many (%s) results for '%s'",
|
||||
len(conn.response), localpart
|
||||
)
|
||||
conn.unbind()
|
||||
return False, None
|
||||
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during LDAP authentication: %s", e)
|
||||
return False, None
|
||||
|
||||
|
||||
def _require_keys(config, required):
|
||||
missing = [key for key in required if key not in config]
|
||||
if missing:
|
||||
raise ConfigError(
|
||||
"LDAP enabled but missing required config values: {}".format(
|
||||
", ".join(missing)
|
||||
)
|
||||
)
|
|
@ -121,15 +121,9 @@ class RetryDestinationLimiter(object):
|
|||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def err(failure):
|
||||
logger.exception(
|
||||
"Failed to store set_destination_retry_timings",
|
||||
failure.value
|
||||
)
|
||||
|
||||
valid_err_code = False
|
||||
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
||||
valid_err_code = 0 <= exc_val.code < 500
|
||||
valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
|
||||
|
||||
if exc_type is None or valid_err_code:
|
||||
# We connected successfully.
|
||||
|
@ -151,6 +145,15 @@ class RetryDestinationLimiter(object):
|
|||
|
||||
retry_last_ts = int(self.clock.time_msec())
|
||||
|
||||
self.store.set_destination_retry_timings(
|
||||
self.destination, retry_last_ts, self.retry_interval
|
||||
).addErrback(err)
|
||||
@defer.inlineCallbacks
|
||||
def store_retry_timings():
|
||||
try:
|
||||
yield self.store.set_destination_retry_timings(
|
||||
self.destination, retry_last_ts, self.retry_interval
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to store set_destination_retry_timings",
|
||||
)
|
||||
|
||||
store_retry_timings()
|
||||
|
|
|
@ -12,17 +12,22 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from tests import unittest
|
||||
|
||||
import pymacaroons
|
||||
from mock import Mock
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import synapse.handlers.auth
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.types import UserID
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver, mock_getRawHeaders
|
||||
|
||||
import pymacaroons
|
||||
|
||||
class TestHandlers(object):
|
||||
def __init__(self, hs):
|
||||
self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
|
||||
|
||||
|
||||
class AuthTestCase(unittest.TestCase):
|
||||
|
@ -34,14 +39,17 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
self.hs = yield setup_test_homeserver(handlers=None)
|
||||
self.hs.get_datastore = Mock(return_value=self.store)
|
||||
self.hs.handlers = TestHandlers(self.hs)
|
||||
self.auth = Auth(self.hs)
|
||||
|
||||
self.test_user = "@foo:bar"
|
||||
self.test_token = "_test_token_"
|
||||
|
||||
# this is overridden for the appservice tests
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_user_valid_token(self):
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
user_info = {
|
||||
"name": self.test_user,
|
||||
"token_id": "ditto",
|
||||
|
@ -56,7 +64,6 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_user_bad_token(self):
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
|
@ -66,7 +73,6 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.failureResultOf(d, AuthError)
|
||||
|
||||
def test_get_user_by_req_user_missing_token(self):
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
user_info = {
|
||||
"name": self.test_user,
|
||||
"token_id": "ditto",
|
||||
|
@ -158,7 +164,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
||||
|
@ -168,6 +174,10 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_guest_user_from_macaroon(self):
|
||||
self.store.get_user_by_id = Mock(return_value={
|
||||
"is_guest": True,
|
||||
})
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
|
@ -179,11 +189,12 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("guest = true")
|
||||
serialized = macaroon.serialize()
|
||||
|
||||
user_info = yield self.auth.get_user_from_macaroon(serialized)
|
||||
user_info = yield self.auth.get_user_by_access_token(serialized)
|
||||
user = user_info["user"]
|
||||
is_guest = user_info["is_guest"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertTrue(is_guest)
|
||||
self.store.get_user_by_id.assert_called_with(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_from_macaroon_user_db_mismatch(self):
|
||||
|
@ -200,7 +211,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("User mismatch", cm.exception.msg)
|
||||
|
||||
|
@ -220,7 +231,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("type = access")
|
||||
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("No user caveat", cm.exception.msg)
|
||||
|
||||
|
@ -242,7 +253,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
||||
|
@ -265,7 +276,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("cunning > fox")
|
||||
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
||||
|
@ -293,12 +304,12 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
self.hs.clock.now = 5000 # seconds
|
||||
self.hs.config.expire_access_token = True
|
||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
# yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||
# validate expiration (and remove the above line, which will start
|
||||
# throwing).
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
||||
|
@ -327,6 +338,58 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.hs.clock.now = 5000 # seconds
|
||||
self.hs.config.expire_access_token = True
|
||||
|
||||
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cannot_use_regular_token_as_guest(self):
|
||||
USER_ID = "@percy:matrix.org"
|
||||
self.store.add_access_token_to_user = Mock()
|
||||
|
||||
token = yield self.hs.handlers.auth_handler.issue_access_token(
|
||||
USER_ID, "DEVICE"
|
||||
)
|
||||
self.store.add_access_token_to_user.assert_called_with(
|
||||
USER_ID, token, "DEVICE"
|
||||
)
|
||||
|
||||
def get_user(tok):
|
||||
if token != tok:
|
||||
return None
|
||||
return {
|
||||
"name": USER_ID,
|
||||
"is_guest": False,
|
||||
"token_id": 1234,
|
||||
"device_id": "DEVICE",
|
||||
}
|
||||
self.store.get_user_by_access_token = get_user
|
||||
self.store.get_user_by_id = Mock(return_value={
|
||||
"is_guest": False,
|
||||
})
|
||||
|
||||
# check the token works
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||
self.assertFalse(requester.is_guest)
|
||||
|
||||
# add an is_guest caveat
|
||||
mac = pymacaroons.Macaroon.deserialize(token)
|
||||
mac.add_first_party_caveat("guest = true")
|
||||
guest_tok = mac.serialize()
|
||||
|
||||
# the token should *not* work now
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [guest_tok]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
||||
|
||||
self.store.get_user_by_id.assert_called_with(USER_ID)
|
||||
|
|
|
@ -17,7 +17,11 @@
|
|||
from .. import unittest
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.events.utils import prune_event, serialize_event
|
||||
|
||||
|
||||
def MockEvent(**kwargs):
|
||||
return FrozenEvent(kwargs)
|
||||
|
||||
|
||||
class PruneEventTestCase(unittest.TestCase):
|
||||
|
@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
'unsigned': {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SerializeEventTestCase(unittest.TestCase):
|
||||
|
||||
def serialize(self, ev, fields):
|
||||
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
||||
|
||||
def test_event_fields_works_with_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar"
|
||||
),
|
||||
["room_id"]
|
||||
),
|
||||
{
|
||||
"room_id": "!foo:bar",
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_nested_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"body": "A message",
|
||||
},
|
||||
),
|
||||
["content.body"]
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"body": "A message",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_dot_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"key.with.dots": {},
|
||||
},
|
||||
),
|
||||
["content.key\.with\.dots"]
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"key.with.dots": {},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_nested_dot_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"not_me": 1,
|
||||
"nested.dot.key": {
|
||||
"leaf.key": 42,
|
||||
"not_me_either": 1,
|
||||
},
|
||||
},
|
||||
),
|
||||
["content.nested\.dot\.key.leaf\.key"]
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"nested.dot.key": {
|
||||
"leaf.key": 42,
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_unknown_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
),
|
||||
["content.foo", "content.notexists"]
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"foo": "bar",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_non_dict_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": ["I", "am", "an", "array"],
|
||||
},
|
||||
),
|
||||
["content.foo.am"]
|
||||
),
|
||||
{}
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_array_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": ["I", "am", "an", "array"],
|
||||
},
|
||||
),
|
||||
["content.foo.1"]
|
||||
),
|
||||
{}
|
||||
)
|
||||
|
||||
def test_event_fields_all_fields_if_empty(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
),
|
||||
[]
|
||||
),
|
||||
{
|
||||
"room_id": "!foo:bar",
|
||||
"content": {
|
||||
"foo": "bar",
|
||||
},
|
||||
"unsigned": {}
|
||||
}
|
||||
)
|
||||
|
||||
def test_event_fields_fail_if_fields_not_str(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
),
|
||||
["room_id", 4]
|
||||
)
|
||||
|
|
|
@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase):
|
|||
def verify_type(caveat):
|
||||
return caveat == "type = access"
|
||||
|
||||
def verify_expiry(caveat):
|
||||
return caveat == "time < 8600000"
|
||||
def verify_nonce(caveat):
|
||||
return caveat.startswith("nonce =")
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_general(verify_gen)
|
||||
v.satisfy_general(verify_user)
|
||||
v.satisfy_general(verify_type)
|
||||
v.satisfy_general(verify_expiry)
|
||||
v.satisfy_general(verify_nonce)
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
|
|
|
@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
duration_ms = 200
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
user_id = "@someone:test"
|
||||
requester = create_requester("@as:test")
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
requester, local_part, display_name, duration_ms)
|
||||
requester, local_part, display_name)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
||||
|
||||
|
@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
user_id=frank.to_string(),
|
||||
token="jkv;g498752-43gj['eamb!-5",
|
||||
password_hash=None)
|
||||
duration_ms = 200
|
||||
local_part = "frank"
|
||||
display_name = "Frank"
|
||||
user_id = "@frank:test"
|
||||
requester = create_requester("@as:test")
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
requester, local_part, display_name, duration_ms)
|
||||
requester, local_part, display_name)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
||||
|
|
|
@ -103,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||
room_id = yield self.create_room()
|
||||
event_id = yield self.send_text_message(room_id, "Hello, World")
|
||||
get = self.get(receipts="-1")
|
||||
yield self.hs.get_handlers().receipts_handler.received_client_receipt(
|
||||
yield self.hs.get_receipts_handler().received_client_receipt(
|
||||
room_id, "m.read", self.user_id, event_id
|
||||
)
|
||||
code, body = yield get
|
||||
|
|
|
@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=user_id
|
||||
)
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
|
@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertIn("refresh_token", result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_invalid(self):
|
||||
|
@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
"password": "monkey"
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
)
|
||||
self.device_handler.check_device_registered = \
|
||||
Mock(return_value=device_id)
|
||||
|
@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertIn("refresh_token", result)
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||
|
||||
self.as_token = "token1"
|
||||
self.as_url = "some_url"
|
||||
|
@ -112,7 +112,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||
self.db_pool = hs.get_db_pool()
|
||||
|
||||
self.as_list = [
|
||||
|
@ -443,7 +443,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||
password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
)
|
||||
|
||||
ApplicationServiceStore(hs)
|
||||
|
||||
|
@ -456,7 +460,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||
password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
@ -475,7 +483,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||
password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
hs = yield setup_test_homeserver(
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock()
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue