Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2023-08-18 11:06:02 +01:00
commit de16789d87
109 changed files with 1164 additions and 555 deletions

View File

@ -5,6 +5,9 @@ on:
- cron: 0 8 * * *
workflow_dispatch:
# NB: inputs are only present when this workflow is dispatched manually.
# (The default below is the default field value in the form to trigger
# a manual dispatch). Otherwise the inputs will evaluate to null.
inputs:
twisted_ref:
description: Commit, branch or tag to checkout from upstream Twisted.
@ -49,7 +52,7 @@ jobs:
extras: "all"
- run: |
poetry remove twisted
poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref }}
poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }}
poetry install --no-interaction --extras "all test"
- name: Remove warn_unused_ignores from mypy config
run: sed '/warn_unused_ignores = True/d' -i mypy.ini

View File

@ -1,3 +1,8 @@
# Synapse 1.90.0 (2023-08-15)
No significant changes since 1.90.0rc1.
# Synapse 1.90.0rc1 (2023-08-08)
### Features

4
Cargo.lock generated
View File

@ -132,9 +132,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.19"
version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "memchr"

View File

@ -0,0 +1 @@
Implements an admin API to lock an user without deactivating them. Based on [MSC3939](https://github.com/matrix-org/matrix-spec-proposals/pull/3939).

1
changelog.d/16010.misc Normal file
View File

@ -0,0 +1 @@
Update dehydrated devices implementation.

1
changelog.d/16052.bugfix Normal file
View File

@ -0,0 +1 @@
Fix long-standing bug where concurrent requests to change a user's push rules could cause a deadlock. Contributed by Nick @ Beeper (@fizzadar).

1
changelog.d/16061.misc Normal file
View File

@ -0,0 +1 @@
Fix database performance of read/write worker locks.

1
changelog.d/16080.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bu in `/sync` where timeout=0 does not skip caching, resulting in slow calls in cases where there are no new changes. Contributed by @PlasmaIntec.

1
changelog.d/16085.misc Normal file
View File

@ -0,0 +1 @@
Override global statement timeout when creating indexes in Postgres.

1
changelog.d/16089.misc Normal file
View File

@ -0,0 +1 @@
Fix the type annotation on `run_db_interaction` in the Module API.

1
changelog.d/16091.doc Normal file
View File

@ -0,0 +1 @@
Structured logging docs: add a link to explain the ELK stack

1
changelog.d/16092.misc Normal file
View File

@ -0,0 +1 @@
Clean-up the presence code.

View File

@ -0,0 +1 @@
Allow customising the IdP display name, icon, and brand for SAML and CAS providers (in addition to OIDC provider).

1
changelog.d/16110.misc Normal file
View File

@ -0,0 +1 @@
Run `pyupgrade` for Python 3.8+.

1
changelog.d/16112.misc Normal file
View File

@ -0,0 +1 @@
Rename pagination and purge locks and add comments to explain why they exist and how they work.

1
changelog.d/16115.misc Normal file
View File

@ -0,0 +1 @@
Attempt to fix the twisted trunk job.

1
changelog.d/16117.misc Normal file
View File

@ -0,0 +1 @@
Cache token introspection response from OIDC provider.

1
changelog.d/16123.misc Normal file
View File

@ -0,0 +1 @@
Add cache to `get_server_keys_json_for_remote`.

View File

@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path):
global CONFIG_JSON
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
try:
with open(config_path, "r") as config:
with open(config_path) as config:
syn_cmd.config = json.load(config)
try:
http_client.verbose = "on" == syn_cmd.config["verbose"]

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.90.0) stable; urgency=medium
* New Synapse release 1.90.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 15 Aug 2023 11:17:34 +0100
matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium
* New Synapse release 1.90.0rc1.

View File

@ -861,7 +861,7 @@ def generate_worker_files(
# Then a worker config file
convert(
"/conf/worker.yaml.j2",
"/conf/workers/{name}.yaml".format(name=worker_name),
f"/conf/workers/{worker_name}.yaml",
**worker_config,
worker_log_config_filepath=log_config_filepath,
using_unix_sockets=using_unix_sockets,

View File

@ -82,7 +82,7 @@ def generate_config_from_template(
with open(filename) as handle:
value = handle.read()
else:
log("Generating a random secret for {}".format(secret))
log(f"Generating a random secret for {secret}")
value = codecs.encode(os.urandom(32), "hex").decode()
with open(filename, "w") as handle:
handle.write(value)

View File

@ -146,6 +146,7 @@ Body parameters:
- `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator,
granting them access to the Admin API, among other things.
- `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged.
- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
Note: the `password` field must also be set if both of the following are true:
- `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user)

View File

@ -3,7 +3,7 @@
A structured logging system can be useful when your logs are destined for a
machine to parse and process. By maintaining its machine-readable characteristics,
it enables more efficient searching and aggregations when consumed by software
such as the "ELK stack".
such as the [ELK stack](https://opensource.com/article/18/9/open-source-log-aggregation-tools).
Synapse's structured logging system is configured via the file that Synapse's
`log_config` config option points to. The file should include a formatter which

View File

@ -3025,6 +3025,16 @@ enable SAML login. You can either put your entire pysaml config inline using the
option, or you can specify a path to a psyaml config file with the sub-option `config_path`.
This setting has the following sub-options:
* `idp_name`: A user-facing name for this identity provider, which is used to
offer the user a choice of login mechanisms.
* `idp_icon`: An optional icon for this identity provider, which is presented
by clients and Synapse's own IdP picker page. If given, must be an
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
obtain such an MXC URI is to upload an image to an (unencrypted) room
and then copy the "url" from the source of the event.)
* `idp_brand`: An optional brand for this identity provider, allowing clients
to style the login flow according to the identity provider in question.
See the [spec](https://spec.matrix.org/latest/) for possible options here.
* `sp_config`: the configuration for the pysaml2 Service Provider. See pysaml2 docs for format of config.
Default values will be used for the `entityid` and `service` settings,
so it is not normally necessary to specify them unless you need to
@ -3176,7 +3186,7 @@ Options for each entry include:
* `idp_icon`: An optional icon for this identity provider, which is presented
by clients and Synapse's own IdP picker page. If given, must be an
MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
obtain such an MXC URI is to upload an image to an (unencrypted) room
and then copy the "url" from the source of the event.)
@ -3391,6 +3401,16 @@ Enable Central Authentication Service (CAS) for registration and login.
Has the following sub-options:
* `enabled`: Set this to true to enable authorization against a CAS server.
Defaults to false.
* `idp_name`: A user-facing name for this identity provider, which is used to
offer the user a choice of login mechanisms.
* `idp_icon`: An optional icon for this identity provider, which is presented
by clients and Synapse's own IdP picker page. If given, must be an
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
obtain such an MXC URI is to upload an image to an (unencrypted) room
and then copy the "url" from the source of the event.)
* `idp_brand`: An optional brand for this identity provider, allowing clients
to style the login flow according to the identity provider in question.
See the [spec](https://spec.matrix.org/latest/) for possible options here.
* `server_url`: The URL of the CAS authorization endpoint.
* `displayname_attribute`: The attribute of the CAS response to use as the display name.
If no name is given here, no displayname will be set.
@ -3631,6 +3651,7 @@ This option has the following sub-options:
* `prefer_local_users`: Defines whether to prefer local users in search query results.
If set to true, local users are more likely to appear above remote users when searching the
user directory. Defaults to false.
* `show_locked_users`: Defines whether to show locked users in search query results. Defaults to false.
Example configuration:
```yaml
@ -3638,6 +3659,7 @@ user_directory:
enabled: false
search_all_users: true
prefer_local_users: true
show_locked_users: true
```
---
### `user_consent`

View File

@ -45,6 +45,13 @@ warn_unused_ignores = False
disallow_untyped_defs = False
disallow_incomplete_defs = False
[mypy-synapse.util.manhole]
# This module imports something from Twisted which has a bad annotation in Twisted trunk,
# but is unannotated in Twisted's latest release. We want to type-ignore the problem
# in the twisted trunk job, even though it has no effect on normal mypy runs.
warn_unused_ignores = False
;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here:

34
poetry.lock generated
View File

@ -589,13 +589,13 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
version = "3.1.31"
version = "3.1.32"
description = "GitPython is a Python library used to interact with Git repositories"
optional = false
python-versions = ">=3.7"
files = [
{file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"},
{file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"},
{file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"},
{file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"},
]
[package.dependencies]
@ -887,17 +887,17 @@ scripts = ["click (>=6.0)", "twisted (>=16.4.0)"]
[[package]]
name = "isort"
version = "5.11.5"
version = "5.12.0"
description = "A Python utility / library to sort Python imports."
optional = false
python-versions = ">=3.7.0"
python-versions = ">=3.8.0"
files = [
{file = "isort-5.11.5-py3-none-any.whl", hash = "sha256:ba1d72fb2595a01c7895a5128f9585a5cc4b6d395f1c8d514989b9a7eb2a8746"},
{file = "isort-5.11.5.tar.gz", hash = "sha256:6be1f76a507cb2ecf16c7cf14a37e41609ca082330be4e3436a18ef74add55db"},
{file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"},
{file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"},
]
[package.extras]
colors = ["colorama (>=0.4.3,<0.5.0)"]
colors = ["colorama (>=0.4.3)"]
pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"]
plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"]
@ -2921,13 +2921,13 @@ files = [
[[package]]
name = "txredisapi"
version = "1.4.9"
version = "1.4.10"
description = "non-blocking redis client for python"
optional = true
python-versions = "*"
files = [
{file = "txredisapi-1.4.9-py3-none-any.whl", hash = "sha256:72e6ad09cc5fffe3bec2e55e5bfb74407bd357565fc212e6003f7e26ef7d8f78"},
{file = "txredisapi-1.4.9.tar.gz", hash = "sha256:c9607062d05e4d0b8ef84719eb76a3fe7d5ccd606a2acf024429da51d6e84559"},
{file = "txredisapi-1.4.10-py3-none-any.whl", hash = "sha256:0a6ea77f27f8cf092f907654f08302a97b48fa35f24e0ad99dfb74115f018161"},
{file = "txredisapi-1.4.10.tar.gz", hash = "sha256:7609a6af6ff4619a3189c0adfb86aeda789afba69eb59fc1e19ac0199e725395"},
]
[package.dependencies]
@ -2936,13 +2936,13 @@ twisted = "*"
[[package]]
name = "types-bleach"
version = "6.0.0.3"
version = "6.0.0.4"
description = "Typing stubs for bleach"
optional = false
python-versions = "*"
files = [
{file = "types-bleach-6.0.0.3.tar.gz", hash = "sha256:8ce7896d4f658c562768674ffcf07492c7730e128018f03edd163ff912bfadee"},
{file = "types_bleach-6.0.0.3-py3-none-any.whl", hash = "sha256:d43eaf30a643ca824e16e2dcdb0c87ef9226237e2fa3ac4732a50cb3f32e145f"},
{file = "types-bleach-6.0.0.4.tar.gz", hash = "sha256:357b0226f65c4f20ab3b13ca8d78a6b91c78aad256d8ec168d4e90fc3303ebd4"},
{file = "types_bleach-6.0.0.4-py3-none-any.whl", hash = "sha256:2b8767eb407c286b7f02803678732e522e04db8d56cbc9f1270bee49627eae92"},
]
[[package]]
@ -2991,13 +2991,13 @@ files = [
[[package]]
name = "types-pillow"
version = "10.0.0.1"
version = "10.0.0.2"
description = "Typing stubs for Pillow"
optional = false
python-versions = "*"
files = [
{file = "types-Pillow-10.0.0.1.tar.gz", hash = "sha256:834a07a04504f8bf37936679bc6a5802945e7644d0727460c0c4d4307967e2a3"},
{file = "types_Pillow-10.0.0.1-py3-none-any.whl", hash = "sha256:be576b67418f1cb3b93794cf7946581be1009a33a10085b3c132eb0875a819b4"},
{file = "types-Pillow-10.0.0.2.tar.gz", hash = "sha256:fe09380ab22d412ced989a067e9ee4af719fa3a47ba1b53b232b46514a871042"},
{file = "types_Pillow-10.0.0.2-py3-none-any.whl", hash = "sha256:29d51a3ce6ef51fabf728a504d33b4836187ff14256b2e86996d55c91ab214b1"},
]
[[package]]

View File

@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
version = "1.90.0rc1"
version = "1.90.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"

View File

@ -47,7 +47,7 @@ can be passed on the commandline for debugging.
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class Builder(object):
class Builder:
def __init__(
self,
redirect_stdout: bool = False,

View File

@ -43,7 +43,7 @@ def main(force_colors: bool) -> None:
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
# Get the schema version of the local file to check against current schema on develop
with open("synapse/storage/schema/__init__.py", "r") as file:
with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read()
new_locals: Dict[str, Any] = {}
exec(local_schema, new_locals)

View File

@ -247,7 +247,7 @@ def main() -> None:
def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
with open(args.config) as fh:
config = yaml.safe_load(fh)
if not args.server_name:

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -145,7 +145,7 @@ Example usage:
def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
with open(args.config) as fh:
config = yaml.safe_load(fh)
if not args.server_name:
args.server_name = config["server_name"]

View File

@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 8):
#
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
# if-statement completely.
py_version = sys.version_info
if py_version < (3, 8):
print("Synapse requires Python 3.8 or above.")
sys.exit(1)
@ -78,7 +82,7 @@ try:
except ImportError:
pass
import synapse.util
import synapse.util # noqa: E402
__version__ = synapse.util.SYNAPSE_VERSION

View File

@ -123,7 +123,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"rooms": ["is_public", "has_auth_chain_index"],
"users": ["shadow_banned", "approved"],
"users": ["shadow_banned", "approved", "locked"],
"un_partial_stated_event_stream": ["rejection_status_changed"],
"users_who_share_rooms": ["share_private"],
"per_user_experimental_features": ["enabled"],
@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
self.total_processed = 0
self.total_remaining = 0
super(CursesProgress, self).__init__()
super().__init__()
def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done)
super().update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done)
super().update(table, num_done)
data = self.tables[table]

View File

@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__(
super().__init__(
hostname=config.server.server_name,
config=config,
reactor=reactor,

View File

@ -60,6 +60,7 @@ class Auth(Protocol):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.

View File

@ -58,6 +58,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.
@ -79,7 +80,7 @@ class InternalAuth(BaseAuth):
parent_span = active_span()
with start_active_span("get_user_by_req"):
requester = await self._wrapped_get_user_by_req(
request, allow_guest, allow_expired
request, allow_guest, allow_expired, allow_locked
)
if parent_span:
@ -107,6 +108,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool,
allow_expired: bool,
allow_locked: bool,
) -> Requester:
"""Helper for get_user_by_req
@ -126,6 +128,17 @@ class InternalAuth(BaseAuth):
access_token, allow_expired=allow_expired
)
# Deny the request if the user account is locked.
if not allow_locked and await self.store.get_user_locked_status(
requester.user.to_string()
):
raise AuthError(
401,
"User account has been locked",
errcode=Codes.USER_LOCKED,
additional_fields={"soft_logout": True},
)
# Deny the request if the user account has expired.
# This check is only done for regular users, not appservice ones.
if not allow_expired:

View File

@ -27,6 +27,7 @@ from twisted.web.http_headers import Headers
from synapse.api.auth.base import BaseAuth
from synapse.api.errors import (
AuthError,
Codes,
HttpResponseException,
InvalidClientTokenError,
OAuthInsufficientScopeError,
@ -38,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -105,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
self._clock = hs.get_clock()
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
cache_name="introspection_token_cache",
clock=self._clock,
max_len=10000,
expiry_ms=5 * 60 * 1000,
)
if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided"
@ -143,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns:
The introspection response
"""
# check the cache before doing a request
introspection_token = self._token_cache.get(token, None)
if introspection_token:
# check the expiration field of the token (if it exists)
exp = introspection_token.get("exp", None)
if exp:
time_now = self._clock.time()
expired = time_now > exp
if not expired:
return introspection_token
else:
return introspection_token
metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = {
@ -156,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
# Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
method="POST",
uri=introspection_endpoint,
headers=raw_headers,
body=body,
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
@ -186,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response."
)
return IntrospectionToken(**resp)
expiration = resp.get("exp", None)
if expiration:
if self._clock.time() > expiration:
raise InvalidClientTokenError("Token is expired.")
introspection_token = IntrospectionToken(**resp)
# add token to cache
self._token_cache[token] = introspection_token
return introspection_token
async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
@ -196,6 +233,7 @@ class MSC3861DelegatedAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
access_token = self.get_access_token_from_request(request)
@ -205,6 +243,17 @@ class MSC3861DelegatedAuth(BaseAuth):
# so that we don't provision the user if they don't have enough permission:
requester = await self.get_user_by_access_token(access_token, allow_expired)
# Deny the request if the user account is locked.
if not allow_locked and await self.store.get_user_locked_status(
requester.user.to_string()
):
raise AuthError(
401,
"User account has been locked",
errcode=Codes.USER_LOCKED,
additional_fields={"soft_logout": True},
)
if not allow_guest and requester.is_guest:
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])

View File

@ -18,8 +18,7 @@
"""Contains constants from the specification."""
import enum
from typing_extensions import Final
from typing import Final
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536

View File

@ -80,6 +80,8 @@ class Codes(str, Enum):
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
# USER_LOCKED = "M_USER_LOCKED"
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
# Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848

View File

@ -47,6 +47,10 @@ class CasConfig(Config):
required_attributes
)
self.idp_name = cas_config.get("idp_name", "CAS")
self.idp_icon = cas_config.get("idp_icon")
self.idp_brand = cas_config.get("idp_brand")
else:
self.cas_server_url = None
self.cas_service_url = None

View File

@ -89,8 +89,14 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
# refers to a SAML IdP entity ID
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
# IdP properties for Matrix clients
self.idp_name = saml2_config.get("idp_name", "SAML")
self.idp_icon = saml2_config.get("idp_icon")
self.idp_brand = saml2_config.get("idp_brand")
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}

View File

@ -35,3 +35,4 @@ class UserDirectoryConfig(Config):
self.user_directory_search_prefer_local_users = user_directory_config.get(
"prefer_local_users", False
)
self.show_locked_users = user_directory_config.get("show_locked_users", False)

View File

@ -63,7 +63,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@ -1245,7 +1245,7 @@ class FederationServer(FederationBase):
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock.
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
await self._federation_event_handler.on_receive_pdu(
origin, event

View File

@ -67,6 +67,7 @@ class AdminHandler:
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",

View File

@ -76,12 +76,13 @@ class CasHandler:
self.idp_id = "cas"
# user-facing name of this auth provider
self.idp_name = "CAS"
self.idp_name = hs.config.cas.idp_name
# we do not currently support brands/icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
# MXC URI for icon for this auth provider
self.idp_icon = hs.config.cas.idp_icon
# optional brand identifier for this auth provider
self.idp_brand = hs.config.cas.idp_brand
self._sso_handler = hs.get_sso_handler()

View File

@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool
self.device_list_updater = DeviceListUpdater(hs, self)
@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str:
"""Store a dehydrated device for a user. If the user had a previous
dehydrated device, it is removed.
"""Store a dehydrated device for a user, optionally storing the keys associated with
it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
device_id,
initial_device_display_name,
)
time_now = self.clock.time_msec()
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
user_id, device_id, device_data, time_now, keys_for_device
)
if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])
return device_id
async def rehydrate_device(

View File

@ -367,19 +367,6 @@ class DeviceMessageHandler:
errcode=Codes.INVALID_PARAM,
)
# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
deleted,
since_stream_id,
user_id,
device_id,
)
to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device(

View File

@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@ -1034,7 +1034,7 @@ class EventCreationHandler:
)
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
return await self._create_and_send_nonmember_event_locked(
requester=requester,
@ -1978,7 +1978,7 @@ class EventCreationHandler:
for room_id in room_ids:
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
dummy_event_sent = await self._send_dummy_event_for_room(room_id)

View File

@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
@ -46,9 +47,10 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
# This is used to avoid purging a room several time at the same moment,
# and also paginating during a purge. Pagination can trigger backfill,
# which would create old events locally, and would potentially clash with the room delete.
PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock"
@attr.s(slots=True, auto_attribs=True)
@ -363,7 +365,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True
PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
@ -421,7 +423,10 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
async with self._worker_locks.acquire_multi_read_write_lock(
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
[
(PURGE_PAGINATION_LOCK_NAME, room_id),
(NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id),
],
write=True,
):
# first check that we have no users in this room
@ -483,7 +488,7 @@ class PaginationHandler:
room_token = from_token.room_key
async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=False
PURGE_PAGINATION_LOCK_NAME, room_id, write=False
):
(membership, member_event_id) = (None, None)
if not use_admin_priviledge:
@ -761,7 +766,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True
PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[

View File

@ -30,9 +30,9 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
ContextManager,
Dict,
Generator,
Iterable,
@ -44,7 +44,6 @@ from typing import (
)
from prometheus_client import Counter
from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
@ -54,7 +53,10 @@ from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
@ -141,6 +143,8 @@ class BasePresenceHandler(abc.ABC):
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
self._presence_enabled = hs.config.server.use_presence
self._federation = None
if hs.should_send_federation():
self._federation = hs.get_federation_sender()
@ -149,6 +153,15 @@ class BasePresenceHandler(abc.ABC):
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
self.VALID_PRESENCE: Tuple[str, ...] = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
)
if self._busy_presence_enabled:
self.VALID_PRESENCE += (PresenceState.BUSY,)
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@ -395,8 +408,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
EduTypes.PRESENCE,
@ -421,8 +432,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
@ -490,7 +499,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# what the spec wants: see comment in the BasePresenceHandler version
# of this function.
await self.set_state(
UserID.from_string(user_id), {"presence": presence_state}, True
UserID.from_string(user_id),
{"presence": presence_state},
ignore_status_msg=True,
)
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
@ -601,22 +612,13 @@ class WorkerPresenceHandler(BasePresenceHandler):
"""
presence = state["presence"]
valid_presence = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
PresenceState.BUSY,
)
if presence not in valid_presence or (
presence == PresenceState.BUSY and not self._busy_presence_enabled
):
if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
# If presence is disabled, no-op
if not self.hs.config.server.use_presence:
if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@ -633,7 +635,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.server.use_presence:
if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@ -649,7 +651,6 @@ class PresenceHandler(BasePresenceHandler):
self.hs = hs
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@ -700,8 +701,6 @@ class PresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs: Dict[str, int] = {}
@ -723,21 +722,16 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
self.clock.call_later(
30, self.clock.looping_call, run_timeout_handler, 5000
30, self.clock.looping_call, self._handle_timeouts, 5000
)
def run_persister() -> Awaitable[None]:
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@ -783,6 +777,7 @@ class PresenceHandler(BasePresenceHandler):
)
logger.info("Finished _on_shutdown")
@wrap_as_background_process("persist_presence_changes")
async def _persist_unpersisted_changes(self) -> None:
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
@ -898,6 +893,7 @@ class PresenceHandler(BasePresenceHandler):
states, [destination]
)
@wrap_as_background_process("handle_presence_timeouts")
async def _handle_timeouts(self) -> None:
"""Checks the presence of users that have timed out and updates as
appropriate.
@ -955,7 +951,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.server.use_presence:
if not self._presence_enabled:
return
user_id = user.to_string()
@ -990,56 +986,51 @@ class PresenceHandler(BasePresenceHandler):
client that is being used by a user.
presence_state: The presence state indicated in the sync request
"""
# Override if it should affect the user's presence, if presence is
# disabled.
if not self.hs.config.server.use_presence:
affect_presence = False
if not affect_presence or not self._presence_enabled:
return _NullContextManager()
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_state = await self.current_state_for_user(user_id)
# If they're busy then they don't stop being busy just by syncing,
# so just update the last sync time.
if prev_state.state != PresenceState.BUSY:
# XXX: We set_state separately here and just update the last_active_ts above
# This keeps the logic as similar as possible between the worker and single
# process modes. Using set_state will actually cause last_active_ts to be
# updated always, which is not what the spec calls for, but synapse has done
# this for... forever, I think.
await self.set_state(
UserID.from_string(user_id),
{"presence": presence_state},
ignore_status_msg=True,
)
# Retrieve the new state for the logic below. This should come from the
# in-memory cache.
prev_state = await self.current_state_for_user(user_id)
# If they're busy then they don't stop being busy just by syncing,
# so just update the last sync time.
if prev_state.state != PresenceState.BUSY:
# XXX: We set_state separately here and just update the last_active_ts above
# This keeps the logic as similar as possible between the worker and single
# process modes. Using set_state will actually cause last_active_ts to be
# updated always, which is not what the spec calls for, but synapse has done
# this for... forever, I think.
await self.set_state(
UserID.from_string(user_id), {"presence": presence_state}, True
)
# Retrieve the new state for the logic below. This should come from the
# in-memory cache.
prev_state = await self.current_state_for_user(user_id)
# To keep the single process behaviour consistent with worker mode, run the
# same logic as `update_external_syncs_row`, even though it looks weird.
if prev_state.state == PresenceState.OFFLINE:
await self._update_states(
[
prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=self.clock.time_msec(),
last_user_sync_ts=self.clock.time_msec(),
)
]
)
# otherwise, set the new presence state & update the last sync time,
# but don't update last_active_ts as this isn't an indication that
# they've been active (even though it's probably been updated by
# set_state above)
else:
await self._update_states(
[
prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec()
)
]
)
# To keep the single process behaviour consistent with worker mode, run the
# same logic as `update_external_syncs_row`, even though it looks weird.
if prev_state.state == PresenceState.OFFLINE:
await self._update_states(
[
prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=self.clock.time_msec(),
last_user_sync_ts=self.clock.time_msec(),
)
]
)
# otherwise, set the new presence state & update the last sync time,
# but don't update last_active_ts as this isn't an indication that
# they've been active (even though it's probably been updated by
# set_state above)
else:
await self._update_states(
[prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())]
)
async def _end() -> None:
try:
@ -1061,8 +1052,7 @@ class PresenceHandler(BasePresenceHandler):
try:
yield
finally:
if affect_presence:
run_in_background(_end)
run_in_background(_end)
return _user_syncing()
@ -1229,20 +1219,11 @@ class PresenceHandler(BasePresenceHandler):
status_msg = state.get("status_msg", None)
presence = state["presence"]
valid_presence = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
PresenceState.BUSY,
)
if presence not in valid_presence or (
presence == PresenceState.BUSY and not self._busy_presence_enabled
):
if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
if not self.hs.config.server.use_presence:
if not self._presence_enabled:
return
user_id = target_user.to_string()

View File

@ -39,7 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
@ -629,7 +629,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async with self.member_linearizer.queue(key):
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
diff = self.clock.time_msec() - then

View File

@ -74,12 +74,13 @@ class SamlHandler:
self.idp_id = "saml"
# user-facing name of this auth provider
self.idp_name = "SAML"
self.idp_name = hs.config.saml2.idp_name
# we do not currently support icons/brands for SAML auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
# MXC URI for icon for this auth provider
self.idp_icon = hs.config.saml2.idp_icon
# optional brand identifier for this auth provider
self.idp_brand = hs.config.saml2.idp_brand
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}

View File

@ -24,13 +24,14 @@ from typing import (
Iterable,
List,
Mapping,
NoReturn,
Optional,
Set,
)
from urllib.parse import urlencode
import attr
from typing_extensions import NoReturn, Protocol
from typing_extensions import Protocol
from twisted.web.iweb import IRequest
from twisted.web.server import Request
@ -791,7 +792,7 @@ class SsoHandler:
if code != 200:
raise Exception(
"GET request to download sso avatar image returned {}".format(code)
f"GET request to download sso avatar image returned {code}"
)
# upload name includes hash of the image file's content so that we can

View File

@ -14,9 +14,15 @@
# limitations under the License.
import logging
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing_extensions import Counter as CounterType
from typing import (
TYPE_CHECKING,
Any,
Counter as CounterType,
Dict,
Iterable,
Optional,
Tuple,
)
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions

View File

@ -387,16 +387,16 @@ class SyncHandler:
from_token=since_token,
)
# if nothing has happened in any of the users' rooms since /sync was called,
# the resultant next_batch will be the same as since_token (since the result
# is generated when wait_for_events is first called, and not regenerated
# when wait_for_events times out).
#
# If that happens, we mustn't cache it, so that when the client comes back
# with the same cache token, we don't immediately return the same empty
# result, causing a tightloop. (#8518)
if result.next_batch == since_token:
cache_context.should_cache = False
# if nothing has happened in any of the users' rooms since /sync was called,
# the resultant next_batch will be the same as since_token (since the result
# is generated when wait_for_events is first called, and not regenerated
# when wait_for_events times out).
#
# If that happens, we mustn't cache it, so that when the client comes back
# with the same cache token, we don't immediately return the same empty
# result, causing a tightloop. (#8518)
if result.next_batch == since_token:
cache_context.should_cache = False
if result:
if sync_config.filter_collection.lazy_load_members():
@ -1442,11 +1442,9 @@ class SyncHandler:
# Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset(
(
room_id
for room_id in mutable_joined_room_ids
if room_id not in mutable_rooms_to_exclude
)
room_id
for room_id in mutable_joined_room_ids
if room_id not in mutable_rooms_to_exclude
)
logger.debug(

View File

@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.worker.should_update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
self.show_locked_users = hs.config.userdirectory.show_locked_users
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._hs = hs
@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler):
]
}
"""
results = await self.store.search_user_dir(user_id, search_term, limit)
results = await self.store.search_user_dir(
user_id, search_term, limit, self.show_locked_users
)
# Remove any spammy users from the results.
non_spammy_users = []

View File

@ -42,7 +42,11 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
# This lock is used to avoid creating an event while we are purging the room.
# We take a read lock when creating an event, and a write one when purging a room.
# This is because it is fine to create several events concurrently, since referenced events
# will not disappear under our feet as long as we don't delete the room.
NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
class WorkerLocksHandler:

View File

@ -18,10 +18,9 @@ import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing import Callable, Optional
from typing import Callable, Deque, Optional
import attr
from typing_extensions import Deque
from zope.interface import implementer
from twisted.application.internet import ClientService

View File

@ -31,7 +31,7 @@ from typing import (
import attr
import jinja2
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
@ -885,7 +885,7 @@ class ModuleApi:
def run_db_interaction(
self,
desc: str,
func: Callable[P, T],
func: Callable[Concatenate[LoggingTransaction, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":

View File

@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
generally discouraged as it doesn't support internationalization.
"""
for callback in self._check_event_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(event))
if res is False or res == self.NOT_SPAM:
# This spam-checker accepts the event.
@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
True if the event should be silently dropped
"""
for callback in self._should_drop_federated_event_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res
@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
"""
for callback in self._user_may_join_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(user_id, room_id, is_invited))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM:
@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_invite_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_send_3pid_invite_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
userid: The ID of the user attempting to create a room
"""
for callback in self._user_may_create_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid))
if res is True or res is self.NOT_SPAM:
continue
@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._user_may_create_room_alias_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_alias))
if res is True or res is self.NOT_SPAM:
continue
@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
room_id: The ID of the room that would be published
"""
for callback in self._user_may_publish_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_id))
if res is True or res is self.NOT_SPAM:
continue
@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
res = await delay_cancellation(callback(user_profile.copy()))
@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_registration_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id)
)
@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_media_file_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(file_wrapper, file_info))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is False or res is self.NOT_SPAM:
@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(
user_id,

View File

@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Deque,
Dict,
Iterable,
Iterator,
@ -29,7 +30,6 @@ from typing import (
)
from prometheus_client import Counter
from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory

View File

@ -280,6 +280,17 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
)
lock = body.get("locked", False)
if not isinstance(lock, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean"
)
if deactivate and lock:
raise SynapseError(
HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked"
)
approved: Optional[bool] = None
if "approved" in body and self._msc3866_enabled:
approved = body["approved"]
@ -397,6 +408,12 @@ class UserRestServletV2(RestServlet):
target_user.to_string()
)
if "locked" in body:
if lock and not user["locked"]:
await self.store.set_user_locked_status(user_id, True)
elif not lock and user["locked"]:
await self.store.set_user_locked_status(user_id, False)
if "user_type" in body:
await self.store.set_user_type(target_user, user_type)

View File

@ -29,7 +29,6 @@ from synapse.http.servlet import (
parse_integer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler
if hs.config.worker.worker_app is None:
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
"Device key(s) not found, these must be provided.",
)
# TODO: Those two operations, creating a device and storing the
# device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
)
# TODO: Do we need to do something with the result here?
await self.key_uploader(
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
device_info,
)
return 200, {"device_id": device_id}

View File

@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
requester = await self.auth.get_user_by_req(
request, allow_expired=True, allow_locked=True
)
if requester.device_id is None:
# The access token wasn't associated with a device.
@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
requester = await self.auth.get_user_by_req(
request, allow_expired=True, allow_locked=True
)
user_id = requester.user.to_string()
# first delete all of the user's devices

View File

@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
self._push_rules_handler = hs.get_push_rules_handler()
self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
async with self._push_rule_linearizer.queue(user_id):
return await self.handle_put(request, path, user_id)
async def handle_put(
self, request: SynapseRequest, path: str, user_id: str
) -> Tuple[int, JsonDict]:
spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
requester = await self.auth.get_user_by_req(request)
if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
if spec.attr:
try:
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
spec = _rule_spec_from_path(path.split("/"))
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
async with self._push_rule_linearizer.queue(user_id):
return await self.handle_delete(request, path, user_id)
async def handle_delete(
self,
request: SynapseRequest,
path: str,
user_id: str,
) -> Tuple[int, JsonDict]:
spec = _rule_spec_from_path(path.split("/"))
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet):
try:
async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version

View File

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
store_queries = []
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
if key_ids:
results: Mapping[
str, Optional[FetchKeyResultForRemote]
] = await self.store.get_server_keys_json_for_remote(
server_name, key_ids
)
else:
results = await self.store.get_all_server_keys_json_for_remote(
server_name
)
cached = await self.store.get_server_keys_json_for_remote(store_queries)
server_keys.update(
((server_name, key_id), res) for key_id, res in results.items()
)
json_results: Set[bytes] = set()
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]
if key_id is None:
for (server_name, key_id), key_result in server_keys.items():
if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if key_result:
json_results.add(key_result.key_json)
continue
miss = False
if not results:
if key_result is None:
miss = True
else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist

View File

@ -238,6 +238,7 @@ class BackgroundUpdater:
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
self.hs = hs
self._database_name = database.name()
@ -758,6 +759,11 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
# override the global statement timeout to avoid accidentally squashing
# a long-running index creation process
timeout_sql = "SET SESSION statement_timeout = 0"
c.execute(timeout_sql)
sql = (
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
" ON %(table)s"
@ -778,6 +784,12 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
# mypy ignore - `statement_timeout` is defined on PostgresEngine
# reset the global timeout to the default
default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined]
undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
conn.cursor().execute(undo_timeout_sql)
conn.set_session(autocommit=False) # type: ignore
def create_index_sqlite(conn: Connection) -> None:

View File

@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import (
SynapseTags,
@ -357,7 +357,7 @@ class EventsPersistenceStorageController:
# it. We might already have taken out the lock, but since this is just a
# "read" lock its inherently reentrant.
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task)

View File

@ -28,6 +28,7 @@ from typing import (
cast,
)
from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
device_data: str,
time: int,
keys: Optional[JsonDict] = None,
) -> Optional[str]:
# TODO: make keys non-optional once support for msc2697 is dropped
if keys:
device_keys = keys.get("device_keys", None)
if device_keys:
# Type ignore - this function is defined on EndToEndKeyStore which we do
# have access to due to hs.get_datastore() "magic"
self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
txn, user_id, device_id, time, device_keys
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append(
(
algorithm,
key_id,
encode_canonical_json(key_obj).decode("ascii"),
)
)
self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
fallback_keys = keys.get("fallback_keys", None)
if fallback_keys:
self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
return old_device_id
async def store_dehydrated_device(
self, user_id: str, device_id: str, device_data: JsonDict
self,
user_id: str,
device_id: str,
device_data: JsonDict,
time_now: int,
keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
time_now: current time at the request in milliseconds
keys: keys for the dehydrated device
Returns:
device id of the user's previous dehydrated device, if any
"""
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
time_now,
keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:

View File

@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
"add_e2e_one_time_keys_insert",
self._add_e2e_one_time_keys_txn,
user_id,
device_id,
time_now,
new_keys,
)
def _add_e2e_one_time_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
new_keys: Iterable[Tuple[str, str, str]],
) -> None:
"""Insert some new one time keys for a device. Errors if any of the keys already exist.
Args:
user_id: id of user to get keys for
device_id: id of device to get keys for
time_now: insertion time to record (ms since epoch)
new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
that the key JSON must be in canonical JSON form
"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
keys=(
"user_id",
"device_id",
"algorithm",
"key_id",
"ts_added_ms",
"key_json",
),
values=[
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str,
fallback_keys: JsonDict,
) -> None:
"""Set the user's e2e fallback keys.
Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
@ -1304,43 +1333,70 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
"set_e2e_device_keys",
self._set_e2e_device_keys_txn,
user_id,
device_id,
time_now,
device_keys,
)
def _set_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
time_now: int,
device_keys: JsonDict,
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
Args:
user_id: user_id of the user to store keys for
device_id: device_id of the device to store keys for
time_now: time at the request to store the keys
device_keys: the keys to store
"""
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json

View File

@ -16,14 +16,13 @@
import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
class KeyStore(SQLBaseStore):
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),)))
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
)
@cached()
def _get_server_keys_json(
@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
@cached()
def get_server_key_json_for_remote(
self,
server_name: str,
key_id: str,
) -> Optional[FetchKeyResultForRemote]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
)
async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.
Args:
server_keys: List of (server_name, key_id, source) triplets.
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
If we have multiple entries for a given key ID, returns the most recent.
"""
def _get_server_keys_json_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
rows = await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}
async def get_all_server_keys_json_for_remote(
self,
server_name: str,
) -> Dict[str, FetchKeyResultForRemote]:
"""Fetch the cached keys for the given server.
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

View File

@ -26,7 +26,6 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from synapse.util.stringutils import random_string
@ -96,6 +95,10 @@ class LockStore(SQLBaseStore):
self._acquiring_locks: Set[Tuple[str, str]] = set()
self._clock.looping_call(
self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0
)
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@ -216,6 +219,7 @@ class LockStore(SQLBaseStore):
lock_name,
lock_key,
write,
db_autocommit=True,
)
except self.database_engine.module.IntegrityError:
return None
@ -233,61 +237,22 @@ class LockStore(SQLBaseStore):
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.
now = self._clock.time_msec()
token = random_string(6)
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""
insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""
if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
self.db_pool.simple_insert_txn(
txn,
table="worker_read_write_locks",
values={
"lock_name": lock_name,
"lock_key": lock_key,
"write_lock": write,
"instance_name": self._instance_name,
"token": token,
"last_renewed_ts": now,
},
)
lock = Lock(
self._reactor,
@ -351,6 +316,24 @@ class LockStore(SQLBaseStore):
return locks
@wrap_as_background_process("_reap_stale_read_write_locks")
async def _reap_stale_read_write_locks(self) -> None:
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ?
"""
def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None:
txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,))
if txn.rowcount:
logger.info("Reaped %d stale locks", txn.rowcount)
await self.db_pool.runInteraction(
"_reap_stale_read_write_locks",
reap_stale_read_write_locks_txn,
db_autocommit=True,
)
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is

View File

@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked
FROM users
WHERE name = ?
""",
@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
if not isinstance(row[column], bool):
row[column] = bool(row[column])
row[column] = bool(row[column])
return row
@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the integer into a boolean.
return res == 1
@cached()
async def get_user_locked_status(self, user_id: str) -> bool:
"""Retrieve the value for the `locked` property for the provided user.
Args:
user_id: The ID of the user to retrieve the status for.
Returns:
True if the user was locked, false if the user is still active.
"""
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="locked",
desc="get_user_locked_status",
)
# Convert the potential integer into a boolean.
return bool(res)
async def get_threepid_validation_session(
self,
medium: Optional[str],
@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
"""Set the `locked` property for the provided user to the provided value.
Args:
user_id: The ID of the user to set the status for.
locked: The value to set for `locked`.
"""
await self.db_pool.runInteraction(
"set_user_locked_status",
self.set_user_locked_status_txn,
user_id,
locked,
)
def set_user_locked_status_txn(
self, txn: LoggingTransaction, user_id: str, locked: bool
) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"locked": locked},
)
self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:

View File

@ -19,6 +19,7 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Counter,
Dict,
Iterable,
List,
@ -28,8 +29,6 @@ from typing import (
cast,
)
from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership

View File

@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
self,
user_id: str,
search_term: str,
limit: int,
show_locked_users: bool = False,
) -> SearchResult:
"""Searches for users in directory
@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
if not show_locked_users:
where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM matching_users as t
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
ORDER BY
@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
AND value MATCH ?

View File

@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
This is not provided by DBAPI2, and so needs engine-specific support.
"""
with open(filepath, "rt") as f:
with open(filepath) as f:
cls.executescript(cursor, f.read())

View File

@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResultForRemote:
key_json: bytes # the full key JSON
valid_until_ts: int # how long we can use this key for, in milliseconds.
added_ts: int # When we added this key, in milliseconds.

View File

@ -16,10 +16,18 @@ import logging
import os
import re
from collections import Counter
from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple
from typing import (
Collection,
Counter as CounterType,
Generator,
Iterable,
List,
Optional,
TextIO,
Tuple,
)
import attr
from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL;

View File

@ -21,6 +21,7 @@ from typing import (
Any,
ClassVar,
Dict,
Final,
List,
Mapping,
Match,
@ -38,7 +39,7 @@ import attr
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import Final, TypedDict
from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface

View File

@ -22,6 +22,7 @@ import logging
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
@ -42,7 +43,7 @@ from typing import (
)
import attr
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.internet import defer
from twisted.internet.defer import CancelledError

View File

@ -218,7 +218,7 @@ class MacaroonGenerator:
# to avoid validating those as guest tokens, we explicitely verify if
# the macaroon includes the "guest = true" caveat.
is_guest = any(
(caveat.caveat_id == "guest = true" for caveat in macaroon.caveats)
caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
)
if not is_guest:

View File

@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
# type-ignore: This is an error in Twisted's annotations. See
# https://github.com/twisted/twisted/issues/11812 and /11813 .
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type]
# conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used).

View File

@ -20,6 +20,7 @@ import typing
from typing import (
Any,
Callable,
ContextManager,
DefaultDict,
Dict,
Iterator,
@ -33,7 +34,6 @@ from typing import (
from weakref import WeakSet
from prometheus_client.core import Counter
from typing_extensions import ContextManager
from twisted.internet import defer

View File

@ -17,6 +17,7 @@ from enum import Enum, auto
from typing import (
Collection,
Dict,
Final,
FrozenSet,
List,
Mapping,
@ -27,7 +28,6 @@ from typing import (
)
import attr
from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase

View File

@ -69,6 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@ -293,6 +294,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@ -311,6 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True,
)
)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})

View File

@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
hs = super().make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check
# that it's not enabled

View File

@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server9", get_key_id(key1))]
)
result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""

View File

@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
# Fetch the message of the dehydrated device again, which should return nothing
# and delete the old messages
# Fetch the message of the dehydrated device again, which should return
# the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=res["next_batch"],
since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0)
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")

View File

@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
)
)
# first call should cache response
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
# for regular auth code via the config
self.get_success(
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], access_token)
# there's been one http request
self.http_client.request.assert_called_once()
# second call should pull from cache, there should still be only one http request
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.http_client.request.assert_called_once()
self.assertEqual(token["jti"], access_token)
# advance past five minutes and check that cache expired - there should be more than one http call now
self.reactor.advance(360)
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.assertEqual(self.http_client.request.call_count, 2)
self.assertEqual(token_2["jti"], access_token)
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={
"active": "true",
"scope": "guest",
"jti": "stale",
"exp": self.clock.time() + 100,
},
)
)
self.get_success(
self.auth._introspect_token("stale") # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], "stale")
self.assertEqual(self.http_client.request.call_count, 1)
# advance the reactor past the token expiry but less than the cache expiry
self.reactor.advance(120)
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
# check that the next call causes another http request (which will fail because the token is technically expired
# but the important thing is we discard the token from the cache and try the network)
self.get_failure(
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
)
self.assertEqual(self.http_client.request.call_count, 2)
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)

View File

@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200)
# Send the body
request.write('{ "a": 1 }'.encode("ascii"))
request.write(b'{ "a": 1 }')
request.finish()
self.reactor.pump((0.1,))

View File

@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias.
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias.
room_id, room_alias = self.get_success(

View File

@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]

View File

@ -29,7 +29,16 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import devices, login, logout, profile, register, room, sync
from synapse.rest.client import (
devices,
login,
logout,
profile,
register,
room,
sync,
user_directory,
)
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
register.register_servlets,
user_directory.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -2464,6 +2474,105 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", channel.json_body)
def test_locked_user(self) -> None:
# User can sync
channel = self.make_request(
"GET",
"/_matrix/client/v3/sync",
access_token=self.other_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Lock user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"locked": True},
)
# User is not authorized to sync anymore
channel = self.make_request(
"GET",
"/_matrix/client/v3/sync",
access_token=self.other_user_token,
)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
self.assertTrue(channel.json_body["soft_logout"])
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_locked_user_not_in_user_dir(self) -> None:
# User is available in the user dir
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": self.other_user},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("results", channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Lock user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"locked": True},
)
# User is not available anymore in the user dir
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": self.other_user},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("results", channel.json_body)
self.assertEqual(0, len(channel.json_body["results"]))
@override_config(
{
"user_directory": {
"enabled": True,
"search_all_users": True,
"show_locked_users": True,
}
}
)
def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
# User is available in the user dir
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": self.other_user},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("results", channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Lock user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"locked": True},
)
# User is still available in the user dir
channel = self.make_request(
"POST",
"/_matrix/client/v3/user_directory/search",
{"search_term": self.other_user},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("results", channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self) -> None:
"""

View File

@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
},
},
"fallback_keys": {
"alg1:device1": "f4llb4ckk3y",
"signed_<algorithm>:<device_id>": {
"fallback": "true",
"key": "f4llb4ckk3y",
"signatures": {
"<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
},
},
},
"one_time_keys": {"alg1:k1": "0net1m3k3y"},
}
channel = self.make_request(
"PUT",
@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
}
self.assertEqual(device_data, expected_device_data)
# test that the keys are correctly uploaded
channel = self.make_request(
"POST",
"/_matrix/client/r0/keys/query",
{
"device_keys": {
user: ["device1"],
},
},
token,
)
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body["device_keys"][user][device_id]["keys"],
content["device_keys"]["keys"],
)
# first claim should return the onetime key we uploaded
res = self.get_success(
self.hs.get_e2e_keys_handler().claim_one_time_keys(
{user: {device_id: {"alg1": 1}}},
UserID.from_string(user),
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
},
)
# second claim should return fallback key
res2 = self.get_success(
self.hs.get_e2e_keys_handler().claim_one_time_keys(
{user: {device_id: {"alg1": 1}}},
UserID.from_string(user),
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res2,
{
"failures": {},
"one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
},
)
# create another device for the user
(
new_device_id,
@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
expected_content = {"body": "test_message"}
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
# fetch messages again and make sure that the message was not deleted
channel = self.make_request(
"POST",
f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
content={},
access_token=token,
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
next_batch_token = channel.json_body.get("next_batch")
# fetch messages again and make sure that the message was deleted and we are returned an
# empty array
# make sure fetching messages with next batch token works - there are no unfetched
# messages so we should receive an empty array
content = {"next_batch": next_batch_token}
channel = self.make_request(
"POST",

View File

@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction)
# The redacts key should be in the content and the redacts keys.
self.assertEquals(redact_event["content"]["redacts"], event_id)
self.assertEquals(redact_event["redacts"], event_id)
self.assertEqual(redact_event["content"]["redacts"], event_id)
self.assertEqual(redact_event["redacts"], event_id)
# But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict:
@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event)
)
self.assertEquals(event_json["type"], EventTypes.Redaction)
self.assertEqual(event_json["type"], EventTypes.Redaction)
if expect_content:
self.assertNotIn("redacts", event_json)
self.assertEquals(event_json["content"]["redacts"], event_id)
self.assertEqual(event_json["content"]["redacts"], event_id)
else:
self.assertEquals(event_json["redacts"], event_id)
self.assertEqual(event_json["redacts"], event_id)
self.assertNotIn("redacts", event_json["content"])

Some files were not shown because too many files have changed in this diff Show More