Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Sean Quah 2023-03-31 10:17:16 +01:00
commit bf81ee4217
42 changed files with 738 additions and 398 deletions

View File

@ -0,0 +1 @@
Add the ability to enable/disable registrations when in the OIDC flow.

View File

@ -1 +1 @@
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).

View File

@ -0,0 +1 @@
Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).

1
changelog.d/15334.misc Normal file
View File

@ -0,0 +1 @@
Speed up unit tests when using SQLite3.

1
changelog.d/15349.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where some to_device messages could be dropped when using workers.

1
changelog.d/15350.misc Normal file
View File

@ -0,0 +1 @@
Make the `thread_id` column on `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` non-null.

1
changelog.d/15351.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.70.0 where the background sync from a faster join could spin for hours when one of the events involved had been marked for backoff.

1
changelog.d/15352.bugfix Normal file
View File

@ -0,0 +1 @@
Fix missing app variable in mail subject for password resets. Contributed by Cyberes.

1
changelog.d/15354.misc Normal file
View File

@ -0,0 +1 @@
Add some clarification to the doc/comments regarding TCP replication.

View File

@ -25,7 +25,7 @@ position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where
the format of `<row>` is defined by the individual streams. The
`<instance_name>` is the name of the Synapse process that generated the data
(usually "master").
(usually "master"). We expect an RDATA for every row in the DB.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@ -107,7 +107,7 @@ reconnect, following the steps above.
If the server sends messages faster than the client can consume them the
server will first buffer a (fairly large) number of commands and then
disconnect the client. This ensures that we don't queue up an unbounded
number of commands in memory and gives us a potential oppurtunity to
number of commands in memory and gives us a potential opportunity to
squawk loudly. When/if the client recovers it can reconnect to the
server and ask for missed messages.
@ -122,7 +122,7 @@ since these include tokens which can be used to restart the stream on
connection errors.
The client should keep track of the token in the last RDATA command
received for each stream so that on reconneciton it can start streaming
received for each stream so that on reconnection it can start streaming
from the correct place. Note: not all RDATA have valid tokens due to
batching. See `RdataCommand` for more details.
@ -188,7 +188,8 @@ client (C):
Two positions are included, the "new" position and the last position sent respectively.
This allows servers to tell instances that the positions have advanced but no
data has been written, without clients needlessly checking to see if they
have missed any updates.
have missed any updates. Instances will only fetch stuff if there is a gap between
their current position and the given last position.
#### ERROR (S, C)

View File

@ -3100,6 +3100,11 @@ Options for each entry include:
match a pre-existing account instead of failing. This could be used if
switching from password logins to OIDC. Defaults to false.
* `enable_registration`: set to 'false' to disable automatic registration of new
users. This allows the OIDC SSO flow to be limited to sign in only, rather than
automatically registering users that have a valid SSO login but do not have
a pre-registered account. Defaults to true.
* `user_mapping_provider`: Configuration for how attributes returned from a OIDC
provider are mapped onto a matrix user. This setting has the following
sub-properties:
@ -3216,6 +3221,7 @@ oidc_providers:
userinfo_endpoint: "https://accounts.example.com/userinfo"
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
skip_verification: true
enable_registration: true
user_mapping_provider:
config:
subject_claim: "id"

View File

@ -27,7 +27,7 @@ from synapse.util import json_decoder
if typing.TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
from synapse.types import JsonDict
from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__)
@ -682,18 +682,27 @@ class FederationPullAttemptBackoffError(RuntimeError):
Attributes:
event_id: The event_id which we are refusing to pull
message: A custom error message that gives more context
retry_after_ms: The remaining backoff interval, in milliseconds
"""
def __init__(self, event_ids: List[str], message: Optional[str]):
self.event_ids = event_ids
def __init__(
self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int
):
event_ids = list(event_ids)
if message:
error_message = message
else:
error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)."
error_message = (
f"Not attempting to pull event_ids={event_ids} because we already "
"tried to pull them recently (backing off)."
)
super().__init__(error_message)
self.event_ids = event_ids
self.retry_after_ms = retry_after_ms
class HttpResponseException(CodeMessageException):
"""

View File

@ -30,7 +30,7 @@ from prometheus_client import Counter
from typing_extensions import TypeGuard
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeysCount,
@ -38,7 +38,7 @@ from synapse.appservice import (
)
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient):
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from an application service.
Note that any error (including a timeout) is treated as the application
service having no information.
Args:
service: The application service to query.
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
except CodeMessageException as e:
except HttpResponseException as e:
# The appservice doesn't support this endpoint.
if e.code == 404 or e.code == 405:
if is_unknown_endpoint(e):
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient):
return response, missing
async def query_keys(
self, service: "ApplicationService", query: Dict[str, List[str]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Query the application service for keys.
Note that any error (including a timeout) is treated as the application
service having no information.
Args:
service: The application service to query.
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
device_keys is a map of user ID -> a map device ID -> device info.
"""
if service.url is None:
return {}
# This is required by the configuration.
assert service.hs_token is not None
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
try:
response = await self.post_json_get_json(
uri,
query,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
except HttpResponseException as e:
# The appservice doesn't support this endpoint.
if is_unknown_endpoint(e):
return {}
logger.warning("query_keys to %s received %s", uri, e.code)
return {}
except Exception as ex:
logger.warning("query_keys to %s threw exception %s", uri, ex)
return {}
return response
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:

View File

@ -79,6 +79,11 @@ class ExperimentalConfig(Config):
"msc3983_appservice_otk_claims", False
)
# MSC3984: Proxying key queries to exclusive ASes.
self.msc3984_appservice_key_query: bool = experimental.get(
"msc3984_appservice_key_query", False
)
# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve

View File

@ -136,6 +136,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "array",
"items": SsoAttributeRequirement.JSON_SCHEMA,
},
"enable_registration": {"type": "boolean"},
},
}
@ -306,6 +307,7 @@ def _parse_oidc_config_dict(
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
attribute_requirements=attribute_requirements,
enable_registration=oidc_config.get("enable_registration", True),
)
@ -405,3 +407,6 @@ class OidcProviderConfig:
# required attributes to require in userinfo to allow login/registration
attribute_requirements: List[SsoAttributeRequirement]
# Whether automatic registrations are enabled in the ODIC flow. Defaults to True
enable_registration: bool

View File

@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
event_from_pdu_json,
)
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
@ -759,43 +760,6 @@ class FederationClient(FederationBase):
return signed_auth
def _is_unknown_endpoint(
self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
) -> bool:
"""
Returns true if the response was due to an endpoint being unimplemented.
Args:
e: The error response received from the remote server.
synapse_error: The above error converted to a SynapseError. This is
automatically generated if not provided.
"""
if synapse_error is None:
synapse_error = e.to_synapse_error()
# MSC3743 specifies that servers should return a 404 or 405 with an errcode
# of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
# to an unknown method, respectively.
#
# Older versions of servers don't properly handle this. This needs to be
# rather specific as some endpoints truly do return 404 errors.
return (
# 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
(e.code == 404 or e.code == 405)
and (
# Older Dendrites returned a text or empty body.
# Older Conduit returned an empty body.
not e.response
or e.response == b"404 page not found"
# The proper response JSON with M_UNRECOGNIZED errcode.
or synapse_error.errcode == Codes.UNRECOGNIZED
)
) or (
# Older Synapses returned a 400 error.
e.code == 400
and synapse_error.errcode == Codes.UNRECOGNIZED
)
async def _try_destination_list(
self,
description: str,
@ -887,7 +851,7 @@ class FederationClient(FederationBase):
elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
failover = True
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
elif failover_on_unknown_endpoint and is_unknown_endpoint(
e, synapse_error
):
failover = True
@ -1223,7 +1187,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not self._is_unknown_endpoint(e):
if not is_unknown_endpoint(e):
raise
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
@ -1297,7 +1261,7 @@ class FederationClient(FederationBase):
# fallback to the v1 endpoint if the room uses old-style event IDs.
# Otherwise, consider it a legitimate error and raise.
err = e.to_synapse_error()
if self._is_unknown_endpoint(e, err):
if is_unknown_endpoint(e, err):
if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
raise SynapseError(
400,
@ -1358,7 +1322,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not self._is_unknown_endpoint(e):
if not is_unknown_endpoint(e):
raise
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
@ -1629,7 +1593,7 @@ class FederationClient(FederationBase):
# If an error is received that is due to an unrecognised endpoint,
# fallback to the unstable endpoint. Otherwise, consider it a
# legitimate error and raise.
if not self._is_unknown_endpoint(e):
if not is_unknown_endpoint(e):
raise
logger.debug(

View File

@ -18,6 +18,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
@ -846,6 +847,10 @@ class ApplicationServicesHandler:
]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
key claim request to check if the application service provides keys
directly.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
@ -901,3 +906,59 @@ class ApplicationServicesHandler:
missing.extend(result[1])
return claimed_keys, missing
async def query_keys(
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Query application services for device keys.
Users which are exclusively owned by an application service are queried
for keys to check if the application service provides keys directly.
Args:
query: map from user_id to a list of devices to query
Returns:
A map from user_id -> device_id -> device details
"""
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
for user_id, device_ids in query.items():
if not self.store.get_if_app_services_interested_in_user(user_id):
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, {})[user_id] = (
device_ids or []
)
continue
# Query each service in parallel.
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
self.appservice_api.query_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
service_query,
)
for service_id, service_query in query_by_appservice.items()
],
consumeErrors=True,
)
)
# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a single
# dictionary.
key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
key_queries.update(result)
return key_queries

View File

@ -91,6 +91,9 @@ class E2eKeysHandler:
self._query_appservices_for_otks = (
hs.config.experimental.msc3983_appservice_otk_claims
)
self._query_appservices_for_keys = (
hs.config.experimental.msc3984_appservice_key_query
)
@trace
@cancellable
@ -497,6 +500,19 @@ class E2eKeysHandler:
local_query, include_displaynames
)
# Check if the application services have any additional results.
if self._query_appservices_for_keys:
# Query the appservices for any keys.
appservice_results = await self._appservice_handler.query_keys(query)
# Merge results, overriding with what the appservice returned.
for user_id, devices in appservice_results.get("device_keys", {}).items():
# Copy the appservice device info over the homeserver device info, but
# don't completely overwrite it.
results.setdefault(user_id, {}).update(devices)
# TODO Handle cross-signing keys.
# Build the result structure
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():

View File

@ -1949,27 +1949,25 @@ class FederationHandler:
)
for event in events:
for attempt in itertools.count():
# We try a new destination on every iteration.
try:
while True:
try:
await self._federation_event_handler.update_state_for_partial_state_event(
destination, event
)
break
except FederationPullAttemptBackoffError as exc:
# Log a warning about why we failed to process the event (the error message
# for `FederationPullAttemptBackoffError` is pretty good)
logger.warning("_sync_partial_state_room: %s", exc)
# We do not record a failed pull attempt when we backoff fetching a missing
# `prev_event` because not being able to fetch the `prev_events` just means
# we won't be able to de-outlier the pulled event. But we can still use an
# `outlier` in the state/auth chain for another event. So we shouldn't stop
# a downstream event from trying to pull it.
#
# This avoids a cascade of backoff for all events in the DAG downstream from
# one event backoff upstream.
except FederationError as e:
# TODO: We should `record_event_failed_pull_attempt` here,
# see https://github.com/matrix-org/synapse/issues/13700
except FederationPullAttemptBackoffError as e:
# We are in the backoff period for one of the event's
# prev_events. Wait it out and try again after.
logger.warning(
"%s; waiting for %d ms...", e, e.retry_after_ms
)
await self.clock.sleep(e.retry_after_ms / 1000)
# Success, no need to try the rest of the destinations.
break
except FederationError as e:
if attempt == len(destinations) - 1:
# We have tried every remote server for this event. Give up.
# TODO(faster_joins) giving up isn't the right thing to do
@ -1986,6 +1984,8 @@ class FederationHandler:
destination,
e,
)
# TODO: We should `record_event_failed_pull_attempt` here,
# see https://github.com/matrix-org/synapse/issues/13700
raise
# Try the next remote server.

View File

@ -140,6 +140,7 @@ class FederationEventHandler:
"""
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@ -1038,8 +1039,8 @@ class FederationEventHandler:
Raises:
FederationPullAttemptBackoffError if we are are deliberately not attempting
to pull the given event over federation because we've already done so
recently and are backing off.
to pull one of the given event's `prev_event`s over federation because
we've already done so recently and are backing off.
FederationError if we fail to get the state from the remote server after any
missing `prev_event`s.
"""
@ -1053,13 +1054,22 @@ class FederationEventHandler:
# If we've already recently attempted to pull this missing event, don't
# try it again so soon. Since we have to fetch all of the prev_events, we can
# bail early here if we find any to ignore.
prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
prevs_with_pull_backoff = (
await self._store.get_event_ids_to_not_pull_from_backoff(
room_id, missing_prevs
)
if len(prevs_to_ignore) > 0:
)
if len(prevs_with_pull_backoff) > 0:
raise FederationPullAttemptBackoffError(
event_ids=prevs_to_ignore,
message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
event_ids=prevs_with_pull_backoff.keys(),
message=(
f"While computing context for event={event_id}, not attempting to "
f"pull missing prev_events={list(prevs_with_pull_backoff.keys())} "
"because we already tried to pull recently (backing off)."
),
retry_after_ms=(
max(prevs_with_pull_backoff.values()) - self._clock.time_msec()
),
)
if not missing_prevs:

View File

@ -1239,6 +1239,7 @@ class OidcProvider:
grandfather_existing_users,
extra_attributes,
auth_provider_session_id=sid,
registration_enabled=self._config.enable_registration,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:

View File

@ -383,6 +383,7 @@ class SsoHandler:
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
registration_enabled: bool = True,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -435,6 +436,10 @@ class SsoHandler:
auth_provider_session_id: An optional session ID from the IdP.
registration_enabled: An optional boolean to enable/disable automatic
registrations of new users. If false and the user does not exist then the
flow is aborted. Defaults to true.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@ -462,8 +467,16 @@ class SsoHandler:
auth_provider_id, remote_user_id, user_id
)
# Otherwise, generate a new user.
if not user_id:
if not user_id and not registration_enabled:
logger.info(
"User does not exist and registration are disabled for IdP '%s' and remote_user_id '%s'",
auth_provider_id,
remote_user_id,
)
raise MappingException(
"User does not exist and registrations are disabled"
)
elif not user_id: # Otherwise, generate a new user.
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
next_step_url = self._get_url_for_next_new_user_step(

View File

@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
return self
def is_unknown_endpoint(
e: HttpResponseException, synapse_error: Optional[SynapseError] = None
) -> bool:
"""
Returns true if the response was due to an endpoint being unimplemented.
Args:
e: The error response received from the remote server.
synapse_error: The above error converted to a SynapseError. This is
automatically generated if not provided.
"""
if synapse_error is None:
synapse_error = e.to_synapse_error()
# MSC3743 specifies that servers should return a 404 or 405 with an errcode
# of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
# to an unknown method, respectively.
#
# Older versions of servers don't properly handle this. This needs to be
# rather specific as some endpoints truly do return 404 errors.
return (
# 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
(e.code == 404 or e.code == 405)
and (
# Older Dendrites returned a text body or empty body.
# Older Conduit returned an empty body.
not e.response
or e.response == b"404 page not found"
# The proper response JSON with M_UNRECOGNIZED errcode.
or synapse_error.errcode == Codes.UNRECOGNIZED
)
) or (
# Older Synapses returned a 400 error.
e.code == 400
and synapse_error.errcode == Codes.UNRECOGNIZED
)

View File

@ -149,7 +149,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.password_reset
% {"server_name": self.hs.config.server.server_name},
% {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)

View File

@ -14,36 +14,7 @@
"""This module contains the implementation of both the client and server
protocols.
The basic structure of the protocol is line based, where the initial word of
each line specifies the command. The rest of the line is parsed based on the
command. For example, the `RDATA` command is defined as::
RDATA <stream_name> <token> <row_json>
(Note that `<row_json>` may contains spaces, but cannot contain newlines.)
Blank lines are ignored.
# Example
An example iteraction is shown below. Each line is prefixed with '>' or '<' to
indicate which side is sending, these are *not* included on the wire::
* connection established *
> SERVER localhost:8823
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
> RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823",
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]]
< PING 1490197675618
> ERROR server stopping
* connection closed by server *
An explanation of this protocol is available in docs/tcp_replication.md
"""
import fcntl
import logging

View File

@ -152,8 +152,8 @@ class Stream:
Returns:
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
position in stream (ie the highest token returned in the updates),
and `limited` is whether there are more updates to fetch.
"""
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(

View File

@ -617,14 +617,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_id, last_id + limit)
upto_token = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_id, upper_pos))
txn.execute(sql, (last_id, upto_token))
updates = [(row[0], row[1:]) for row in txn]
sql = (
@ -633,19 +633,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_id, upper_pos))
txn.execute(sql, (last_id, upto_token))
updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
updates.sort()
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return updates, upto_token, upto_token < current_id
return await self.db_pool.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn

View File

@ -1544,7 +1544,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self,
room_id: str,
event_ids: Collection[str],
) -> List[str]:
) -> Dict[str, int]:
"""
Filter down the events to ones that we've failed to pull before recently. Uses
exponential backoff.
@ -1554,7 +1554,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_ids: A list of events to filter down
Returns:
List of event_ids that should not be attempted to be pulled
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
@ -1570,13 +1571,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
current_time = self._clock.time_msec()
return [
event_failed_pull_attempt["event_id"]
for event_failed_pull_attempt in event_failed_pull_attempts
event_ids_with_backoff = {}
for event_failed_pull_attempt in event_failed_pull_attempts:
event_id = event_failed_pull_attempt["event_id"]
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
if current_time
< event_failed_pull_attempt["last_attempt_ts"]
backoff_end_time = (
event_failed_pull_attempt["last_attempt_ts"]
+ (
2
** min(
@ -1585,7 +1587,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
)
* BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
]
)
if current_time < backoff_end_time: # `backoff_end_time` is exclusive
event_ids_with_backoff[event_id] = backoff_end_time
return event_ids_with_backoff
async def get_missing_events(
self,

View File

@ -100,7 +100,6 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -289,180 +288,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
unique=True,
)
self.db_pool.updates.register_background_update_handler(
"event_push_backfill_thread_id",
self._background_backfill_thread_id,
)
# Indexes which will be used to quickly make the thread_id column non-null.
self.db_pool.updates.register_background_index_update(
"event_push_actions_thread_id_null",
index_name="event_push_actions_thread_id_null",
table="event_push_actions",
columns=["thread_id"],
where_clause="thread_id IS NULL",
)
self.db_pool.updates.register_background_index_update(
"event_push_summary_thread_id_null",
index_name="event_push_summary_thread_id_null",
table="event_push_summary",
columns=["thread_id"],
where_clause="thread_id IS NULL",
)
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates the event_push_actions and event_push_summary tables.
self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
self._event_push_backfill_thread_id_done = False
@wrap_as_background_process("check_event_push_backfill_thread_id")
async def _check_event_push_backfill_thread_id(self) -> None:
"""
Has thread_id finished backfilling?
If not, we need to just-in-time update it so the queries work.
"""
done = await self.db_pool.updates.has_completed_background_update(
"event_push_backfill_thread_id"
)
if done:
self._event_push_backfill_thread_id_done = True
else:
# Reschedule to run.
self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
async def _background_backfill_thread_id(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Fill in the thread_id field for event_push_actions and event_push_summary.
This is preparatory so that it can be made non-nullable in the future.
Because all current (null) data is done in an unthreaded manner this
simply assumes it is on the "main" timeline. Since event_push_actions
are periodically cleared it is not possible to correctly re-calculate
the thread_id.
"""
event_push_actions_done = progress.get("event_push_actions_done", False)
def add_thread_id_txn(
txn: LoggingTransaction, start_stream_ordering: int
) -> int:
sql = """
SELECT stream_ordering
FROM event_push_actions
WHERE
thread_id IS NULL
AND stream_ordering > ?
ORDER BY stream_ordering
LIMIT ?
"""
txn.execute(sql, (start_stream_ordering, batch_size))
# No more rows to process.
rows = txn.fetchall()
if not rows:
progress["event_push_actions_done"] = True
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
return 0
# Update the thread ID for any of those rows.
max_stream_ordering = rows[-1][0]
sql = """
UPDATE event_push_actions
SET thread_id = 'main'
WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
"""
txn.execute(
sql,
(
start_stream_ordering,
max_stream_ordering,
),
)
# Update progress.
processed_rows = txn.rowcount
progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
return processed_rows
def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
min_user_id = progress.get("max_summary_user_id", "")
min_room_id = progress.get("max_summary_room_id", "")
# Slightly overcomplicated query for getting the Nth user ID / room
# ID tuple, or the last if there are less than N remaining.
sql = """
SELECT user_id, room_id FROM (
SELECT user_id, room_id FROM event_push_summary
WHERE (user_id, room_id) > (?, ?)
AND thread_id IS NULL
ORDER BY user_id, room_id
LIMIT ?
) AS e
ORDER BY user_id DESC, room_id DESC
LIMIT 1
"""
txn.execute(sql, (min_user_id, min_room_id, batch_size))
row = txn.fetchone()
if not row:
return 0
max_user_id, max_room_id = row
sql = """
UPDATE event_push_summary
SET thread_id = 'main'
WHERE
(?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
AND thread_id IS NULL
"""
txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
processed_rows = txn.rowcount
progress["max_summary_user_id"] = max_user_id
progress["max_summary_room_id"] = max_room_id
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
return processed_rows
# First update the event_push_actions table, then the event_push_summary table.
#
# Note that the event_push_actions_staging table is ignored since it is
# assumed that items in that table will only exist for a short period of
# time.
if not event_push_actions_done:
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
add_thread_id_txn,
progress.get("max_event_push_actions_stream_ordering", 0),
)
else:
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
add_thread_id_summary_txn,
)
# Only done after the event_push_summary table is done.
if not result:
await self.db_pool.updates._end_background_update(
"event_push_backfill_thread_id"
)
return result
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.
@ -711,25 +536,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
# First ensure that the existing rows have an updated thread_id field.
if not self._event_push_backfill_thread_id_done:
txn.execute(
"""
UPDATE event_push_summary
SET thread_id = ?
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
""",
(MAIN_TIMELINE, room_id, user_id),
)
txn.execute(
"""
UPDATE event_push_actions
SET thread_id = ?
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
""",
(MAIN_TIMELINE, room_id, user_id),
)
# First we pull the counts from the summary table.
#
# We check that `last_receipt_stream_ordering` matches the stream ordering of the
@ -1545,25 +1351,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering, *thread_args),
)
# First ensure that the existing rows have an updated thread_id field.
if not self._event_push_backfill_thread_id_done:
txn.execute(
"""
UPDATE event_push_summary
SET thread_id = ?
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
""",
(MAIN_TIMELINE, room_id, user_id),
)
txn.execute(
"""
UPDATE event_push_actions
SET thread_id = ?
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
""",
(MAIN_TIMELINE, room_id, user_id),
)
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
unread_counts = self._get_notif_unread_count_for_user_room(
@ -1698,19 +1485,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
"""
# Ensure that any new actions have an updated thread_id.
if not self._event_push_backfill_thread_id_done:
txn.execute(
"""
UPDATE event_push_actions
SET thread_id = ?
WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
""",
(MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
)
# XXX Do we need to update summaries here too?
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id, thread_id,
@ -1773,20 +1547,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
# Ensure that any updated threads have the proper thread_id.
if not self._event_push_backfill_thread_id_done:
txn.execute_batch(
"""
UPDATE event_push_summary
SET thread_id = ?
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
""",
[
(MAIN_TIMELINE, room_id, user_id)
for user_id, room_id, _ in summaries
],
)
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",

View File

@ -34,6 +34,13 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
":memory:",
)
# A connection to a database that has already been prepared, to use as a
# base for an in-memory connection. This is used during unit tests to
# speed up setting up the DB.
self._prepped_conn: Optional[sqlite3.Connection] = database_config.get(
"_TEST_PREPPED_CONN"
)
if platform.python_implementation() == "PyPy":
# pypy's sqlite3 module doesn't handle bytearrays, convert them
# back to bytes.
@ -84,6 +91,14 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
# In memory databases need to be rebuilt each time. Ideally we'd
# reuse the same connection as we do when starting up, but that
# would involve using adbapi before we have started the reactor.
#
# If we have a `prepped_conn` we can use that to initialise the DB,
# otherwise we need to call `prepare_database`.
if self._prepped_conn is not None:
# Initialise the new DB from the pre-prepared DB.
assert isinstance(db_conn.conn, sqlite3.Connection)
self._prepped_conn.backup(db_conn.conn)
else:
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)

View File

@ -95,9 +95,9 @@ Changes in SCHEMA_VERSION = 74:
SCHEMA_COMPAT_VERSION = (
# The threads_id column must exist for event_push_actions, event_push_summary,
# receipts_linearized, and receipts_graph.
73
# The threads_id column must written to with non-null values event_push_actions,
# event_push_actions_staging, and event_push_summary.
74
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -0,0 +1,28 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Force the background updates from 06thread_notifications.sql to run in the
-- foreground as code will now require those to be "done".
DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id';
-- Overwrite any null thread_id values.
UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL;
UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL;
UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL;
-- Drop the background updates to calculate the indexes used to find null thread_ids.
DELETE FROM background_updates WHERE update_name = 'event_push_actions_thread_id_null';
DELETE FROM background_updates WHERE update_name = 'event_push_summary_thread_id_null';

View File

@ -0,0 +1,23 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Drop the indexes used to find null thread_ids.
DROP INDEX IF EXISTS event_push_actions_thread_id_null;
DROP INDEX IF EXISTS event_push_summary_thread_id_null;
-- The thread_id columns can now be made non-nullable.
ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL;
ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL;
ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL;

View File

@ -0,0 +1,99 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The thread_id columns can now be made non-nullable.
--
-- SQLite doesn't support modifying columns to an existing table, so it must
-- be recreated.
-- Create the new tables.
CREATE TABLE event_push_actions_staging_new (
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
actions TEXT NOT NULL,
notif SMALLINT NOT NULL,
highlight SMALLINT NOT NULL,
unread SMALLINT,
thread_id TEXT NOT NULL,
inserted_ts BIGINT
);
CREATE TABLE event_push_actions_new (
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
profile_tag VARCHAR(32),
actions TEXT NOT NULL,
topological_ordering BIGINT,
stream_ordering BIGINT,
notif SMALLINT,
highlight SMALLINT,
unread SMALLINT,
thread_id TEXT NOT NULL,
CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
);
CREATE TABLE event_push_summary_new (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notif_count BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL,
unread_count BIGINT,
last_receipt_stream_ordering BIGINT,
thread_id TEXT NOT NULL
);
-- Copy the data.
INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts)
SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts
FROM event_push_actions_staging;
INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id)
SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id
FROM event_push_actions;
INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id)
SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id
FROM event_push_summary;
-- Drop the old tables.
DROP TABLE event_push_actions_staging;
DROP TABLE event_push_actions;
DROP TABLE event_push_summary;
-- Rename the tables.
ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging;
ALTER TABLE event_push_actions_new RENAME TO event_push_actions;
ALTER TABLE event_push_summary_new RENAME TO event_push_summary;
-- Recreate the indexes.
CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering);
CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering );
CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);
CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id );
CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering);
CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary (user_id, room_id, thread_id) ;
-- Recreate some indexes in the background, by re-running the background updates
-- from 72/02event_push_actions_index.sql and 72/06thread_notifications.sql.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7403, 'event_push_summary_unique_index2', '{}')
ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7403, 'event_push_actions_stream_highlight_index', '{}')
ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';

View File

@ -960,7 +960,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
@ -1015,3 +1015,122 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
},
)
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
def test_query_local_devices_appservice(self) -> None:
"""Test that querying of appservices for keys overrides responses from the database."""
local_user = "@boris:" + self.hs.hostname
device_1 = "abc"
device_2 = "def"
device_3 = "ghi"
# There are 3 devices:
#
# 1. One which is uploaded to the homeserver.
# 2. One which is uploaded to the homeserver, but a newer copy is returned
# by the appservice.
# 3. One which is only returned by the appservice.
device_key_1: JsonDict = {
"user_id": local_user,
"device_id": device_1,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"ed25519:abc": "base64+ed25519+key",
"curve25519:abc": "base64+curve25519+key",
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
device_key_2a: JsonDict = {
"user_id": local_user,
"device_id": device_2,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"ed25519:def": "base64+ed25519+key",
"curve25519:def": "base64+curve25519+key",
},
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
device_key_2b: JsonDict = {
"user_id": local_user,
"device_id": device_2,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
# The device ID is the same (above), but the keys are different.
"keys": {
"ed25519:xyz": "base64+ed25519+key",
"curve25519:xyz": "base64+curve25519+key",
},
"signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
}
device_key_3: JsonDict = {
"user_id": local_user,
"device_id": device_3,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"ed25519:jkl": "base64+ed25519+key",
"curve25519:jkl": "base64+curve25519+key",
},
"signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
}
# Upload keys for devices 1 & 2a.
self.get_success(
self.handler.upload_keys_for_user(
local_user, device_1, {"device_keys": device_key_1}
)
)
self.get_success(
self.handler.upload_keys_for_user(
local_user, device_2, {"device_keys": device_key_2a}
)
)
# Inject an appservice interested in this user.
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
[appservice]
)
# Setup a response.
self.appservice_api.query_keys.return_value = make_awaitable(
{
"device_keys": {
local_user: {device_2: device_key_2b, device_3: device_key_3}
}
}
)
# Request all devices.
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertIn(local_user, res)
for res_key in res[local_user].values():
res_key.pop("unsigned", None)
self.assertDictEqual(
res,
{
local_user: {
device_1: device_key_1,
device_2: device_key_2b,
device_3: device_key_3,
}
},
)

View File

@ -922,7 +922,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
@override_config({"oidc_config": DEFAULT_CONFIG})
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": True}})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
userinfo: dict = {
@ -975,6 +975,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": False}})
def test_map_userinfo_to_user_does_not_register_new_user(self) -> None:
"""Ensures new users are not registered if the enabled registration flag is disabled."""
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"User does not exist and registrations are disabled",
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""

View File

@ -54,6 +54,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)

View File

@ -37,11 +37,6 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)

View File

@ -0,0 +1,89 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import synapse
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.types import JsonDict
from tests.replication._base import BaseStreamTestCase
logger = logging.getLogger(__name__)
class ToDeviceStreamTestCase(BaseStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
]
def test_to_device_stream(self) -> None:
store = self.hs.get_datastores().main
user1 = self.register_user("user1", "pass")
self.login("user1", "pass", "device")
user2 = self.register_user("user2", "pass")
self.login("user2", "pass", "device")
# connect to pull the updates related to users creation/login
self.reconnect()
self.replicate()
self.test_handler.received_rdata_rows.clear()
# disconnect so we can accumulate the updates without pulling them
self.disconnect()
msg: JsonDict = {}
msg["sender"] = "@sender:example.org"
msg["type"] = "m.new_device"
# add messages to the device inbox for user1 up until the
# limit defined for a stream update batch
for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
msg["content"] = {"device": {}}
messages = {user1: {"device": msg}}
self.get_success(
store.add_messages_from_remote_to_device_inbox(
"example.org",
f"{i}",
messages,
)
)
# add one more message, for user2 this time
# this message would be dropped before fixing #15335
msg["content"] = {"device": {}}
messages = {user2: {"device": msg}}
self.get_success(
store.add_messages_from_remote_to_device_inbox(
"example.org",
f"{_STREAM_UPDATE_TARGET_ROW_COUNT}",
messages,
)
)
# replication is disconnected so we shouldn't get any updates yet
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should receive the fact that we have to_device updates
# for user1 and user2
received_rows = self.test_handler.received_rdata_rows
self.assertEqual(len(received_rows), 2)
self.assertEqual(received_rows[0][2].entity, user1)
self.assertEqual(received_rows[1][2].entity, user2)

View File

@ -16,6 +16,7 @@ import json
import logging
import os
import os.path
import sqlite3
import time
import uuid
import warnings
@ -79,7 +80,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
@ -104,6 +107,10 @@ P = ParamSpec("P")
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
# A pre-prepared SQLite DB that is used as a template when creating new SQLite
# DB each test run. This dramatically speeds up test set up when using SQLite.
PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
class TimedOutException(Exception):
"""
@ -899,6 +906,22 @@ def setup_test_homeserver(
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
# Check if we have set up a DB that we can use as a template.
global PREPPED_SQLITE_DB_CONN
if PREPPED_SQLITE_DB_CONN is None:
temp_engine = create_engine(database_config)
PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
)
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
prepare_database(
PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
)
database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]

View File

@ -1143,19 +1143,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
tok = self.login("alice", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
failure_time = self.clock.time_msec()
self.get_success(
self.store.record_event_failed_pull_attempt(
room_id, "$failed_event_id", "fake cause"
)
)
event_ids_to_backoff = self.get_success(
event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
self.assertEqual(
event_ids_with_backoff,
# We expect a 2^1 hour backoff after a single failed attempt.
{"$failed_event_id": failure_time + 2 * 60 * 60 * 1000},
)
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
@ -1179,14 +1184,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# attempt (2^1 hours).
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
event_ids_to_backoff = self.get_success(
event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
# Since this function only returns events we should backoff from, time has
# elapsed past the backoff range so there is no events to backoff from.
self.assertEqual(event_ids_to_backoff, [])
self.assertEqual(event_ids_with_backoff, {})
@attr.s(auto_attribs=True)

View File

@ -146,6 +146,9 @@ class TestCase(unittest.TestCase):
% (current_context(),)
)
# Disable GC for duration of test. See below for why.
gc.disable()
old_level = logging.getLogger().level
if level is not None and old_level != level:
@ -163,12 +166,19 @@ class TestCase(unittest.TestCase):
return orig()
# We want to force a GC to workaround problems with deferreds leaking
# logcontexts when they are GCed (see the logcontext docs).
#
# The easiest way to do this would be to do a full GC after each test
# run, but that is very expensive. Instead, we disable GC (above) for
# the duration of the test so that we only need to run a gen-0 GC, which
# is a lot quicker.
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
gc.collect(0)
gc.enable()
set_current_context(SENTINEL_CONTEXT)
return ret