Merge branch 'release-v0.25.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
552f123bea
58
CHANGES.rst
58
CHANGES.rst
|
@ -1,3 +1,61 @@
|
|||
Changes in synapse v0.25.0 (2017-11-15)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix port script (PR #2673)
|
||||
|
||||
|
||||
Changes in synapse v0.25.0-rc1 (2017-11-14)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add is_public to groups table to allow for private groups (PR #2582)
|
||||
* Add a route for determining who you are (PR #2668) Thanks to @turt2live!
|
||||
* Add more features to the password providers (PR #2608, #2610, #2620, #2622,
|
||||
#2623, #2624, #2626, #2628, #2629)
|
||||
* Add a hook for custom rest endpoints (PR #2627)
|
||||
* Add API to update group room visibility (PR #2651)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Ignore <noscript> tags when generating URL preview descriptions (PR #2576)
|
||||
Thanks to @maximevaillancourt!
|
||||
* Register some /unstable endpoints in /r0 as well (PR #2579) Thanks to
|
||||
@krombel!
|
||||
* Support /keys/upload on /r0 as well as /unstable (PR #2585)
|
||||
* Front-end proxy: pass through auth header (PR #2586)
|
||||
* Allow ASes to deactivate their own users (PR #2589)
|
||||
* Remove refresh tokens (PR #2613)
|
||||
* Automatically set default displayname on register (PR #2617)
|
||||
* Log login requests (PR #2618)
|
||||
* Always return `is_public` in the `/groups/:group_id/rooms` API (PR #2630)
|
||||
* Avoid no-op media deletes (PR #2637) Thanks to @spantaleev!
|
||||
* Fix various embarrassing typos around user_directory and add some doc. (PR
|
||||
#2643)
|
||||
* Return whether a user is an admin within a group (PR #2647)
|
||||
* Namespace visibility options for groups (PR #2657)
|
||||
* Downcase UserIDs on registration (PR #2662)
|
||||
* Cache failures when fetching URL previews (PR #2669)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix port script (PR #2577)
|
||||
* Fix error when running synapse with no logfile (PR #2581)
|
||||
* Fix UI auth when deleting devices (PR #2591)
|
||||
* Fix typo when checking if user is invited to group (PR #2599)
|
||||
* Fix the port script to drop NUL values in all tables (PR #2611)
|
||||
* Fix appservices being backlogged and not receiving new events due to a bug in
|
||||
notify_interested_services (PR #2631) Thanks to @xyzz!
|
||||
* Fix updating rooms avatar/display name when modified by admin (PR #2636)
|
||||
Thanks to @farialima!
|
||||
* Fix bug in state group storage (PR #2649)
|
||||
* Fix 500 on invalid utf-8 in request (PR #2663)
|
||||
|
||||
|
||||
Changes in synapse v0.24.1 (2017-10-24)
|
||||
=======================================
|
||||
|
||||
|
|
|
@ -823,7 +823,9 @@ spidering 'internal' URLs on your network. At the very least we recommend that
|
|||
your loopback and RFC1918 IP addresses are blacklisted.
|
||||
|
||||
This also requires the optional lxml and netaddr python dependencies to be
|
||||
installed.
|
||||
installed. This in turn requires the libxml2 library to be available - on
|
||||
Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for
|
||||
your OS.
|
||||
|
||||
|
||||
Password reset
|
||||
|
|
|
@ -1,52 +1,119 @@
|
|||
Basically, PEP8
|
||||
- Everything should comply with PEP8. Code should pass
|
||||
``pep8 --max-line-length=100`` without any warnings.
|
||||
|
||||
- NEVER tabs. 4 spaces to indent.
|
||||
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
|
||||
- **Indenting**:
|
||||
|
||||
- NEVER tabs. 4 spaces to indent.
|
||||
|
||||
- follow PEP8; either hanging indent or multiline-visual indent depending
|
||||
on the size and shape of the arguments and what makes more sense to the
|
||||
author. In other words, both this::
|
||||
|
||||
print("I am a fish %s" % "moo")
|
||||
|
||||
and this::
|
||||
|
||||
print("I am a fish %s" %
|
||||
"moo")
|
||||
|
||||
and this::
|
||||
|
||||
print(
|
||||
"I am a fish %s" %
|
||||
"moo",
|
||||
)
|
||||
|
||||
...are valid, although given each one takes up 2x more vertical space than
|
||||
the previous, it's up to the author's discretion as to which layout makes
|
||||
most sense for their function invocation. (e.g. if they want to add
|
||||
comments per-argument, or put expressions in the arguments, or group
|
||||
related arguments together, or want to deliberately extend or preserve
|
||||
vertical/horizontal space)
|
||||
|
||||
- **Line length**:
|
||||
|
||||
Max line length is 79 chars (with flexibility to overflow by a "few chars" if
|
||||
the overflowing content is not semantically significant and avoids an
|
||||
explosion of vertical whitespace).
|
||||
- Use camel case for class and type names
|
||||
- Use underscores for functions and variables.
|
||||
- Use double quotes.
|
||||
- Use parentheses instead of '\\' for line continuation where ever possible
|
||||
(which is pretty much everywhere)
|
||||
- There should be max a single new line between:
|
||||
|
||||
Use parentheses instead of ``\`` for line continuation where ever possible
|
||||
(which is pretty much everywhere).
|
||||
|
||||
- **Naming**:
|
||||
|
||||
- Use camel case for class and type names
|
||||
- Use underscores for functions and variables.
|
||||
|
||||
- Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
|
||||
|
||||
- **Blank lines**:
|
||||
|
||||
- There should be max a single new line between:
|
||||
|
||||
- statements
|
||||
- functions in a class
|
||||
- There should be two new lines between:
|
||||
|
||||
- There should be two new lines between:
|
||||
|
||||
- definitions in a module (e.g., between different classes)
|
||||
- There should be spaces where spaces should be and not where there shouldn't be:
|
||||
- a single space after a comma
|
||||
- a single space before and after for '=' when used as assignment
|
||||
- no spaces before and after for '=' for default values and keyword arguments.
|
||||
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
|
||||
depending on the size and shape of the arguments and what makes more sense to
|
||||
the author. In other words, both this::
|
||||
|
||||
print("I am a fish %s" % "moo")
|
||||
- **Whitespace**:
|
||||
|
||||
and this::
|
||||
There should be spaces where spaces should be and not where there shouldn't
|
||||
be:
|
||||
|
||||
print("I am a fish %s" %
|
||||
"moo")
|
||||
- a single space after a comma
|
||||
- a single space before and after for '=' when used as assignment
|
||||
- no spaces before and after for '=' for default values and keyword arguments.
|
||||
|
||||
and this::
|
||||
- **Comments**: should follow the `google code style
|
||||
<http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||
This is so that we can generate documentation with `sphinx
|
||||
<http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||
`examples
|
||||
<http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||
in the sphinx documentation.
|
||||
|
||||
print(
|
||||
"I am a fish %s" %
|
||||
"moo"
|
||||
)
|
||||
- **Imports**:
|
||||
|
||||
...are valid, although given each one takes up 2x more vertical space than
|
||||
the previous, it's up to the author's discretion as to which layout makes most
|
||||
sense for their function invocation. (e.g. if they want to add comments
|
||||
per-argument, or put expressions in the arguments, or group related arguments
|
||||
together, or want to deliberately extend or preserve vertical/horizontal
|
||||
space)
|
||||
- Prefer to import classes and functions than packages or modules.
|
||||
|
||||
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||
This is so that we can generate documentation with
|
||||
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||
in the sphinx documentation.
|
||||
Example::
|
||||
|
||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||
from synapse.types import UserID
|
||||
...
|
||||
user_id = UserID(local, server)
|
||||
|
||||
is preferred over::
|
||||
|
||||
from synapse import types
|
||||
...
|
||||
user_id = types.UserID(local, server)
|
||||
|
||||
(or any other variant).
|
||||
|
||||
This goes against the advice in the Google style guide, but it means that
|
||||
errors in the name are caught early (at import time).
|
||||
|
||||
- Multiple imports from the same package can be combined onto one line::
|
||||
|
||||
from synapse.types import GroupID, RoomID, UserID
|
||||
|
||||
An effort should be made to keep the individual imports in alphabetical
|
||||
order.
|
||||
|
||||
If the list becomes long, wrap it with parentheses and split it over
|
||||
multiple lines.
|
||||
|
||||
- As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
|
||||
imports should be grouped in the following order, with a blank line between
|
||||
each group:
|
||||
|
||||
1. standard library imports
|
||||
2. related third party imports
|
||||
3. local application/library specific imports
|
||||
|
||||
- Imports within each group should be sorted alphabetically by module name.
|
||||
|
||||
- Avoid wildcard imports (``from synapse.types import *``) and relative
|
||||
imports (``from .types import UserID``).
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
Password auth provider modules
|
||||
==============================
|
||||
|
||||
Password auth providers offer a way for server administrators to integrate
|
||||
their Synapse installation with an existing authentication system.
|
||||
|
||||
A password auth provider is a Python class which is dynamically loaded into
|
||||
Synapse, and provides a number of methods by which it can integrate with the
|
||||
authentication system.
|
||||
|
||||
This document serves as a reference for those looking to implement their own
|
||||
password auth providers.
|
||||
|
||||
Required methods
|
||||
----------------
|
||||
|
||||
Password auth provider classes must provide the following methods:
|
||||
|
||||
*class* ``SomeProvider.parse_config``\(*config*)
|
||||
|
||||
This method is passed the ``config`` object for this module from the
|
||||
homeserver configuration file.
|
||||
|
||||
It should perform any appropriate sanity checks on the provided
|
||||
configuration, and return an object which is then passed into ``__init__``.
|
||||
|
||||
*class* ``SomeProvider``\(*config*, *account_handler*)
|
||||
|
||||
The constructor is passed the config object returned by ``parse_config``,
|
||||
and a ``synapse.module_api.ModuleApi`` object which allows the
|
||||
password provider to check if accounts exist and/or create new ones.
|
||||
|
||||
Optional methods
|
||||
----------------
|
||||
|
||||
Password auth provider classes may optionally provide the following methods.
|
||||
|
||||
*class* ``SomeProvider.get_db_schema_files``\()
|
||||
|
||||
This method, if implemented, should return an Iterable of ``(name,
|
||||
stream)`` pairs of database schema files. Each file is applied in turn at
|
||||
initialisation, and a record is then made in the database so that it is
|
||||
not re-applied on the next start.
|
||||
|
||||
``someprovider.get_supported_login_types``\()
|
||||
|
||||
This method, if implemented, should return a ``dict`` mapping from a login
|
||||
type identifier (such as ``m.login.password``) to an iterable giving the
|
||||
fields which must be provided by the user in the submission to the
|
||||
``/login`` api. These fields are passed in the ``login_dict`` dictionary
|
||||
to ``check_auth``.
|
||||
|
||||
For example, if a password auth provider wants to implement a custom login
|
||||
type of ``com.example.custom_login``, where the client is expected to pass
|
||||
the fields ``secret1`` and ``secret2``, the provider should implement this
|
||||
method and return the following dict::
|
||||
|
||||
{"com.example.custom_login": ("secret1", "secret2")}
|
||||
|
||||
``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
|
||||
|
||||
This method is the one that does the real work. If implemented, it will be
|
||||
called for each login attempt where the login type matches one of the keys
|
||||
returned by ``get_supported_login_types``.
|
||||
|
||||
It is passed the (possibly UNqualified) ``user`` provided by the client,
|
||||
the login type, and a dictionary of login secrets passed by the client.
|
||||
|
||||
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||
the canonical ``@localpart:domain`` user id if authentication is successful,
|
||||
and ``None`` if not.
|
||||
|
||||
Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in
|
||||
which case the second field is a callback which will be called with the
|
||||
result from the ``/login`` call (including ``access_token``, ``device_id``,
|
||||
etc.)
|
||||
|
||||
``someprovider.check_password``\(*user_id*, *password*)
|
||||
|
||||
This method provides a simpler interface than ``get_supported_login_types``
|
||||
and ``check_auth`` for password auth providers that just want to provide a
|
||||
mechanism for validating ``m.login.password`` logins.
|
||||
|
||||
Iif implemented, it will be called to check logins with an
|
||||
``m.login.password`` login type. It is passed a qualified
|
||||
``@localpart:domain`` user id, and the password provided by the user.
|
||||
|
||||
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||
``True`` if authentication is successful, and ``False`` if not.
|
||||
|
||||
``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*)
|
||||
|
||||
This method, if implemented, is called when a user logs out. It is passed
|
||||
the qualified user ID, the ID of the deactivated device (if any: access
|
||||
tokens are occasionally created without an associated device ID), and the
|
||||
(now deactivated) access token.
|
||||
|
||||
It may return a Twisted ``Deferred`` object; the logout request will wait
|
||||
for the deferred to complete but the result is ignored.
|
|
@ -56,6 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o
|
|||
API
|
||||
---
|
||||
|
||||
```
|
||||
GET /_matrix/media/r0/preview_url?url=http://wherever.com
|
||||
200 OK
|
||||
{
|
||||
|
@ -66,6 +67,7 @@ GET /_matrix/media/r0/preview_url?url=http://wherever.com
|
|||
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
|
||||
"og:site_name" : "Twitter"
|
||||
}
|
||||
```
|
||||
|
||||
* Downloads the URL
|
||||
* If HTML, just stores it in RAM and parses it for OG meta tags
|
|
@ -0,0 +1,17 @@
|
|||
User Directory API Implementation
|
||||
=================================
|
||||
|
||||
The user directory is currently maintained based on the 'visible' users
|
||||
on this particular server - i.e. ones which your account shares a room with, or
|
||||
who are present in a publicly viewable room present on the server.
|
||||
|
||||
The directory info is stored in various tables, which can (typically after
|
||||
DB corruption) get stale or out of sync. If this happens, for now the
|
||||
quickest solution to fix it is:
|
||||
|
||||
```
|
||||
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||
```
|
||||
|
||||
and restart the synapse, which should then start a background task to
|
||||
flush the current tables and regenerate the directory.
|
|
@ -42,6 +42,14 @@ BOOLEAN_COLUMNS = {
|
|||
"public_room_list_stream": ["visibility"],
|
||||
"device_lists_outbound_pokes": ["sent"],
|
||||
"users_who_share_rooms": ["share_private"],
|
||||
"groups": ["is_public"],
|
||||
"group_rooms": ["is_public"],
|
||||
"group_users": ["is_public", "is_admin"],
|
||||
"group_summary_rooms": ["is_public"],
|
||||
"group_room_categories": ["is_public"],
|
||||
"group_summary_users": ["is_public"],
|
||||
"group_roles": ["is_public"],
|
||||
"local_group_membership": ["is_publicised", "is_admin"],
|
||||
}
|
||||
|
||||
|
||||
|
@ -112,6 +120,7 @@ class Store(object):
|
|||
|
||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||
_simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
|
||||
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
def r(conn):
|
||||
|
@ -318,7 +327,7 @@ class Porter(object):
|
|||
backward_chunk = min(row[0] for row in brows) - 1
|
||||
|
||||
rows = frows + brows
|
||||
self._convert_rows(table, headers, rows)
|
||||
rows = self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
|
@ -554,17 +563,29 @@ class Porter(object):
|
|||
i for i, h in enumerate(headers) if h in bool_col_names
|
||||
]
|
||||
|
||||
class BadValueException(Exception):
|
||||
pass
|
||||
|
||||
def conv(j, col):
|
||||
if j in bool_cols:
|
||||
return bool(col)
|
||||
elif isinstance(col, basestring) and "\0" in col:
|
||||
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
|
||||
raise BadValueException();
|
||||
return col
|
||||
|
||||
outrows = []
|
||||
for i, row in enumerate(rows):
|
||||
rows[i] = tuple(
|
||||
conv(j, col)
|
||||
for j, col in enumerate(row)
|
||||
if j > 0
|
||||
)
|
||||
try:
|
||||
outrows.append(tuple(
|
||||
conv(j, col)
|
||||
for j, col in enumerate(row)
|
||||
if j > 0
|
||||
))
|
||||
except BadValueException:
|
||||
pass
|
||||
|
||||
return outrows
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _setup_sent_transactions(self):
|
||||
|
@ -592,7 +613,7 @@ class Porter(object):
|
|||
"select", r,
|
||||
)
|
||||
|
||||
self._convert_rows("sent_transactions", headers, rows)
|
||||
rows = self._convert_rows("sent_transactions", headers, rows)
|
||||
|
||||
inserted_rows = len(rows)
|
||||
if inserted_rows:
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.24.1"
|
||||
__version__ = "0.25.0"
|
||||
|
|
|
@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy")
|
|||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||
releases=())
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
if body:
|
||||
# They're actually trying to upload something, proxy to main synapse.
|
||||
# Pass through the auth headers, if any, in case the access token
|
||||
# is there.
|
||||
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
|
||||
headers = {
|
||||
"Authorization": auth_headers,
|
||||
}
|
||||
result = yield self.http_client.post_json_get_json(
|
||||
self.main_uri + request.uri,
|
||||
body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
|
|
@ -30,6 +30,8 @@ from synapse.config._base import ConfigError
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.http.additional_resource import AdditionalResource
|
||||
from synapse.http.server import RootRedirect
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics import register_memory_metrics
|
||||
|
@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
|
|||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.module_loader import load_module
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from twisted.application import service
|
||||
|
@ -107,52 +110,18 @@ class SynapseHomeServer(HomeServer):
|
|||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "client":
|
||||
client_resource = ClientRestResource(self)
|
||||
if res["compress"]:
|
||||
client_resource = gz_wrap(client_resource)
|
||||
resources.update(self._configure_named_resource(
|
||||
name, res.get("compress", False),
|
||||
))
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client/api/v1": client_resource,
|
||||
"/_matrix/client/r0": client_resource,
|
||||
"/_matrix/client/unstable": client_resource,
|
||||
"/_matrix/client/v2_alpha": client_resource,
|
||||
"/_matrix/client/versions": client_resource,
|
||||
})
|
||||
|
||||
if name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
|
||||
if name in ["static", "client"]:
|
||||
resources.update({
|
||||
STATIC_PREFIX: File(
|
||||
os.path.join(os.path.dirname(synapse.__file__), "static")
|
||||
),
|
||||
})
|
||||
|
||||
if name in ["media", "federation", "client"]:
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
resources.update({
|
||||
SERVER_KEY_PREFIX: LocalKey(self),
|
||||
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
|
||||
})
|
||||
|
||||
if name == "webclient":
|
||||
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
|
||||
|
||||
if name == "metrics" and self.get_config().enable_metrics:
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
additional_resources = listener_config.get("additional_resources", {})
|
||||
logger.debug("Configuring additional resources: %r",
|
||||
additional_resources)
|
||||
module_api = ModuleApi(self, self.get_auth_handler())
|
||||
for path, resmodule in additional_resources.items():
|
||||
handler_cls, config = load_module(resmodule)
|
||||
handler = handler_cls(config, module_api)
|
||||
resources[path] = AdditionalResource(self, handler.handle_request)
|
||||
|
||||
if WEB_CLIENT_PREFIX in resources:
|
||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||
|
@ -188,6 +157,67 @@ class SynapseHomeServer(HomeServer):
|
|||
)
|
||||
logger.info("Synapse now listening on port %d", port)
|
||||
|
||||
def _configure_named_resource(self, name, compress=False):
|
||||
"""Build a resource map for a named resource
|
||||
|
||||
Args:
|
||||
name (str): named resource: one of "client", "federation", etc
|
||||
compress (bool): whether to enable gzip compression for this
|
||||
resource
|
||||
|
||||
Returns:
|
||||
dict[str, Resource]: map from path to HTTP resource
|
||||
"""
|
||||
resources = {}
|
||||
if name == "client":
|
||||
client_resource = ClientRestResource(self)
|
||||
if compress:
|
||||
client_resource = gz_wrap(client_resource)
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client/api/v1": client_resource,
|
||||
"/_matrix/client/r0": client_resource,
|
||||
"/_matrix/client/unstable": client_resource,
|
||||
"/_matrix/client/v2_alpha": client_resource,
|
||||
"/_matrix/client/versions": client_resource,
|
||||
})
|
||||
|
||||
if name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
|
||||
if name in ["static", "client"]:
|
||||
resources.update({
|
||||
STATIC_PREFIX: File(
|
||||
os.path.join(os.path.dirname(synapse.__file__), "static")
|
||||
),
|
||||
})
|
||||
|
||||
if name in ["media", "federation", "client"]:
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
resources.update({
|
||||
SERVER_KEY_PREFIX: LocalKey(self),
|
||||
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
|
||||
})
|
||||
|
||||
if name == "webclient":
|
||||
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
|
||||
|
||||
if name == "metrics" and self.get_config().enable_metrics:
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
return resources
|
||||
|
||||
def start_listening(self):
|
||||
config = self.get_config()
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind
|
|||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
|
||||
|
@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
defer.returnValue(None)
|
||||
|
||||
key = (service.id, protocol)
|
||||
return self.protocol_meta_cache.get(key) or (
|
||||
self.protocol_meta_cache.set(key, _get())
|
||||
)
|
||||
result = self.protocol_meta_cache.get(key)
|
||||
if not result:
|
||||
result = self.protocol_meta_cache.set(
|
||||
key, preserve_fn(_get)()
|
||||
)
|
||||
return make_deferred_yieldable(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_bulk(self, service, events, txn_id=None):
|
||||
|
|
|
@ -41,7 +41,7 @@ class CasConfig(Config):
|
|||
#cas_config:
|
||||
# enabled: true
|
||||
# server_url: "https://cas-server.com"
|
||||
# service_url: "https://homesever.domain.com:8448"
|
||||
# service_url: "https://homeserver.domain.com:8448"
|
||||
# #required_attributes:
|
||||
# # name: value
|
||||
"""
|
||||
|
|
|
@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False):
|
|||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if log_config is None:
|
||||
|
||||
if log_config is None:
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if config.verbosity:
|
||||
|
@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False):
|
|||
logger.info("Opened new log file due to SIGHUP")
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
def sighup(signum, stack):
|
||||
pass
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
|
|
|
@ -13,41 +13,40 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
from ._base import Config
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
|
||||
|
||||
|
||||
class PasswordAuthProviderConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.password_providers = []
|
||||
|
||||
provider_config = None
|
||||
providers = []
|
||||
|
||||
# We want to be backwards compatible with the old `ldap_config`
|
||||
# param.
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
if self.ldap_enabled:
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
parsed_config = LdapAuthProvider.parse_config(ldap_config)
|
||||
self.password_providers.append((LdapAuthProvider, parsed_config))
|
||||
if ldap_config.get("enabled", False):
|
||||
providers.append[{
|
||||
'module': LDAP_PROVIDER,
|
||||
'config': ldap_config,
|
||||
}]
|
||||
|
||||
providers = config.get("password_providers", [])
|
||||
providers.extend(config.get("password_providers", []))
|
||||
for provider in providers:
|
||||
mod_name = provider['module']
|
||||
|
||||
# This is for backwards compat when the ldap auth provider resided
|
||||
# in this package.
|
||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
provider_class = LdapAuthProvider
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
else:
|
||||
(provider_class, provider_config) = load_module(provider)
|
||||
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
mod_name = LDAP_PROVIDER
|
||||
|
||||
(provider_class, provider_config) = load_module({
|
||||
"module": mod_name,
|
||||
"config": provider['config'],
|
||||
})
|
||||
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
|
|
|
@ -247,6 +247,13 @@ class ServerConfig(Config):
|
|||
- names: [federation] # Federation APIs
|
||||
compress: false
|
||||
|
||||
# optional list of additional endpoints which can be loaded via
|
||||
# dynamic modules
|
||||
# additional_resources:
|
||||
# "/_matrix/my/custom/endpoint":
|
||||
# module: my_module.CustomRequestHandler
|
||||
# config: {}
|
||||
|
||||
# Unsecure HTTP listener,
|
||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
||||
- port: %(unsecure_port)s
|
||||
|
|
|
@ -109,6 +109,12 @@ class TlsConfig(Config):
|
|||
# key. It may be necessary to publish the fingerprints of a new
|
||||
# certificate and wait until the "valid_until_ts" of the previous key
|
||||
# responses have passed before deploying it.
|
||||
#
|
||||
# You can calculate a fingerprint from a given TLS listener via:
|
||||
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
||||
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
||||
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
||||
#
|
||||
tls_fingerprints: []
|
||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||
""" % locals()
|
||||
|
|
|
@ -18,6 +18,7 @@ from .federation_base import FederationBase
|
|||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util import async
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
|
@ -253,12 +254,13 @@ class FederationServer(FederationBase):
|
|||
result = self._state_resp_cache.get((room_id, event_id))
|
||||
if not result:
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
resp = yield self._state_resp_cache.set(
|
||||
d = self._state_resp_cache.set(
|
||||
(room_id, event_id),
|
||||
self._on_context_state_request_compute(room_id, event_id)
|
||||
preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
|
||||
)
|
||||
resp = yield make_deferred_yieldable(d)
|
||||
else:
|
||||
resp = yield result
|
||||
resp = yield make_deferred_yieldable(result)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
|
|
@ -545,6 +545,20 @@ class TransportLayerClient(object):
|
|||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
|
||||
config_key, content):
|
||||
"""Update room in group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||
"""Remove a room from a group
|
||||
"""
|
||||
|
|
|
@ -676,7 +676,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
|
|||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove room from group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, room_id):
|
||||
|
@ -703,6 +703,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
|||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
||||
"""Update room config in group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||
"/config/(?P<config_key>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, room_id, config_key):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
result = yield self.groups_handler.update_room_in_group(
|
||||
group_id, requester_user_id, room_id, config_key, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||
"""Get the users in a group on behalf of a user
|
||||
"""
|
||||
|
@ -1142,6 +1163,8 @@ GROUP_SERVER_SERVLET_CLASSES = (
|
|||
FederationGroupsRolesServlet,
|
||||
FederationGroupsRoleServlet,
|
||||
FederationGroupsSummaryUsersServlet,
|
||||
FederationGroupsAddRoomsServlet,
|
||||
FederationGroupsAddRoomsConfigServlet,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attestations ensure that users and groups can't lie about their memberships.
|
||||
|
||||
When a user joins a group the HS and GS swap attestations, which allow them
|
||||
both to independently prove to third parties their membership.These
|
||||
attestations have a validity period so need to be periodically renewed.
|
||||
|
||||
If a user leaves (or gets kicked out of) a group, either side can still use
|
||||
their attestation to "prove" their membership, until the attestation expires.
|
||||
Therefore attestations shouldn't be relied on to prove membership in important
|
||||
cases, but can for less important situtations, e.g. showing a users membership
|
||||
of groups on their profile, showing flairs, etc.abs
|
||||
|
||||
An attestsation is a signed blob of json that looks like:
|
||||
|
||||
{
|
||||
"user_id": "@foo:a.example.com",
|
||||
"group_id": "+bar:b.example.com",
|
||||
"valid_until_ms": 1507994728530,
|
||||
"signatures":{"matrix.org":{"ed25519:auto":"..."}}
|
||||
}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn
|
|||
from signedjson.sign import sign_json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default validity duration for new attestations we create
|
||||
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
||||
|
||||
# We add some jitter to the validity duration of attestations so that if we
|
||||
# add lots of users at once we don't need to renew them all at once.
|
||||
# The jitter is a multiplier picked randomly between the first and second number
|
||||
DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
|
||||
|
||||
# Start trying to update our attestations when they come this close to expiring
|
||||
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
||||
|
||||
|
@ -73,10 +106,14 @@ class GroupAttestationSigning(object):
|
|||
"""Create an attestation for the group_id and user_id with default
|
||||
validity length.
|
||||
"""
|
||||
validity_period = DEFAULT_ATTESTATION_LENGTH_MS
|
||||
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
|
||||
valid_until_ms = int(self.clock.time_msec() + validity_period)
|
||||
|
||||
return sign_json({
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
|
||||
"valid_until_ms": valid_until_ms,
|
||||
}, self.server_name, self.signing_key)
|
||||
|
||||
|
||||
|
@ -128,12 +165,19 @@ class GroupAttestionRenewer(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_id, user_id):
|
||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
if not self.is_mine_id(group_id):
|
||||
destination = get_domain_from_id(group_id)
|
||||
elif not self.is_mine_id(user_id):
|
||||
destination = get_domain_from_id(user_id)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
logger.warn(
|
||||
"Incorrectly trying to do attestations for user: %r in %r",
|
||||
user_id, group_id,
|
||||
)
|
||||
yield self.store.remove_attestation_renewal(group_id, user_id)
|
||||
return
|
||||
|
||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
yield self.transport_client.renew_group_attestation(
|
||||
destination, group_id, user_id,
|
||||
|
|
|
@ -49,7 +49,8 @@ class GroupsServerHandler(object):
|
|||
hs.get_groups_attestation_renewer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
|
||||
def check_group_is_ours(self, group_id, requester_user_id,
|
||||
and_exists=False, and_is_admin=None):
|
||||
"""Check that the group is ours, and optionally if it exists.
|
||||
|
||||
If group does exist then return group.
|
||||
|
@ -67,6 +68,10 @@ class GroupsServerHandler(object):
|
|||
if and_exists and not group:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
if group and not is_user_in_group and not group["is_public"]:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
if and_is_admin:
|
||||
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
||||
if not is_admin:
|
||||
|
@ -84,7 +89,7 @@ class GroupsServerHandler(object):
|
|||
|
||||
A user/room may appear in multiple roles/categories.
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
|
@ -153,10 +158,16 @@ class GroupsServerHandler(object):
|
|||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
|
||||
def update_group_summary_room(self, group_id, requester_user_id,
|
||||
room_id, category_id, content):
|
||||
"""Add/update a room to the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
|
@ -175,10 +186,16 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
|
||||
def delete_group_summary_room(self, group_id, requester_user_id,
|
||||
room_id, category_id):
|
||||
"""Remove a room from the summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
yield self.store.remove_room_from_summary(
|
||||
group_id=group_id,
|
||||
|
@ -189,10 +206,10 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_categories(self, group_id, user_id):
|
||||
def get_group_categories(self, group_id, requester_user_id):
|
||||
"""Get all categories in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
categories = yield self.store.get_group_categories(
|
||||
group_id=group_id,
|
||||
|
@ -200,10 +217,10 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({"categories": categories})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_category(self, group_id, user_id, category_id):
|
||||
def get_group_category(self, group_id, requester_user_id, category_id):
|
||||
"""Get a specific category in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_category(
|
||||
group_id=group_id,
|
||||
|
@ -213,10 +230,15 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_category(self, group_id, user_id, category_id, content):
|
||||
def update_group_category(self, group_id, requester_user_id, category_id, content):
|
||||
"""Add/Update a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
profile = content.get("profile")
|
||||
|
@ -231,10 +253,15 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_category(self, group_id, user_id, category_id):
|
||||
def delete_group_category(self, group_id, requester_user_id, category_id):
|
||||
"""Delete a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_group_category(
|
||||
group_id=group_id,
|
||||
|
@ -244,10 +271,10 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_roles(self, group_id, user_id):
|
||||
def get_group_roles(self, group_id, requester_user_id):
|
||||
"""Get all roles in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
roles = yield self.store.get_group_roles(
|
||||
group_id=group_id,
|
||||
|
@ -255,10 +282,10 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({"roles": roles})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_role(self, group_id, user_id, role_id):
|
||||
def get_group_role(self, group_id, requester_user_id, role_id):
|
||||
"""Get a specific role in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_role(
|
||||
group_id=group_id,
|
||||
|
@ -267,10 +294,15 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_role(self, group_id, user_id, role_id, content):
|
||||
def update_group_role(self, group_id, requester_user_id, role_id, content):
|
||||
"""Add/update a role in a group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
|
@ -286,10 +318,15 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_role(self, group_id, user_id, role_id):
|
||||
def delete_group_role(self, group_id, requester_user_id, role_id):
|
||||
"""Remove role from group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
yield self.store.remove_group_role(
|
||||
group_id=group_id,
|
||||
|
@ -304,7 +341,7 @@ class GroupsServerHandler(object):
|
|||
"""Add/update a users entry in the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
order = content.get("order", None)
|
||||
|
@ -326,7 +363,7 @@ class GroupsServerHandler(object):
|
|||
"""Remove a user from the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
yield self.store.remove_user_from_summary(
|
||||
|
@ -342,7 +379,7 @@ class GroupsServerHandler(object):
|
|||
"""Get the group profile as seen by requester_user_id
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id)
|
||||
|
||||
group_description = yield self.store.get_group(group_id)
|
||||
|
||||
|
@ -356,7 +393,7 @@ class GroupsServerHandler(object):
|
|||
"""Update the group profile
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
profile = {}
|
||||
|
@ -377,7 +414,7 @@ class GroupsServerHandler(object):
|
|||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
|
@ -389,14 +426,15 @@ class GroupsServerHandler(object):
|
|||
for user_result in user_results:
|
||||
g_user_id = user_result["user_id"]
|
||||
is_public = user_result["is_public"]
|
||||
is_privileged = user_result["is_admin"]
|
||||
|
||||
entry = {"user_id": g_user_id}
|
||||
|
||||
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
||||
entry.update(profile)
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
entry["is_public"] = bool(is_public)
|
||||
entry["is_privileged"] = bool(is_privileged)
|
||||
|
||||
if not self.is_mine_id(g_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||
|
@ -425,7 +463,7 @@ class GroupsServerHandler(object):
|
|||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
|
@ -459,7 +497,7 @@ class GroupsServerHandler(object):
|
|||
This returns rooms in order of decreasing number of joined users
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
|
@ -470,7 +508,6 @@ class GroupsServerHandler(object):
|
|||
chunk = []
|
||||
for room_result in room_results:
|
||||
room_id = room_result["room_id"]
|
||||
is_public = room_result["is_public"]
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
|
@ -481,8 +518,7 @@ class GroupsServerHandler(object):
|
|||
if not entry:
|
||||
continue
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
entry["is_public"] = bool(room_result["is_public"])
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
|
@ -500,7 +536,7 @@ class GroupsServerHandler(object):
|
|||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
@ -509,12 +545,35 @@ class GroupsServerHandler(object):
|
|||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
|
||||
content):
|
||||
"""Update room in group
|
||||
"""
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
if config_key == "m.visibility":
|
||||
is_public = _parse_visibility_dict(content)
|
||||
|
||||
yield self.store.update_room_in_group_visibility(
|
||||
group_id, room_id,
|
||||
is_public=is_public,
|
||||
)
|
||||
else:
|
||||
raise SynapseError(400, "Uknown config option")
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||
"""Remove room from group
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_room_from_group(group_id, room_id)
|
||||
|
@ -527,7 +586,7 @@ class GroupsServerHandler(object):
|
|||
"""
|
||||
|
||||
group = yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
# TODO: Check if user knocked
|
||||
|
@ -596,35 +655,40 @@ class GroupsServerHandler(object):
|
|||
raise SynapseError(502, "Unknown state returned by HS")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_invite(self, group_id, user_id, content):
|
||||
def accept_invite(self, group_id, requester_user_id, content):
|
||||
"""User tries to accept an invite to the group.
|
||||
|
||||
This is different from them asking to join, and so should error if no
|
||||
invite exists (and they're not a member of the group)
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
if not self.store.is_user_invited_to_local_group(group_id, user_id):
|
||||
is_invited = yield self.store.is_user_invited_to_local_group(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
if not is_invited:
|
||||
raise SynapseError(403, "User not invited to group")
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.hs.is_mine_id(requester_user_id):
|
||||
local_attestation = self.attestations.create_attestation(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
user_id=requester_user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
is_admin=False,
|
||||
is_public=is_public,
|
||||
local_attestation=local_attestation,
|
||||
|
@ -637,31 +701,31 @@ class GroupsServerHandler(object):
|
|||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def knock(self, group_id, user_id, content):
|
||||
def knock(self, group_id, requester_user_id, content):
|
||||
"""A user requests becoming a member of the group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_knock(self, group_id, user_id, content):
|
||||
def accept_knock(self, group_id, requester_user_id, content):
|
||||
"""Accept a users knock to the room.
|
||||
|
||||
Errors if the user hasn't knocked, rather than inviting them.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Remove a user from the group; either a user is leaving or and admin
|
||||
kicked htem.
|
||||
"""Remove a user from the group; either a user is leaving or an admin
|
||||
kicked them.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_kick = False
|
||||
if requester_user_id != user_id:
|
||||
|
@ -692,8 +756,8 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
group = yield self.check_group_is_ours(group_id)
|
||||
def create_group(self, group_id, requester_user_id, content):
|
||||
group = yield self.check_group_is_ours(group_id, requester_user_id)
|
||||
|
||||
logger.info("Attempting to create group with ID: %r", group_id)
|
||||
|
||||
|
@ -703,11 +767,11 @@ class GroupsServerHandler(object):
|
|||
if group:
|
||||
raise SynapseError(400, "Group already exists")
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
|
||||
if not is_admin:
|
||||
if not self.hs.config.enable_group_creation:
|
||||
raise SynapseError(
|
||||
403, "Only server admin can create group on this server",
|
||||
403, "Only a server admin can create groups on this server",
|
||||
)
|
||||
localpart = group_id_obj.localpart
|
||||
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||
|
@ -727,38 +791,41 @@ class GroupsServerHandler(object):
|
|||
|
||||
yield self.store.create_group(
|
||||
group_id,
|
||||
user_id,
|
||||
requester_user_id,
|
||||
name=name,
|
||||
avatar_url=avatar_url,
|
||||
short_description=short_description,
|
||||
long_description=long_description,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.hs.is_mine_id(requester_user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
user_id=requester_user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
local_attestation = self.attestations.create_attestation(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
)
|
||||
else:
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
is_admin=True,
|
||||
is_public=True, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.hs.is_mine_id(requester_user_id):
|
||||
yield self.store.add_remote_profile_cache(
|
||||
user_id,
|
||||
requester_user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
@ -773,15 +840,25 @@ def _parse_visibility_from_contents(content):
|
|||
public or not
|
||||
"""
|
||||
|
||||
visibility = content.get("visibility")
|
||||
visibility = content.get("m.visibility")
|
||||
if visibility:
|
||||
vis_type = visibility["type"]
|
||||
if vis_type not in ("public", "private"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'public'/'private' visibility"
|
||||
)
|
||||
is_public = vis_type == "public"
|
||||
return _parse_visibility_dict(visibility)
|
||||
else:
|
||||
is_public = True
|
||||
|
||||
return is_public
|
||||
|
||||
|
||||
def _parse_visibility_dict(visibility):
|
||||
"""Given a dict for the "m.visibility" config return if the entity should
|
||||
be public or not
|
||||
"""
|
||||
vis_type = visibility.get("type")
|
||||
if not vis_type:
|
||||
return True
|
||||
|
||||
if vis_type not in ("public", "private"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'public'/'private' visibility"
|
||||
)
|
||||
return vis_type == "public"
|
||||
|
|
|
@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
|
|||
with Measure(self.clock, "notify_interested_services"):
|
||||
self.is_processing = True
|
||||
try:
|
||||
upper_bound = self.current_max
|
||||
limit = 100
|
||||
while True:
|
||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||
upper_bound, limit
|
||||
self.current_max, limit
|
||||
)
|
||||
|
||||
if not events:
|
||||
|
@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
|
|||
)
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
|
||||
if len(events) < limit:
|
||||
break
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
||||
|
|
|
@ -13,13 +13,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
|
@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
|
|||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
account_handler = _AccountHandler(
|
||||
hs, check_user_exists=self.check_user_exists
|
||||
)
|
||||
|
||||
account_handler = ModuleApi(hs, self)
|
||||
self.password_providers = [
|
||||
module(config=config, account_handler=account_handler)
|
||||
for module, config in hs.config.password_providers
|
||||
|
@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
|
|||
logger.info("Extra password_providers: %r", self.password_providers)
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
|
||||
login_types = set()
|
||||
if self._password_enabled:
|
||||
login_types.add(LoginType.PASSWORD)
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "get_supported_login_types"):
|
||||
login_types.update(
|
||||
provider.get_supported_login_types().keys()
|
||||
)
|
||||
self._supported_login_types = frozenset(login_types)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
"""
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
protocol and handles the login flow.
|
||||
protocol and handles the User-Interactive Auth flow.
|
||||
|
||||
As a side effect, this function fills in the 'creds' key on the user's
|
||||
session with a map, which maps each auth-type (str) to the relevant
|
||||
|
@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
|
|||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
user_id = authdict["user"]
|
||||
password = authdict["password"]
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID(user_id, self.hs.hostname).to_string()
|
||||
|
||||
return self._check_password(user_id, password)
|
||||
(canonical_id, callback) = yield self.validate_login(user_id, {
|
||||
"type": LoginType.PASSWORD,
|
||||
"password": password,
|
||||
})
|
||||
defer.returnValue(canonical_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
|
@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def validate_password_login(self, user_id, password):
|
||||
"""
|
||||
Authenticates the user with their username and password.
|
||||
|
||||
Used only by the v1 login API.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
password (str): Password
|
||||
Returns:
|
||||
defer.Deferred: (str) canonical user id
|
||||
Raises:
|
||||
StoreError if there was a problem accessing the database
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
def get_access_token_for_user_id(self, user_id, device_id=None):
|
||||
"""
|
||||
Creates a new access token for the user with the given user ID.
|
||||
|
||||
|
@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
|
|||
device_id (str|None): the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
The access token for the user's session.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
|
@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
|
|||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
try:
|
||||
yield self.store.get_device(user_id, device_id)
|
||||
except StoreError:
|
||||
yield self.store.delete_access_token(access_token)
|
||||
raise StoreError(400, "Login raced against device deletion")
|
||||
|
||||
defer.returnValue(access_token)
|
||||
|
||||
|
@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""Authenticate a user against the LDAP and local databases.
|
||||
def get_supported_login_types(self):
|
||||
"""Get a the login types supported for the /login API
|
||||
|
||||
user_id is checked case insensitively against the local database, but
|
||||
will throw if there are multiple inexact matches.
|
||||
By default this is just 'm.login.password' (unless password_enabled is
|
||||
False in the config file), but password auth providers can provide
|
||||
other login types.
|
||||
|
||||
Returns:
|
||||
Iterable[str]: login types
|
||||
"""
|
||||
return self._supported_login_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_login(self, username, login_submission):
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate
|
||||
m.login.password auth types.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
username (str): username supplied by the user
|
||||
login_submission (dict): the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
Returns:
|
||||
(str) the canonical_user_id
|
||||
Deferred[str, func]: canonical user id, and optional callback
|
||||
to be called once the access token and device id are issued
|
||||
Raises:
|
||||
LoginError if login fails
|
||||
StoreError if there was a problem accessing the database
|
||||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if username.startswith('@'):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(
|
||||
username, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
login_type = login_submission.get("type")
|
||||
known_login_type = False
|
||||
|
||||
# special case to check for "password" for the check_password interface
|
||||
# for the auth providers
|
||||
password = login_submission.get("password")
|
||||
if login_type == LoginType.PASSWORD:
|
||||
if not self._password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
if not password:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
|
||||
for provider in self.password_providers:
|
||||
is_valid = yield provider.check_password(user_id, password)
|
||||
if is_valid:
|
||||
defer.returnValue(user_id)
|
||||
if (hasattr(provider, "check_password")
|
||||
and login_type == LoginType.PASSWORD):
|
||||
known_login_type = True
|
||||
is_valid = yield provider.check_password(
|
||||
qualified_user_id, password,
|
||||
)
|
||||
if is_valid:
|
||||
defer.returnValue(qualified_user_id)
|
||||
|
||||
canonical_user_id = yield self._check_local_password(user_id, password)
|
||||
if (not hasattr(provider, "get_supported_login_types")
|
||||
or not hasattr(provider, "check_auth")):
|
||||
# this password provider doesn't understand custom login types
|
||||
continue
|
||||
|
||||
if canonical_user_id:
|
||||
defer.returnValue(canonical_user_id)
|
||||
supported_login_types = provider.get_supported_login_types()
|
||||
if login_type not in supported_login_types:
|
||||
# this password provider doesn't understand this login type
|
||||
continue
|
||||
|
||||
known_login_type = True
|
||||
login_fields = supported_login_types[login_type]
|
||||
|
||||
missing_fields = []
|
||||
login_dict = {}
|
||||
for f in login_fields:
|
||||
if f not in login_submission:
|
||||
missing_fields.append(f)
|
||||
else:
|
||||
login_dict[f] = login_submission[f]
|
||||
if missing_fields:
|
||||
raise SynapseError(
|
||||
400, "Missing parameters for login type %s: %s" % (
|
||||
login_type,
|
||||
missing_fields,
|
||||
),
|
||||
)
|
||||
|
||||
result = yield provider.check_auth(
|
||||
username, login_type, login_dict,
|
||||
)
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
result = (result, None)
|
||||
defer.returnValue(result)
|
||||
|
||||
if login_type == LoginType.PASSWORD:
|
||||
known_login_type = True
|
||||
|
||||
canonical_user_id = yield self._check_local_password(
|
||||
qualified_user_id, password,
|
||||
)
|
||||
|
||||
if canonical_user_id:
|
||||
defer.returnValue((canonical_user_id, None))
|
||||
|
||||
if not known_login_type:
|
||||
raise SynapseError(400, "Unknown login type %s" % login_type)
|
||||
|
||||
# unknown username or invalid password. We raise a 403 here, but note
|
||||
# that if we're doing user-interactive login, it turns all LoginErrors
|
||||
|
@ -584,13 +661,80 @@ class AuthHandler(BaseHandler):
|
|||
if e.code == 404:
|
||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, except_access_token_id
|
||||
yield self.delete_access_tokens_for_user(
|
||||
user_id, except_token_id=except_access_token_id,
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deactivate_account(self, user_id):
|
||||
"""Deactivate a user's account
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user to be deactivated
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# FIXME: Theoretically there is a race here wherein user resets
|
||||
# password using threepid.
|
||||
yield self.delete_access_tokens_for_user(user_id)
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_token(self, access_token):
|
||||
"""Invalidate a single access token
|
||||
|
||||
Args:
|
||||
access_token (str): access token to be deleted
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
user_info = yield self.auth.get_user_by_access_token(access_token)
|
||||
yield self.store.delete_access_token(access_token)
|
||||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
yield provider.on_logged_out(
|
||||
user_id=str(user_info["user"]),
|
||||
device_id=user_info["device_id"],
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
||||
device_id=None):
|
||||
"""Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_id (str|None): access_token ID which should *not* be
|
||||
deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
tokens_and_devices = yield self.store.user_delete_access_tokens(
|
||||
user_id, except_token_id=except_token_id, device_id=device_id,
|
||||
)
|
||||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
for token, device_id in tokens_and_devices:
|
||||
yield provider.on_logged_out(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
access_token=token,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_threepid(self, user_id, medium, address, validated_at):
|
||||
# 'Canonicalise' email addresses down to lower case.
|
||||
|
@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
||||
|
||||
class _AccountHandler(object):
|
||||
"""A proxy object that gets passed to password auth providers so they
|
||||
can register new users etc if necessary.
|
||||
"""
|
||||
def __init__(self, hs, check_user_exists):
|
||||
self.hs = hs
|
||||
|
||||
self._check_user_exists = check_user_exists
|
||||
|
||||
def check_user_exists(self, user_id):
|
||||
"""Check if user exissts.
|
||||
|
||||
Returns:
|
||||
Deferred(bool)
|
||||
"""
|
||||
return self._check_user_exists(user_id)
|
||||
|
||||
def register(self, localpart):
|
||||
"""Registers a new user with given localpart
|
||||
|
||||
Returns:
|
||||
Deferred: a 2-tuple of (user_id, access_token)
|
||||
"""
|
||||
reg = self.hs.get_handlers().registration_handler
|
||||
return reg.register(localpart=localpart)
|
||||
|
|
|
@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
self.federation = hs.get_replication_layer()
|
||||
|
||||
|
@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
|
|||
else:
|
||||
raise
|
||||
|
||||
yield self.store.user_delete_access_tokens(
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
|
@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
|
|||
# Delete access tokens and e2e keys for each device. Not optimised as it is not
|
||||
# considered as part of a critical path.
|
||||
for device_id in device_ids:
|
||||
yield self.store.user_delete_access_tokens(
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
|
|
|
@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def do_auth(self, origin, event, context, auth_events):
|
||||
"""
|
||||
|
||||
Args:
|
||||
origin (str):
|
||||
event (synapse.events.FrozenEvent):
|
||||
context (synapse.events.snapshot.EventContext):
|
||||
auth_events (dict[(str, str)->str]):
|
||||
|
||||
Returns:
|
||||
defer.Deferred[None]
|
||||
"""
|
||||
# Check if we have all the auth events.
|
||||
current_state = set(e.event_id for e in auth_events.values())
|
||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||
|
@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
|
|||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
context.current_state_ids = dict(context.current_state_ids)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
if k != event_key
|
||||
})
|
||||
context.prev_state_ids = dict(context.prev_state_ids)
|
||||
context.prev_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
)
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
logger.info("Different auth after resolution: %s", different_auth)
|
||||
|
@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
|
|||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
context.current_state_ids = dict(context.current_state_ids)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
if k != event_key
|
||||
})
|
||||
context.prev_state_ids = dict(context.prev_state_ids)
|
||||
context.prev_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=auth_events)
|
||||
|
@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
|
|||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
def _update_context_for_auth_events(self, context, auth_events,
|
||||
event_key):
|
||||
"""Update the state_ids in an event context after auth event resolution
|
||||
|
||||
Args:
|
||||
context (synapse.events.snapshot.EventContext): event context
|
||||
to be updated
|
||||
|
||||
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||
context.
|
||||
|
||||
event_key ((str, str)): (type, state_key) for the current event.
|
||||
this will not be included in the current_state in the context.
|
||||
"""
|
||||
state_updates = {
|
||||
k: a.event_id for k, a in auth_events.iteritems()
|
||||
if k != event_key
|
||||
}
|
||||
context.current_state_ids = dict(context.current_state_ids)
|
||||
context.current_state_ids.update(state_updates)
|
||||
if context.delta_ids is not None:
|
||||
context.delta_ids = dict(context.delta_ids)
|
||||
context.delta_ids.update(state_updates)
|
||||
context.prev_state_ids = dict(context.prev_state_ids)
|
||||
context.prev_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.iteritems()
|
||||
})
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def construct_auth_difference(self, local_auth, remote_auth):
|
||||
""" Given a local and remote auth chain, find the differences. This
|
||||
|
|
|
@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
|
|||
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
||||
|
||||
add_room_to_group = _create_rerouter("add_room_to_group")
|
||||
update_room_in_group = _create_rerouter("update_room_in_group")
|
||||
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
||||
|
||||
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from ._base import BaseHandler
|
||||
|
@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
|
|||
target_user.localpart, new_displayname
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester)
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_avatar_url(self, target_user):
|
||||
|
@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
|
|||
target_user.localpart, new_avatar_url
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester)
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_profile_query(self, args):
|
||||
|
@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
|
|||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_join_states(self, requester):
|
||||
user = requester.user
|
||||
if not self.hs.is_mine(user):
|
||||
def _update_join_states(self, requester, target_user):
|
||||
if not self.hs.is_mine(target_user):
|
||||
return
|
||||
|
||||
yield self.ratelimit(requester)
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(
|
||||
user.to_string(),
|
||||
target_user.to_string(),
|
||||
)
|
||||
|
||||
for room_id in room_ids:
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
try:
|
||||
# Assume the user isn't a guest because we don't let guests set
|
||||
# profile or avatar data.
|
||||
# XXX why are we recreating `requester` here for each room?
|
||||
# what was wrong with the `requester` we were passed?
|
||||
requester = synapse.types.create_requester(user)
|
||||
# Assume the target_user isn't a guest,
|
||||
# because we don't let guests set profile or avatar data.
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
user,
|
||||
target_user,
|
||||
room_id,
|
||||
"join", # We treat a profile update like a join.
|
||||
ratelimit=False, # Try to hide that these events aren't atomic.
|
||||
|
|
|
@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
|||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
|
@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
|
|||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
else:
|
||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
|
|
|
@ -20,6 +20,7 @@ from ._base import BaseHandler
|
|||
from synapse.api.constants import (
|
||||
EventTypes, JoinRules,
|
||||
)
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
|
|||
if search_filter:
|
||||
# We explicitly don't bother caching searches or requests for
|
||||
# appservice specific lists.
|
||||
logger.info("Bypassing cache as search request.")
|
||||
return self._get_public_room_list(
|
||||
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||
)
|
||||
|
@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
|
|||
key = (limit, since_token, network_tuple)
|
||||
result = self.response_cache.get(key)
|
||||
if not result:
|
||||
logger.info("No cached result, calculating one.")
|
||||
result = self.response_cache.set(
|
||||
key,
|
||||
self._get_public_room_list(
|
||||
preserve_fn(self._get_public_room_list)(
|
||||
limit, since_token, network_tuple=network_tuple
|
||||
)
|
||||
)
|
||||
return result
|
||||
else:
|
||||
logger.info("Using cached deferred result.")
|
||||
return make_deferred_yieldable(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_public_room_list(self, limit=None, since_token=None,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
|
@ -184,11 +184,11 @@ class SyncHandler(object):
|
|||
if not result:
|
||||
result = self.response_cache.set(
|
||||
sync_config.request_key,
|
||||
self._wait_for_sync_for_user(
|
||||
preserve_fn(self._wait_for_sync_for_user)(
|
||||
sync_config, since_token, timeout, full_state
|
||||
)
|
||||
)
|
||||
return result
|
||||
return make_deferred_yieldable(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
|
||||
|
|
|
@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
|
|||
|
||||
for room_id in room_ids:
|
||||
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
||||
yield self._handle_intial_room(room_id)
|
||||
yield self._handle_initial_room(room_id)
|
||||
num_processed_rooms += 1
|
||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||
|
||||
|
@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
|
|||
yield self.store.update_user_directory_stream_pos(new_pos)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_intial_room(self, room_id):
|
||||
def _handle_initial_room(self, room_id):
|
||||
"""Called when we initially fill out user_directory one room at a time
|
||||
"""
|
||||
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import wrap_request_handler
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
|
||||
class AdditionalResource(Resource):
|
||||
"""Resource wrapper for additional_resources
|
||||
|
||||
If the user has configured additional_resources, we need to wrap the
|
||||
handler class with a Resource so that we can map it into the resource tree.
|
||||
|
||||
This class is also where we wrap the request handler with logging, metrics,
|
||||
and exception handling.
|
||||
"""
|
||||
def __init__(self, hs, handler):
|
||||
"""Initialise AdditionalResource
|
||||
|
||||
The ``handler`` should return a deferred which completes when it has
|
||||
done handling the request. It should write a response with
|
||||
``request.write()``, and call ``request.finish()``.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): homeserver
|
||||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||
function to be called to handle the request.
|
||||
"""
|
||||
Resource.__init__(self)
|
||||
self._handler = handler
|
||||
|
||||
# these are required by the request_handler wrapper
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render(self, request):
|
||||
self._async_render(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_request_handler
|
||||
def _async_render(self, request):
|
||||
return self._handler(request)
|
|
@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
|
|||
from synapse.api.errors import (
|
||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||
)
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util import logcontext
|
||||
import synapse.metrics
|
||||
from synapse.http.endpoint import SpiderEndpoint
|
||||
|
@ -114,43 +114,73 @@ class SimpleHttpClient(object):
|
|||
raise e
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_json(self, uri, args={}):
|
||||
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||
"""
|
||||
Args:
|
||||
uri (str):
|
||||
args (dict[str, str|List[str]]): query params
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
"""
|
||||
|
||||
# TODO: Do we ever want to log message contents?
|
||||
logger.debug("post_urlencoded_get_json args: %s", args)
|
||||
|
||||
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
|
||||
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/x-www-form-urlencoded"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
"POST",
|
||||
uri.encode("ascii"),
|
||||
headers=Headers({
|
||||
b"Content-Type": [b"application/x-www-form-urlencoded"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
}),
|
||||
headers=Headers(actual_headers),
|
||||
bodyProducer=FileBodyProducer(StringIO(query_bytes))
|
||||
)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json_get_json(self, uri, post_json):
|
||||
def post_json_get_json(self, uri, post_json, headers=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
uri (str):
|
||||
post_json (object):
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
"""
|
||||
json_str = encode_canonical_json(post_json)
|
||||
|
||||
logger.debug("HTTP POST %s -> %s", json_str, uri)
|
||||
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/json"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
"POST",
|
||||
uri.encode("ascii"),
|
||||
headers=Headers({
|
||||
b"Content-Type": [b"application/json"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
}),
|
||||
headers=Headers(actual_headers),
|
||||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||
)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
|
@ -160,7 +190,7 @@ class SimpleHttpClient(object):
|
|||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, uri, args={}):
|
||||
def get_json(self, uri, args={}, headers=None):
|
||||
""" Gets some json from the given URI.
|
||||
|
||||
Args:
|
||||
|
@ -169,6 +199,8 @@ class SimpleHttpClient(object):
|
|||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
|
@ -177,13 +209,13 @@ class SimpleHttpClient(object):
|
|||
error message.
|
||||
"""
|
||||
try:
|
||||
body = yield self.get_raw(uri, args)
|
||||
body = yield self.get_raw(uri, args, headers=headers)
|
||||
defer.returnValue(json.loads(body))
|
||||
except CodeMessageException as e:
|
||||
raise self._exceptionFromFailedRequest(e.code, e.msg)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, uri, json_body, args={}):
|
||||
def put_json(self, uri, json_body, args={}, headers=None):
|
||||
""" Puts some json to the given URI.
|
||||
|
||||
Args:
|
||||
|
@ -193,6 +225,8 @@ class SimpleHttpClient(object):
|
|||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
|
@ -205,17 +239,21 @@ class SimpleHttpClient(object):
|
|||
|
||||
json_str = encode_canonical_json(json_body)
|
||||
|
||||
actual_headers = {
|
||||
b"Content-Type": [b"application/json"],
|
||||
b"User-Agent": [self.user_agent],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
"PUT",
|
||||
uri.encode("ascii"),
|
||||
headers=Headers({
|
||||
b"User-Agent": [self.user_agent],
|
||||
"Content-Type": ["application/json"]
|
||||
}),
|
||||
headers=Headers(actual_headers),
|
||||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||
)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
|
@ -226,7 +264,7 @@ class SimpleHttpClient(object):
|
|||
raise CodeMessageException(response.code, body)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_raw(self, uri, args={}):
|
||||
def get_raw(self, uri, args={}, headers=None):
|
||||
""" Gets raw text from the given URI.
|
||||
|
||||
Args:
|
||||
|
@ -235,6 +273,8 @@ class SimpleHttpClient(object):
|
|||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body at text.
|
||||
|
@ -246,15 +286,19 @@ class SimpleHttpClient(object):
|
|||
query_bytes = urllib.urlencode(args, True)
|
||||
uri = "%s?%s" % (uri, query_bytes)
|
||||
|
||||
actual_headers = {
|
||||
b"User-Agent": [self.user_agent],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
"GET",
|
||||
uri.encode("ascii"),
|
||||
headers=Headers({
|
||||
b"User-Agent": [self.user_agent],
|
||||
})
|
||||
headers=Headers(actual_headers),
|
||||
)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(body)
|
||||
|
@ -274,27 +318,33 @@ class SimpleHttpClient(object):
|
|||
# The two should be factored out.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, url, output_stream, max_size=None):
|
||||
def get_file(self, url, output_stream, max_size=None, headers=None):
|
||||
"""GETs a file from a given URL
|
||||
Args:
|
||||
url (str): The URL to GET
|
||||
output_stream (file): File to write the response body to.
|
||||
headers (dict[str, List[str]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
A (int,dict,string,int) tuple of the file length, dict of the response
|
||||
headers, absolute URI of the response and HTTP response code.
|
||||
"""
|
||||
|
||||
actual_headers = {
|
||||
b"User-Agent": [self.user_agent],
|
||||
}
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
"GET",
|
||||
url.encode("ascii"),
|
||||
headers=Headers({
|
||||
b"User-Agent": [self.user_agent],
|
||||
})
|
||||
headers=Headers(actual_headers),
|
||||
)
|
||||
|
||||
headers = dict(response.headers.getAllRawHeaders())
|
||||
resp_headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
if 'Content-Length' in headers and headers['Content-Length'] > max_size:
|
||||
if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
|
||||
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
|
||||
raise SynapseError(
|
||||
502,
|
||||
|
@ -315,10 +365,9 @@ class SimpleHttpClient(object):
|
|||
# straight back in again
|
||||
|
||||
try:
|
||||
length = yield preserve_context_over_fn(
|
||||
_readBodyToFile,
|
||||
response, output_stream, max_size
|
||||
)
|
||||
length = yield make_deferred_yieldable(_readBodyToFile(
|
||||
response, output_stream, max_size,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to download body")
|
||||
raise SynapseError(
|
||||
|
@ -327,7 +376,9 @@ class SimpleHttpClient(object):
|
|||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
defer.returnValue((length, headers, response.request.absoluteURI, response.code))
|
||||
defer.returnValue(
|
||||
(length, resp_headers, response.request.absoluteURI, response.code),
|
||||
)
|
||||
|
||||
|
||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||
|
@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
|
|||
)
|
||||
|
||||
try:
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
defer.returnValue(body)
|
||||
except PartialDownloadError as e:
|
||||
# twisted dislikes google's response, no content length.
|
||||
|
|
|
@ -167,7 +167,8 @@ def parse_json_value_from_request(request):
|
|||
|
||||
try:
|
||||
content = simplejson.loads(content_bytes)
|
||||
except simplejson.JSONDecodeError:
|
||||
except Exception as e:
|
||||
logger.warn("Unable to parse JSON: %s", e)
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
||||
return content
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
class ModuleApi(object):
|
||||
"""A proxy object that gets passed to password auth providers so they
|
||||
can register new users etc if necessary.
|
||||
"""
|
||||
def __init__(self, hs, auth_handler):
|
||||
self.hs = hs
|
||||
|
||||
self._store = hs.get_datastore()
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = auth_handler
|
||||
|
||||
def get_user_by_req(self, req, allow_guest=False):
|
||||
"""Check the access_token provided for a request
|
||||
|
||||
Args:
|
||||
req (twisted.web.server.Request): Incoming HTTP request
|
||||
allow_guest (bool): True if guest users should be allowed. If this
|
||||
is False, and the access token is for a guest user, an
|
||||
AuthError will be thrown
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred[synapse.types.Requester]:
|
||||
the requester for this request
|
||||
Raises:
|
||||
synapse.api.errors.AuthError: if no user by that token exists,
|
||||
or the token is invalid.
|
||||
"""
|
||||
return self._auth.get_user_by_req(req, allow_guest)
|
||||
|
||||
def get_qualified_user_id(self, username):
|
||||
"""Qualify a user id, if necessary
|
||||
|
||||
Takes a user id provided by the user and adds the @ and :domain to
|
||||
qualify it, if necessary
|
||||
|
||||
Args:
|
||||
username (str): provided user id
|
||||
|
||||
Returns:
|
||||
str: qualified @user:id
|
||||
"""
|
||||
if username.startswith('@'):
|
||||
return username
|
||||
return UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
def check_user_exists(self, user_id):
|
||||
"""Check if user exists.
|
||||
|
||||
Args:
|
||||
user_id (str): Complete @user:id
|
||||
|
||||
Returns:
|
||||
Deferred[str|None]: Canonical (case-corrected) user_id, or None
|
||||
if the user is not registered.
|
||||
"""
|
||||
return self._auth_handler.check_user_exists(user_id)
|
||||
|
||||
def register(self, localpart):
|
||||
"""Registers a new user with given localpart
|
||||
|
||||
Returns:
|
||||
Deferred: a 2-tuple of (user_id, access_token)
|
||||
"""
|
||||
reg = self.hs.get_handlers().registration_handler
|
||||
return reg.register(localpart=localpart)
|
||||
|
||||
def invalidate_access_token(self, access_token):
|
||||
"""Invalidate an access token for a user
|
||||
|
||||
Args:
|
||||
access_token(str): access token
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred - resolves once the access token
|
||||
has been removed.
|
||||
|
||||
Raises:
|
||||
synapse.api.errors.AuthError: the access token is invalid
|
||||
"""
|
||||
|
||||
return self._auth_handler.delete_access_token(access_token)
|
||||
|
||||
def run_db_interaction(self, desc, func, *args, **kwargs):
|
||||
"""Run a function with a database connection
|
||||
|
||||
Args:
|
||||
desc (str): description for the transaction, for metrics etc
|
||||
func (func): function to be run. Passed a database cursor object
|
||||
as well as *args and **kwargs
|
||||
*args: positional args to be passed to func
|
||||
**kwargs: named args to be passed to func
|
||||
|
||||
Returns:
|
||||
Deferred[object]: result of func
|
||||
"""
|
||||
return self._store.runInteraction(desc, func, *args, **kwargs)
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(hs)
|
||||
super(BaseSlavedStore, self).__init__(db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = SlavedIdTracker(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
|
|
|
@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(target_user_id)
|
||||
yield self.store.user_delete_threepids(target_user_id)
|
||||
yield self.store.user_set_password_hash(target_user_id, None)
|
||||
|
||||
yield self._auth_handler.deactivate_account(target_user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
|
|
@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
|
||||
class LoginRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login$")
|
||||
PASS_TYPE = "m.login.password"
|
||||
SAML2_TYPE = "m.login.saml2"
|
||||
CAS_TYPE = "m.login.cas"
|
||||
TOKEN_TYPE = "m.login.token"
|
||||
|
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(LoginRestServlet, self).__init__(hs)
|
||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||
self.password_enabled = hs.config.password_enabled
|
||||
self.saml2_enabled = hs.config.saml2_enabled
|
||||
self.jwt_enabled = hs.config.jwt_enabled
|
||||
self.jwt_secret = hs.config.jwt_secret
|
||||
|
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
# fall back to the fallback API if they don't understand one of the
|
||||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
if self.password_enabled:
|
||||
flows.append({"type": LoginRestServlet.PASS_TYPE})
|
||||
|
||||
flows.extend((
|
||||
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||
))
|
||||
|
||||
return (200, {"flows": flows})
|
||||
|
||||
|
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||
if not self.password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
|
||||
result = yield self.do_password_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
elif self.saml2_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.SAML2_TYPE):
|
||||
if self.saml2_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.SAML2_TYPE):
|
||||
relay_state = ""
|
||||
if "relay_state" in login_submission:
|
||||
relay_state = "&RelayState=" + urllib.quote(
|
||||
|
@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
result = yield self.do_token_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
else:
|
||||
raise SynapseError(400, "Bad login type.")
|
||||
result = yield self._do_other_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_password_login(self, login_submission):
|
||||
if "password" not in login_submission:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
def _do_other_login(self, login_submission):
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
||||
Args:
|
||||
login_submission:
|
||||
|
||||
Returns:
|
||||
(int, object): HTTP code/response
|
||||
"""
|
||||
# Log the request we got, but only certain fields to minimise the chance of
|
||||
# logging someone's password (even if they accidentally put it in the wrong
|
||||
# field)
|
||||
logger.info(
|
||||
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
|
||||
login_submission.get('identifier'),
|
||||
login_submission.get('medium'),
|
||||
login_submission.get('address'),
|
||||
login_submission.get('user'),
|
||||
)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
if "identifier" not in login_submission:
|
||||
|
@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
if "user" not in identifier:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
user_id = identifier["user"]
|
||||
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID(
|
||||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.auth_handler
|
||||
user_id = yield auth_handler.validate_password_login(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"],
|
||||
canonical_user_id, callback = yield auth_handler.validate_login(
|
||||
identifier["user"],
|
||||
login_submission,
|
||||
)
|
||||
|
||||
device_id = yield self._register_device(
|
||||
canonical_user_id, login_submission,
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
canonical_user_id, device_id,
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": canonical_user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
if callback is not None:
|
||||
yield callback(result)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
|
@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
|
||||
result = {
|
||||
|
|
|
@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(LogoutRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
access_token = get_access_token_from_request(request)
|
||||
yield self.store.delete_access_token(access_token)
|
||||
yield self._auth_handler.delete_access_token(access_token)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(LogoutAllRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
yield self.store.user_delete_access_tokens(user_id)
|
||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
|
|
@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
localpart=user.lower(),
|
||||
password=password,
|
||||
admin=bool(admin),
|
||||
)
|
||||
|
|
|
@ -13,22 +13,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import LoginError, SynapseError, Codes
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request
|
||||
RestServlet, assert_params_in_request,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__()
|
||||
|
@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# if the caller provides an access token, it ought to be valid.
|
||||
requester = None
|
||||
if has_access_token(request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
) # type: synapse.types.Requester
|
||||
|
||||
# allow ASes to dectivate their own users
|
||||
if requester and requester.app_service:
|
||||
yield self.auth_handler.deactivate_account(
|
||||
requester.user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
requester = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
user_id = result[LoginType.PASSWORD]
|
||||
# if using password, they should also be logged in
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
if user_id != result[LoginType.PASSWORD]:
|
||||
if requester is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Deactivate account requires an access_token",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
if requester.user.to_string() != user_id:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(user_id)
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
yield self.auth_handler.deactivate_account(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
@ -373,6 +382,20 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class WhoamiRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/whoami$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(WhoamiRestServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
defer.returnValue((200, {'user_id': requester.user.to_string()}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
|
@ -382,3 +405,4 @@ def register_servlets(hs, http_server):
|
|||
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||
WhoamiRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||
key which lists the device_ids to delete. Requires user interactive auth.
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
|
||||
PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeleteDevicesRestServlet, self).__init__()
|
||||
|
@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
|
||||
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
|
||||
|
@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
# check that the UI auth matched the access token
|
||||
user_id = result[constants.LoginType.PASSWORD]
|
||||
if user_id != requester.user.to_string():
|
||||
raise errors.AuthError(403, "Invalid auth")
|
||||
|
||||
yield self.device_handler.delete_device(user_id, device_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -39,20 +39,23 @@ class GroupServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
|
||||
group_description = yield self.groups_handler.get_group_profile(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, group_description))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
yield self.groups_handler.update_group_profile(
|
||||
group_id, user_id, content,
|
||||
group_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, get_group_summary))
|
||||
|
||||
|
@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_room(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
|
@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_room(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_category(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
|
@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_category(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_category(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
|
@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_categories(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_role(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
|
@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_role(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_role(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
|
@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_roles(
|
||||
group_id, user_id,
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
|
||||
result = yield self.groups_handler.get_invited_users_in_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
# TODO: Create group on remote server
|
||||
content = parse_json_object_from_request(request)
|
||||
localpart = content.pop("localpart")
|
||||
group_id = GroupID(localpart, self.server_name).to_string()
|
||||
|
||||
result = yield self.groups_handler.create_group(group_id, user_id, content)
|
||||
result = yield self.groups_handler.create_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.add_room_to_group(
|
||||
group_id, user_id, room_id, content,
|
||||
group_id, requester_user_id, room_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -447,10 +460,37 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.remove_room_from_group(
|
||||
group_id, user_id, room_id,
|
||||
group_id, requester_user_id, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminRoomsConfigServlet(RestServlet):
|
||||
"""Update the config of a room in a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
|
||||
"/config/(?P<config_key>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminRoomsConfigServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, room_id, config_key):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.update_room_in_group(
|
||||
group_id, requester_user_id, room_id, config_key, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -685,9 +725,9 @@ class GroupsForUserServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_joined_groups(user_id)
|
||||
result = yield self.groups_handler.get_joined_groups(requester_user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -700,6 +740,7 @@ def register_servlets(hs, http_server):
|
|||
GroupRoomServlet(hs).register(http_server)
|
||||
GroupCreateServlet(hs).register(http_server)
|
||||
GroupAdminRoomsServlet(hs).register(http_server)
|
||||
GroupAdminRoomsConfigServlet(hs).register(http_server)
|
||||
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||
GroupSelfLeaveServlet(hs).register(http_server)
|
||||
|
|
|
@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
|
|||
},
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||
releases=())
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
|
|||
} } } } } }
|
||||
"""
|
||||
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/keys/query$",
|
||||
releases=()
|
||||
)
|
||||
PATTERNS = client_v2_patterns("/keys/query$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
|
|||
200 OK
|
||||
{ "changed": ["@foo:example.com"] }
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/keys/changes$",
|
||||
releases=()
|
||||
)
|
||||
PATTERNS = client_v2_patterns("/keys/changes$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||
} } } }
|
||||
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/keys/claim$",
|
||||
releases=()
|
||||
)
|
||||
PATTERNS = client_v2_patterns("/keys/claim$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(OneTimeKeyServlet, self).__init__()
|
||||
|
|
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class NotificationsServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/notifications$", releases=())
|
||||
PATTERNS = client_v2_patterns("/notifications$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(NotificationsServlet, self).__init__()
|
||||
|
|
|
@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
|
|||
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||
# fallback to 'username' if they gave one.
|
||||
desired_username = body.get("user", desired_username)
|
||||
|
||||
# XXX we should check that desired_username is valid. Currently
|
||||
# we give appservices carte blanche for any insanity in mxids,
|
||||
# because the IRC bridges rely on being able to register stupid
|
||||
# IDs.
|
||||
|
||||
access_token = get_access_token_from_request(request)
|
||||
|
||||
if isinstance(desired_username, basestring):
|
||||
|
@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
|
|||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
# for either shared secret or regular registration, downcase the
|
||||
# provided username before attempting to register it. This should mean
|
||||
# that people who try to register with upper-case in their usernames
|
||||
# don't get a nasty surprise. (Note that we treat username
|
||||
# case-insenstively in login, so they are free to carry on imagining
|
||||
# that their username is CrAzYh4cKeR if that keeps them happy)
|
||||
if desired_username is not None:
|
||||
desired_username = desired_username.lower()
|
||||
|
||||
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||
if 'mac' in body:
|
||||
# FIXME: Should we really be determining if this is shared secret
|
||||
|
@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
|
|||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
|
||||
if desired_username is not None:
|
||||
desired_username = desired_username.lower()
|
||||
|
||||
(registered_user_id, _) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
|
@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
|
|||
def _do_shared_secret_registration(self, username, password, body):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
if not username:
|
||||
raise SynapseError(
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
user = username.encode("utf-8")
|
||||
# use the username from the original request rather than the
|
||||
# downcased one in `username` for the mac calculation
|
||||
user = body["username"].encode("utf-8")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(body["mac"])
|
||||
|
||||
# FIXME this is different to the /v1/register endpoint, which
|
||||
# includes the password and admin flag in the hashed text. Why are
|
||||
# these different?
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
|
@ -557,25 +584,28 @@ class RegisterRestServlet(RestServlet):
|
|||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) params: registration parameters, from which we pull
|
||||
device_id and initial_device_name
|
||||
device_id, initial_device_name and inhibit_login
|
||||
Returns:
|
||||
defer.Deferred: (object) dictionary for response from /register
|
||||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token = (
|
||||
yield self.auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
})
|
||||
}
|
||||
if not params.get("inhibit_login", False):
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token = (
|
||||
yield self.auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
)
|
||||
)
|
||||
|
||||
result.update({
|
||||
"access_token": access_token,
|
||||
"device_id": device_id,
|
||||
})
|
||||
defer.returnValue(result)
|
||||
|
||||
def _register_device(self, user_id, params):
|
||||
"""Register a device for a user.
|
||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||
class SendToDeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False
|
||||
v2_alpha=False
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ThirdPartyProtocolsServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
|
||||
PATTERNS = client_v2_patterns("/thirdparty/protocols")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyProtocolsServlet, self).__init__()
|
||||
|
@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
|
|||
|
||||
|
||||
class ThirdPartyProtocolServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
|
||||
releases=())
|
||||
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyProtocolServlet, self).__init__()
|
||||
|
@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
|
|||
|
||||
|
||||
class ThirdPartyUserServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyUserServlet, self).__init__()
|
||||
|
@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
|
|||
|
||||
|
||||
class ThirdPartyLocationServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyLocationServlet, self).__init__()
|
||||
|
|
|
@ -20,6 +20,7 @@ from twisted.web.resource import Resource
|
|||
from synapse.api.errors import (
|
||||
SynapseError, Codes,
|
||||
)
|
||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.http.client import SpiderHttpClient
|
||||
|
@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
# simple memory cache mapping urls to OG metadata
|
||||
self.cache = ExpiringCache(
|
||||
# memory cache mapping urls to an ObservableDeferred returning
|
||||
# JSON-encoded OG metadata
|
||||
self._cache = ExpiringCache(
|
||||
cache_name="url_previews",
|
||||
clock=self.clock,
|
||||
# don't spider URLs more often than once an hour
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
)
|
||||
self.cache.start()
|
||||
|
||||
self.downloads = {}
|
||||
self._cache.start()
|
||||
|
||||
self._cleaner_loop = self.clock.looping_call(
|
||||
self._expire_url_cache_data, 10 * 1000
|
||||
|
@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
|
|||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
# XXX: we could move this into _do_preview if we wanted.
|
||||
url_tuple = urlparse.urlsplit(url)
|
||||
for entry in self.url_preview_url_blacklist:
|
||||
match = True
|
||||
|
@ -126,14 +127,42 @@ class PreviewUrlResource(Resource):
|
|||
Codes.UNKNOWN
|
||||
)
|
||||
|
||||
# first check the memory cache - good to handle all the clients on this
|
||||
# HS thundering away to preview the same URL at the same time.
|
||||
og = self.cache.get(url)
|
||||
if og:
|
||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
||||
return
|
||||
# the in-memory cache:
|
||||
# * ensures that only one request is active at a time
|
||||
# * takes load off the DB for the thundering herds
|
||||
# * also caches any failures (unlike the DB) so we don't keep
|
||||
# requesting the same endpoint
|
||||
|
||||
# then check the URL cache in the DB (which will also provide us with
|
||||
observable = self._cache.get(url)
|
||||
|
||||
if not observable:
|
||||
download = preserve_fn(self._do_preview)(
|
||||
url, requester.user, ts,
|
||||
)
|
||||
observable = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self._cache[url] = observable
|
||||
else:
|
||||
logger.info("Returning cached response")
|
||||
|
||||
og = yield make_deferred_yieldable(observable.observe())
|
||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_preview(self, url, user, ts):
|
||||
"""Check the db, and download the URL and build a preview
|
||||
|
||||
Args:
|
||||
url (str):
|
||||
user (str):
|
||||
ts (int):
|
||||
|
||||
Returns:
|
||||
Deferred[str]: json-encoded og data
|
||||
"""
|
||||
# check the URL cache in the DB (which will also provide us with
|
||||
# historical previews, if we have any)
|
||||
cache_result = yield self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
|
@ -141,32 +170,10 @@ class PreviewUrlResource(Resource):
|
|||
cache_result["expires_ts"] > ts and
|
||||
cache_result["response_code"] / 100 == 2
|
||||
):
|
||||
respond_with_json_bytes(
|
||||
request, 200, cache_result["og"].encode('utf-8'),
|
||||
send_cors=True
|
||||
)
|
||||
defer.returnValue(cache_result["og"])
|
||||
return
|
||||
|
||||
# Ensure only one download for a given URL is active at a time
|
||||
download = self.downloads.get(url)
|
||||
if download is None:
|
||||
download = self._download_url(url, requester.user)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.downloads[url] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[url]
|
||||
return media_info
|
||||
media_info = yield download.observe()
|
||||
|
||||
# FIXME: we should probably update our cache now anyway, so that
|
||||
# even if the OG calculation raises, we don't keep hammering on the
|
||||
# remote server. For now, leave it uncached to aid debugging OG
|
||||
# calculation problems
|
||||
media_info = yield self._download_url(url, user)
|
||||
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
|
@ -212,7 +219,7 @@ class PreviewUrlResource(Resource):
|
|||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
_rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
_rebase_url(og['og:image'], media_info['uri']), user
|
||||
)
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
|
@ -239,8 +246,7 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
logger.debug("Calculated OG for %s as %s" % (url, og))
|
||||
|
||||
# store OG in ephemeral in-memory cache
|
||||
self.cache[url] = og
|
||||
jsonog = json.dumps(og)
|
||||
|
||||
# store OG in history-aware DB cache
|
||||
yield self.store.store_url_cache(
|
||||
|
@ -248,12 +254,12 @@ class PreviewUrlResource(Resource):
|
|||
media_info["response_code"],
|
||||
media_info["etag"],
|
||||
media_info["expires"] + media_info["created_ts"],
|
||||
json.dumps(og),
|
||||
jsonog,
|
||||
media_info["filesystem_id"],
|
||||
media_info["created_ts"],
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
||||
defer.returnValue(jsonog)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_url(self, url, user):
|
||||
|
@ -520,7 +526,14 @@ def _calc_og(tree, media_uri):
|
|||
from lxml import etree
|
||||
|
||||
TAGS_TO_REMOVE = (
|
||||
"header", "nav", "aside", "footer", "script", "style", etree.Comment
|
||||
"header",
|
||||
"nav",
|
||||
"aside",
|
||||
"footer",
|
||||
"script",
|
||||
"noscript",
|
||||
"style",
|
||||
etree.Comment
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
|
|
|
@ -268,7 +268,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||
|
||||
super(DataStore, self).__init__(hs)
|
||||
super(DataStore, self).__init__(db_conn, hs)
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
active_on_startup = self._presence_on_startup
|
||||
|
|
|
@ -162,7 +162,7 @@ class PerformanceCounters(object):
|
|||
class SQLBaseStore(object):
|
||||
_TXN_ID = 0
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self._db_pool = hs.get_db_pool()
|
||||
|
|
|
@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore):
|
|||
"get_account_data_for_user", get_account_data_for_user_txn
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
@cachedInlineCallbacks(num_args=2, max_entries=5000)
|
||||
def get_global_account_data_by_type_for_user(self, data_type, user_id):
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
@ -48,8 +48,8 @@ def _make_exclusive_regex(services_cache):
|
|||
|
||||
class ApplicationServiceStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(ApplicationServiceStore, self).__init__(db_conn, hs)
|
||||
self.hostname = hs.hostname
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname,
|
||||
|
@ -173,8 +173,8 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
|
||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceTransactionStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_appservices_by_state(self, state):
|
||||
|
|
|
@ -80,8 +80,8 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
||||
def __init__(self, hs):
|
||||
super(BackgroundUpdateStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(BackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
self._background_update_performance = {}
|
||||
self._background_update_queue = []
|
||||
self._background_update_handlers = {}
|
||||
|
|
|
@ -32,14 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
|
|||
|
||||
|
||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||
)
|
||||
|
||||
super(ClientIpStore, self).__init__(hs)
|
||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"user_ips_device_index",
|
||||
|
|
|
@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
|
|||
class DeviceInboxStore(BackgroundUpdateStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeviceInboxStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"device_inbox_stream_index",
|
||||
|
|
|
@ -26,8 +26,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DeviceStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(DeviceStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
|
|
|
@ -39,8 +39,8 @@ class EventFederationStore(SQLBaseStore):
|
|||
|
||||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventFederationStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(EventFederationStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_AUTH_STATE_ONLY,
|
||||
|
|
|
@ -65,8 +65,8 @@ def _deserialize_action(actions, is_highlight):
|
|||
class EventPushActionsStore(SQLBaseStore):
|
||||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventPushActionsStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(EventPushActionsStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
self.EPA_HIGHLIGHT_INDEX,
|
||||
|
|
|
@ -197,8 +197,8 @@ class EventsStore(SQLBaseStore):
|
|||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventsStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(EventsStore, self).__init__(db_conn, hs)
|
||||
self._clock = hs.get_clock()
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
||||
|
|
|
@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore):
|
|||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
retcols=("name", "short_description", "long_description", "avatar_url",),
|
||||
retcols=(
|
||||
"name", "short_description", "long_description", "avatar_url", "is_public"
|
||||
),
|
||||
allow_none=True,
|
||||
desc="is_user_in_group",
|
||||
)
|
||||
|
@ -52,7 +54,7 @@ class GroupServerStore(SQLBaseStore):
|
|||
return self._simple_select_list(
|
||||
table="group_users",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "is_public",),
|
||||
retcols=("user_id", "is_public", "is_admin",),
|
||||
desc="get_users_in_group",
|
||||
)
|
||||
|
||||
|
@ -855,6 +857,19 @@ class GroupServerStore(SQLBaseStore):
|
|||
desc="add_room_to_group",
|
||||
)
|
||||
|
||||
def update_room_in_group_visibility(self, group_id, room_id, is_public):
|
||||
return self._simple_update(
|
||||
table="group_rooms",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={
|
||||
"is_public": is_public,
|
||||
},
|
||||
desc="update_room_in_group_visibility",
|
||||
)
|
||||
|
||||
def remove_room_from_group(self, group_id, room_id):
|
||||
def _remove_room_from_group_txn(txn):
|
||||
self._simple_delete_txn(
|
||||
|
@ -1026,6 +1041,7 @@ class GroupServerStore(SQLBaseStore):
|
|||
"avatar_url": avatar_url,
|
||||
"short_description": short_description,
|
||||
"long_description": long_description,
|
||||
"is_public": True,
|
||||
},
|
||||
desc="create_group",
|
||||
)
|
||||
|
@ -1086,6 +1102,24 @@ class GroupServerStore(SQLBaseStore):
|
|||
desc="update_remote_attestion",
|
||||
)
|
||||
|
||||
def remove_attestation_renewal(self, group_id, user_id):
|
||||
"""Remove an attestation that we thought we should renew, but actually
|
||||
shouldn't. Ideally this would never get called as we would never
|
||||
incorrectly try and do attestations for local users on local groups.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
user_id (str)
|
||||
"""
|
||||
return self._simple_delete(
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
desc="remove_attestation_renewal",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_attestation(self, group_id, user_id):
|
||||
"""Get the attestation that proves the remote agrees that the user is
|
||||
|
|
|
@ -254,6 +254,9 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
||||
|
||||
def delete_url_cache(self, media_ids):
|
||||
if len(media_ids) == 0:
|
||||
return
|
||||
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_url_cache"
|
||||
" WHERE media_id = ?"
|
||||
|
@ -281,6 +284,9 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def delete_url_cache_media(self, media_ids):
|
||||
if len(media_ids) == 0:
|
||||
return
|
||||
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository"
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 45
|
||||
SCHEMA_VERSION = 46
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
|
|||
|
||||
If `config` is None then prepare_database will assert that no upgrade is
|
||||
necessary, *or* will create a fresh database if the database is empty.
|
||||
|
||||
Args:
|
||||
db_conn:
|
||||
database_engine:
|
||||
config (synapse.config.homeserver.HomeServerConfig|None):
|
||||
application config, or None if we are connecting to an existing
|
||||
database which we expect to be configured already
|
||||
"""
|
||||
try:
|
||||
cur = db_conn.cursor()
|
||||
|
@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
|
|||
else:
|
||||
_setup_new_database(cur, database_engine)
|
||||
|
||||
# check if any of our configured dynamic modules want a database
|
||||
if config is not None:
|
||||
_apply_module_schemas(cur, database_engine, config)
|
||||
|
||||
cur.close()
|
||||
db_conn.commit()
|
||||
except Exception:
|
||||
|
@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||
)
|
||||
|
||||
|
||||
def _apply_module_schemas(txn, database_engine, config):
|
||||
"""Apply the module schemas for the dynamic modules, if any
|
||||
|
||||
Args:
|
||||
cur: database cursor
|
||||
database_engine: synapse database engine class
|
||||
config (synapse.config.homeserver.HomeServerConfig):
|
||||
application config
|
||||
"""
|
||||
for (mod, _config) in config.password_providers:
|
||||
if not hasattr(mod, 'get_db_schema_files'):
|
||||
continue
|
||||
modname = ".".join((mod.__module__, mod.__name__))
|
||||
_apply_module_schema_files(
|
||||
txn, database_engine, modname, mod.get_db_schema_files(),
|
||||
)
|
||||
|
||||
|
||||
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
|
||||
"""Apply the module schemas for a single module
|
||||
|
||||
Args:
|
||||
cur: database cursor
|
||||
database_engine: synapse database engine class
|
||||
modname (str): fully qualified name of the module
|
||||
names_and_streams (Iterable[(str, file)]): the names and streams of
|
||||
schemas to be applied
|
||||
"""
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
|
||||
),
|
||||
(modname,)
|
||||
)
|
||||
applied_deltas = set(d for d, in cur)
|
||||
for (name, stream) in names_and_streams:
|
||||
if name in applied_deltas:
|
||||
continue
|
||||
|
||||
root_name, ext = os.path.splitext(name)
|
||||
if ext != '.sql':
|
||||
raise PrepareDatabaseException(
|
||||
"only .sql files are currently supported for module schemas",
|
||||
)
|
||||
|
||||
logger.info("applying schema %s for %s", name, modname)
|
||||
for statement in get_statements(stream):
|
||||
cur.execute(statement)
|
||||
|
||||
# Mark as done.
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO applied_module_schemas (module_name, file)"
|
||||
" VALUES (?,?)",
|
||||
),
|
||||
(modname, name)
|
||||
)
|
||||
|
||||
|
||||
def get_statements(f):
|
||||
statement_buffer = ""
|
||||
in_comment = False # If we're in a /* ... */ style comment
|
||||
|
|
|
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ReceiptsStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(ReceiptsStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||
|
|
|
@ -24,8 +24,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
|||
|
||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegistrationStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
columns=["user_id", "device_id"],
|
||||
)
|
||||
|
||||
self.register_background_index_update(
|
||||
"refresh_tokens_device_index",
|
||||
index_name="refresh_tokens_device_id",
|
||||
table="refresh_tokens",
|
||||
columns=["user_id", "device_id"],
|
||||
)
|
||||
# we no longer use refresh tokens, but it's possible that some people
|
||||
# might have a background update queued to build this index. Just
|
||||
# clear the background update.
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update("refresh_tokens_device_index")
|
||||
defer.returnValue(1)
|
||||
self.register_background_update_handler(
|
||||
"refresh_tokens_device_index", noop_update)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||
|
@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
)
|
||||
|
||||
if create_profile_with_localpart:
|
||||
# set a default displayname serverside to avoid ugly race
|
||||
# between auto-joins and clients trying to set displaynames
|
||||
txn.execute(
|
||||
"INSERT INTO profiles(user_id) VALUES (?)",
|
||||
(create_profile_with_localpart,)
|
||||
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||
(create_profile_with_localpart, create_profile_with_localpart)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
|
@ -236,12 +241,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
"user_set_password_hash", user_set_password_hash_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||
device_id=None,
|
||||
delete_refresh_tokens=False):
|
||||
device_id=None):
|
||||
"""
|
||||
Invalidate access/refresh tokens belonging to a user
|
||||
Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
|
@ -250,10 +253,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
device_id (str|None): ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
delete_refresh_tokens (bool): True to delete refresh tokens as
|
||||
well as access tokens.
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
|
||||
and device IDs
|
||||
"""
|
||||
def f(txn):
|
||||
keyvalues = {
|
||||
|
@ -262,13 +264,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
if delete_refresh_tokens:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues=keyvalues,
|
||||
)
|
||||
|
||||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values = [v for _, v in items]
|
||||
|
@ -277,14 +272,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token FROM access_tokens WHERE %s" % where_clause,
|
||||
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
tokens_and_devices = [(r[0], r[1]) for r in txn]
|
||||
|
||||
for row in rows:
|
||||
for token, _ in tokens_and_devices:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (row["token"],)
|
||||
txn, self.get_user_by_access_token, (token,)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
|
@ -292,7 +287,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
values
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
return tokens_and_devices
|
||||
|
||||
return self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
)
|
||||
|
||||
|
|
|
@ -49,8 +49,8 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
|||
|
||||
|
||||
class RoomMemberStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||
)
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('refresh_tokens_device_index', '{}');
|
|
@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
|
|||
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
|
||||
|
||||
|
||||
-- Make sure that we popualte the table initially
|
||||
-- Make sure that we populate the table initially
|
||||
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,4 +13,5 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
|
||||
/* we no longer use (or create) the refresh_tokens table */
|
||||
DROP TABLE IF EXISTS refresh_tokens;
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE groups_new (
|
||||
group_id TEXT NOT NULL,
|
||||
name TEXT, -- the display name of the room
|
||||
avatar_url TEXT,
|
||||
short_description TEXT,
|
||||
long_description TEXT,
|
||||
is_public BOOL NOT NULL -- whether non-members can access group APIs
|
||||
);
|
||||
|
||||
-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
|
||||
INSERT INTO groups_new
|
||||
SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
|
||||
|
||||
DROP TABLE groups;
|
||||
ALTER TABLE groups_new RENAME TO groups;
|
||||
|
||||
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2015, 2016 OpenMarket Ltd
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,9 +13,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens(
|
||||
id INTEGER PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (token)
|
||||
);
|
||||
-- this is just embarassing :|
|
||||
ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
|
||||
|
||||
-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
|
||||
-- so is hopefully not going to block anyone else for that long...
|
||||
CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
|
||||
CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
|
||||
DROP INDEX users_in_pubic_room_room_idx;
|
||||
DROP INDEX users_in_pubic_room_user_idx;
|
|
@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
|
|||
file TEXT NOT NULL,
|
||||
UNIQUE(version, file)
|
||||
);
|
||||
|
||||
-- a list of schema files we have loaded on behalf of dynamic modules
|
||||
CREATE TABLE IF NOT EXISTS applied_module_schemas(
|
||||
module_name TEXT NOT NULL,
|
||||
file TEXT NOT NULL,
|
||||
UNIQUE(module_name, file)
|
||||
);
|
||||
|
|
|
@ -33,8 +33,8 @@ class SearchStore(BackgroundUpdateStore):
|
|||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SearchStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SearchStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
||||
)
|
||||
|
|
|
@ -63,8 +63,8 @@ class StateStore(SQLBaseStore):
|
|||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(StateStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||
self._background_deduplicate_state,
|
||||
|
|
|
@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
|
|||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
super(TransactionStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(TransactionStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
user_ids (list(str)): Users to add
|
||||
"""
|
||||
yield self._simple_insert_many(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
values=[
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
@ -219,7 +219,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def update_user_in_public_user_list(self, user_id, room_id):
|
||||
yield self._simple_update_one(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"room_id": room_id},
|
||||
desc="update_user_in_public_user_list",
|
||||
|
@ -240,7 +240,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
txn.call_after(
|
||||
|
@ -256,7 +256,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def remove_from_user_in_public_room(self, user_id):
|
||||
yield self._simple_delete(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="remove_from_user_in_public_room",
|
||||
)
|
||||
|
@ -267,7 +267,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
in the given room_id
|
||||
"""
|
||||
return self._simple_select_onecol(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
desc="get_users_in_public_due_to_room",
|
||||
|
@ -286,7 +286,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
user_ids_pub = yield self._simple_select_onecol(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
|
@ -514,7 +514,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
def _delete_all_from_user_dir_txn(txn):
|
||||
txn.execute("DELETE FROM user_directory")
|
||||
txn.execute("DELETE FROM user_directory_search")
|
||||
txn.execute("DELETE FROM users_in_pubic_room")
|
||||
txn.execute("DELETE FROM users_in_public_rooms")
|
||||
txn.execute("DELETE FROM users_who_share_rooms")
|
||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||
txn.call_after(self.get_user_in_public_room.invalidate_all)
|
||||
|
@ -537,7 +537,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
@cached()
|
||||
def get_user_in_public_room(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="users_in_pubic_room",
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("room_id",),
|
||||
allow_none=True,
|
||||
|
@ -641,7 +641,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
SELECT d.user_id, display_name, avatar_url
|
||||
FROM user_directory_search
|
||||
INNER JOIN user_directory AS d USING (user_id)
|
||||
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
LEFT JOIN (
|
||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||
WHERE user_id = ? AND share_private
|
||||
|
@ -680,7 +680,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
SELECT d.user_id, display_name, avatar_url
|
||||
FROM user_directory_search
|
||||
INNER JOIN user_directory AS d USING (user_id)
|
||||
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
LEFT JOIN (
|
||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||
WHERE user_id = ? AND share_private
|
||||
|
|
|
@ -278,8 +278,13 @@ class Limiter(object):
|
|||
if entry[0] >= self.max_count:
|
||||
new_defer = defer.Deferred()
|
||||
entry[1].append(new_defer)
|
||||
|
||||
logger.info("Waiting to acquire limiter lock for key %r", key)
|
||||
with PreserveLoggingContext():
|
||||
yield new_defer
|
||||
logger.info("Acquired limiter lock for key %r", key)
|
||||
else:
|
||||
logger.info("Acquired uncontended limiter lock for key %r", key)
|
||||
|
||||
entry[0] += 1
|
||||
|
||||
|
@ -288,16 +293,21 @@ class Limiter(object):
|
|||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Releasing limiter lock for key %r", key)
|
||||
|
||||
# We've finished executing so check if there are any things
|
||||
# blocked waiting to execute and start one of them
|
||||
entry[0] -= 1
|
||||
try:
|
||||
entry[1].pop(0).callback(None)
|
||||
except IndexError:
|
||||
# If nothing else is executing for this key then remove it
|
||||
# from the map
|
||||
if entry[0] == 0:
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
||||
if entry[1]:
|
||||
next_def = entry[1].pop(0)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
next_def.callback(None)
|
||||
elif entry[0] == 0:
|
||||
# We were the last thing for this key: remove it from the
|
||||
# map.
|
||||
del self.key_to_defer[key]
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
|
|
|
@ -53,7 +53,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
]
|
||||
self.mock_as_api.push = Mock()
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
||||
|
@ -75,7 +78,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
]
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_as_api.query_user.assert_called_once_with(
|
||||
services[0], user_id
|
||||
|
@ -98,7 +104,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
]
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.assertFalse(
|
||||
self.mock_as_api.query_user.called,
|
||||
|
|
|
@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
||||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
||||
# must be done after inserts
|
||||
self.store = ApplicationServiceStore(hs)
|
||||
self.store = ApplicationServiceStore(None, hs)
|
||||
|
||||
def tearDown(self):
|
||||
# TODO: suboptimal that we need to create files for tests!
|
||||
|
@ -150,7 +150,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.as_yaml_files = []
|
||||
|
||||
self.store = TestTransactionStore(hs)
|
||||
self.store = TestTransactionStore(None, hs)
|
||||
|
||||
def _add_service(self, url, as_token, id):
|
||||
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
|
||||
|
@ -420,8 +420,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
class TestTransactionStore(ApplicationServiceTransactionStore,
|
||||
ApplicationServiceStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(TestTransactionStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
super(TestTransactionStore, self).__init__(db_conn, hs)
|
||||
|
||||
|
||||
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||
|
@ -458,7 +458,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
ApplicationServiceStore(hs)
|
||||
ApplicationServiceStore(None, hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_duplicate_ids(self):
|
||||
|
@ -477,7 +477,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
ApplicationServiceStore(None, hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, e.message)
|
||||
|
@ -501,7 +501,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
ApplicationServiceStore(None, hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, e.message)
|
||||
|
|
|
@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
database_engine=create_engine(config.database_config),
|
||||
)
|
||||
|
||||
self.datastore = SQLBaseStore(hs)
|
||||
self.datastore = SQLBaseStore(None, hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_insert_1col(self):
|
||||
|
|
|
@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
hs = yield setup_test_homeserver()
|
||||
|
||||
self.store = DirectoryStore(hs)
|
||||
self.store = DirectoryStore(None, hs)
|
||||
|
||||
self.room = RoomID.from_string("!abcde:test")
|
||||
self.alias = RoomAlias.from_string("#my-room:test")
|
||||
|
|
|
@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(clock=MockClock())
|
||||
|
||||
self.store = PresenceStore(hs)
|
||||
self.store = PresenceStore(None, hs)
|
||||
|
||||
self.u_apple = UserID.from_string("@apple:test")
|
||||
self.u_banana = UserID.from_string("@banana:test")
|
||||
|
|
|
@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
hs = yield setup_test_homeserver()
|
||||
|
||||
self.store = ProfileStore(hs)
|
||||
self.store = ProfileStore(None, hs)
|
||||
|
||||
self.u_frank = UserID.from_string("@frank:test")
|
||||
|
||||
|
|
|
@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
|
||||
# now delete some
|
||||
yield self.store.user_delete_access_tokens(
|
||||
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
|
||||
self.user_id, device_id=self.device_id,
|
||||
)
|
||||
|
||||
# check they were deleted
|
||||
user = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||
|
@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
self.assertEqual(self.user_id, user["name"])
|
||||
|
||||
# now delete the rest
|
||||
yield self.store.user_delete_access_tokens(
|
||||
self.user_id, delete_refresh_tokens=True)
|
||||
yield self.store.user_delete_access_tokens(self.user_id)
|
||||
|
||||
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||
self.assertIsNone(user,
|
||||
|
|
|
@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
|
|||
)
|
||||
|
||||
self.config = Mock()
|
||||
self.config.password_providers = []
|
||||
self.config.database_config = {"name": "sqlite3"}
|
||||
|
||||
def prepare(self):
|
||||
|
|
Loading…
Reference in New Issue