Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
bf81ee4217
|
@ -0,0 +1 @@
|
|||
Add the ability to enable/disable registrations when in the OIDC flow.
|
|
@ -1 +1 @@
|
|||
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
|
||||
Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)).
|
|
@ -0,0 +1 @@
|
|||
Speed up unit tests when using SQLite3.
|
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where some to_device messages could be dropped when using workers.
|
|
@ -0,0 +1 @@
|
|||
Make the `thread_id` column on `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` non-null.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug introduced in Synapse 1.70.0 where the background sync from a faster join could spin for hours when one of the events involved had been marked for backoff.
|
|
@ -0,0 +1 @@
|
|||
Fix missing app variable in mail subject for password resets. Contributed by Cyberes.
|
|
@ -0,0 +1 @@
|
|||
Add some clarification to the doc/comments regarding TCP replication.
|
|
@ -25,7 +25,7 @@ position of all streams. The server then periodically sends `RDATA` commands
|
|||
which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where
|
||||
the format of `<row>` is defined by the individual streams. The
|
||||
`<instance_name>` is the name of the Synapse process that generated the data
|
||||
(usually "master").
|
||||
(usually "master"). We expect an RDATA for every row in the DB.
|
||||
|
||||
Error reporting happens by either the client or server sending an ERROR
|
||||
command, and usually the connection will be closed.
|
||||
|
@ -107,7 +107,7 @@ reconnect, following the steps above.
|
|||
If the server sends messages faster than the client can consume them the
|
||||
server will first buffer a (fairly large) number of commands and then
|
||||
disconnect the client. This ensures that we don't queue up an unbounded
|
||||
number of commands in memory and gives us a potential oppurtunity to
|
||||
number of commands in memory and gives us a potential opportunity to
|
||||
squawk loudly. When/if the client recovers it can reconnect to the
|
||||
server and ask for missed messages.
|
||||
|
||||
|
@ -122,7 +122,7 @@ since these include tokens which can be used to restart the stream on
|
|||
connection errors.
|
||||
|
||||
The client should keep track of the token in the last RDATA command
|
||||
received for each stream so that on reconneciton it can start streaming
|
||||
received for each stream so that on reconnection it can start streaming
|
||||
from the correct place. Note: not all RDATA have valid tokens due to
|
||||
batching. See `RdataCommand` for more details.
|
||||
|
||||
|
@ -188,7 +188,8 @@ client (C):
|
|||
Two positions are included, the "new" position and the last position sent respectively.
|
||||
This allows servers to tell instances that the positions have advanced but no
|
||||
data has been written, without clients needlessly checking to see if they
|
||||
have missed any updates.
|
||||
have missed any updates. Instances will only fetch stuff if there is a gap between
|
||||
their current position and the given last position.
|
||||
|
||||
#### ERROR (S, C)
|
||||
|
||||
|
|
|
@ -3100,6 +3100,11 @@ Options for each entry include:
|
|||
match a pre-existing account instead of failing. This could be used if
|
||||
switching from password logins to OIDC. Defaults to false.
|
||||
|
||||
* `enable_registration`: set to 'false' to disable automatic registration of new
|
||||
users. This allows the OIDC SSO flow to be limited to sign in only, rather than
|
||||
automatically registering users that have a valid SSO login but do not have
|
||||
a pre-registered account. Defaults to true.
|
||||
|
||||
* `user_mapping_provider`: Configuration for how attributes returned from a OIDC
|
||||
provider are mapped onto a matrix user. This setting has the following
|
||||
sub-properties:
|
||||
|
@ -3216,6 +3221,7 @@ oidc_providers:
|
|||
userinfo_endpoint: "https://accounts.example.com/userinfo"
|
||||
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
skip_verification: true
|
||||
enable_registration: true
|
||||
user_mapping_provider:
|
||||
config:
|
||||
subject_claim: "id"
|
||||
|
|
|
@ -27,7 +27,7 @@ from synapse.util import json_decoder
|
|||
|
||||
if typing.TYPE_CHECKING:
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -682,18 +682,27 @@ class FederationPullAttemptBackoffError(RuntimeError):
|
|||
Attributes:
|
||||
event_id: The event_id which we are refusing to pull
|
||||
message: A custom error message that gives more context
|
||||
retry_after_ms: The remaining backoff interval, in milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, event_ids: List[str], message: Optional[str]):
|
||||
self.event_ids = event_ids
|
||||
def __init__(
|
||||
self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int
|
||||
):
|
||||
event_ids = list(event_ids)
|
||||
|
||||
if message:
|
||||
error_message = message
|
||||
else:
|
||||
error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)."
|
||||
error_message = (
|
||||
f"Not attempting to pull event_ids={event_ids} because we already "
|
||||
"tried to pull them recently (backing off)."
|
||||
)
|
||||
|
||||
super().__init__(error_message)
|
||||
|
||||
self.event_ids = event_ids
|
||||
self.retry_after_ms = retry_after_ms
|
||||
|
||||
|
||||
class HttpResponseException(CodeMessageException):
|
||||
"""
|
||||
|
|
|
@ -30,7 +30,7 @@ from prometheus_client import Counter
|
|||
from typing_extensions import TypeGuard
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.api.errors import CodeMessageException, HttpResponseException
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
TransactionOneTimeKeysCount,
|
||||
|
@ -38,7 +38,7 @@ from synapse.appservice import (
|
|||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import SerializeEventConfig, serialize_event
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
|
||||
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
|
@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
"""Claim one time keys from an application service.
|
||||
|
||||
Note that any error (including a timeout) is treated as the application
|
||||
service having no information.
|
||||
|
||||
Args:
|
||||
service: The application service to query.
|
||||
query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
Returns:
|
||||
|
@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
body,
|
||||
headers={"Authorization": [f"Bearer {service.hs_token}"]},
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
except HttpResponseException as e:
|
||||
# The appservice doesn't support this endpoint.
|
||||
if e.code == 404 or e.code == 405:
|
||||
if is_unknown_endpoint(e):
|
||||
return {}, query
|
||||
logger.warning("claim_keys to %s received %s", uri, e.code)
|
||||
return {}, query
|
||||
|
@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
|
||||
return response, missing
|
||||
|
||||
async def query_keys(
|
||||
self, service: "ApplicationService", query: Dict[str, List[str]]
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Query the application service for keys.
|
||||
|
||||
Note that any error (including a timeout) is treated as the application
|
||||
service having no information.
|
||||
|
||||
Args:
|
||||
service: The application service to query.
|
||||
query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
Returns:
|
||||
A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
|
||||
|
||||
device_keys is a map of user ID -> a map device ID -> device info.
|
||||
"""
|
||||
if service.url is None:
|
||||
return {}
|
||||
|
||||
# This is required by the configuration.
|
||||
assert service.hs_token is not None
|
||||
|
||||
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
|
||||
try:
|
||||
response = await self.post_json_get_json(
|
||||
uri,
|
||||
query,
|
||||
headers={"Authorization": [f"Bearer {service.hs_token}"]},
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
# The appservice doesn't support this endpoint.
|
||||
if is_unknown_endpoint(e):
|
||||
return {}
|
||||
logger.warning("query_keys to %s received %s", uri, e.code)
|
||||
return {}
|
||||
except Exception as ex:
|
||||
logger.warning("query_keys to %s threw exception %s", uri, ex)
|
||||
return {}
|
||||
|
||||
return response
|
||||
|
||||
def _serialize(
|
||||
self, service: "ApplicationService", events: Iterable[EventBase]
|
||||
) -> List[JsonDict]:
|
||||
|
|
|
@ -79,6 +79,11 @@ class ExperimentalConfig(Config):
|
|||
"msc3983_appservice_otk_claims", False
|
||||
)
|
||||
|
||||
# MSC3984: Proxying key queries to exclusive ASes.
|
||||
self.msc3984_appservice_key_query: bool = experimental.get(
|
||||
"msc3984_appservice_key_query", False
|
||||
)
|
||||
|
||||
# MSC3706 (server-side support for partial state in /send_join responses)
|
||||
# Synapse will always serve partial state responses to requests using the stable
|
||||
# query parameter `omit_members`. If this flag is set, Synapse will also serve
|
||||
|
|
|
@ -136,6 +136,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
|||
"type": "array",
|
||||
"items": SsoAttributeRequirement.JSON_SCHEMA,
|
||||
},
|
||||
"enable_registration": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -306,6 +307,7 @@ def _parse_oidc_config_dict(
|
|||
user_mapping_provider_class=user_mapping_provider_class,
|
||||
user_mapping_provider_config=user_mapping_provider_config,
|
||||
attribute_requirements=attribute_requirements,
|
||||
enable_registration=oidc_config.get("enable_registration", True),
|
||||
)
|
||||
|
||||
|
||||
|
@ -405,3 +407,6 @@ class OidcProviderConfig:
|
|||
|
||||
# required attributes to require in userinfo to allow login/registration
|
||||
attribute_requirements: List[SsoAttributeRequirement]
|
||||
|
||||
# Whether automatic registrations are enabled in the ODIC flow. Defaults to True
|
||||
enable_registration: bool
|
||||
|
|
|
@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
|
|||
event_from_pdu_json,
|
||||
)
|
||||
from synapse.federation.transport.client import SendJoinResponse
|
||||
from synapse.http.client import is_unknown_endpoint
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
|
@ -759,43 +760,6 @@ class FederationClient(FederationBase):
|
|||
|
||||
return signed_auth
|
||||
|
||||
def _is_unknown_endpoint(
|
||||
self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if the response was due to an endpoint being unimplemented.
|
||||
|
||||
Args:
|
||||
e: The error response received from the remote server.
|
||||
synapse_error: The above error converted to a SynapseError. This is
|
||||
automatically generated if not provided.
|
||||
|
||||
"""
|
||||
if synapse_error is None:
|
||||
synapse_error = e.to_synapse_error()
|
||||
# MSC3743 specifies that servers should return a 404 or 405 with an errcode
|
||||
# of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
|
||||
# to an unknown method, respectively.
|
||||
#
|
||||
# Older versions of servers don't properly handle this. This needs to be
|
||||
# rather specific as some endpoints truly do return 404 errors.
|
||||
return (
|
||||
# 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
|
||||
(e.code == 404 or e.code == 405)
|
||||
and (
|
||||
# Older Dendrites returned a text or empty body.
|
||||
# Older Conduit returned an empty body.
|
||||
not e.response
|
||||
or e.response == b"404 page not found"
|
||||
# The proper response JSON with M_UNRECOGNIZED errcode.
|
||||
or synapse_error.errcode == Codes.UNRECOGNIZED
|
||||
)
|
||||
) or (
|
||||
# Older Synapses returned a 400 error.
|
||||
e.code == 400
|
||||
and synapse_error.errcode == Codes.UNRECOGNIZED
|
||||
)
|
||||
|
||||
async def _try_destination_list(
|
||||
self,
|
||||
description: str,
|
||||
|
@ -887,7 +851,7 @@ class FederationClient(FederationBase):
|
|||
elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
|
||||
failover = True
|
||||
|
||||
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
|
||||
elif failover_on_unknown_endpoint and is_unknown_endpoint(
|
||||
e, synapse_error
|
||||
):
|
||||
failover = True
|
||||
|
@ -1223,7 +1187,7 @@ class FederationClient(FederationBase):
|
|||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
|
||||
# and raise.
|
||||
if not self._is_unknown_endpoint(e):
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
||||
|
@ -1297,7 +1261,7 @@ class FederationClient(FederationBase):
|
|||
# fallback to the v1 endpoint if the room uses old-style event IDs.
|
||||
# Otherwise, consider it a legitimate error and raise.
|
||||
err = e.to_synapse_error()
|
||||
if self._is_unknown_endpoint(e, err):
|
||||
if is_unknown_endpoint(e, err):
|
||||
if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -1358,7 +1322,7 @@ class FederationClient(FederationBase):
|
|||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
|
||||
# and raise.
|
||||
if not self._is_unknown_endpoint(e):
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
|
||||
|
@ -1629,7 +1593,7 @@ class FederationClient(FederationBase):
|
|||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the unstable endpoint. Otherwise, consider it a
|
||||
# legitimate error and raise.
|
||||
if not self._is_unknown_endpoint(e):
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import (
|
|||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
|
@ -846,6 +847,10 @@ class ApplicationServicesHandler:
|
|||
]:
|
||||
"""Claim one time keys from application services.
|
||||
|
||||
Users which are exclusively owned by an application service are sent a
|
||||
key claim request to check if the application service provides keys
|
||||
directly.
|
||||
|
||||
Args:
|
||||
query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
|
@ -901,3 +906,59 @@ class ApplicationServicesHandler:
|
|||
missing.extend(result[1])
|
||||
|
||||
return claimed_keys, missing
|
||||
|
||||
async def query_keys(
|
||||
self, query: Mapping[str, Optional[List[str]]]
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Query application services for device keys.
|
||||
|
||||
Users which are exclusively owned by an application service are queried
|
||||
for keys to check if the application service provides keys directly.
|
||||
|
||||
Args:
|
||||
query: map from user_id to a list of devices to query
|
||||
|
||||
Returns:
|
||||
A map from user_id -> device_id -> device details
|
||||
"""
|
||||
services = self.store.get_app_services()
|
||||
|
||||
# Partition the users by appservice.
|
||||
query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
|
||||
for user_id, device_ids in query.items():
|
||||
if not self.store.get_if_app_services_interested_in_user(user_id):
|
||||
continue
|
||||
|
||||
# Find the associated appservice.
|
||||
for service in services:
|
||||
if service.is_exclusive_user(user_id):
|
||||
query_by_appservice.setdefault(service.id, {})[user_id] = (
|
||||
device_ids or []
|
||||
)
|
||||
continue
|
||||
|
||||
# Query each service in parallel.
|
||||
results = await make_deferred_yieldable(
|
||||
defer.DeferredList(
|
||||
[
|
||||
run_in_background(
|
||||
self.appservice_api.query_keys,
|
||||
# We know this must be an app service.
|
||||
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
|
||||
service_query,
|
||||
)
|
||||
for service_id, service_query in query_by_appservice.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Patch together the results -- they are all independent (since they
|
||||
# require exclusive control over the users). They get returned as a single
|
||||
# dictionary.
|
||||
key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for success, result in results:
|
||||
if success:
|
||||
key_queries.update(result)
|
||||
|
||||
return key_queries
|
||||
|
|
|
@ -91,6 +91,9 @@ class E2eKeysHandler:
|
|||
self._query_appservices_for_otks = (
|
||||
hs.config.experimental.msc3983_appservice_otk_claims
|
||||
)
|
||||
self._query_appservices_for_keys = (
|
||||
hs.config.experimental.msc3984_appservice_key_query
|
||||
)
|
||||
|
||||
@trace
|
||||
@cancellable
|
||||
|
@ -497,6 +500,19 @@ class E2eKeysHandler:
|
|||
local_query, include_displaynames
|
||||
)
|
||||
|
||||
# Check if the application services have any additional results.
|
||||
if self._query_appservices_for_keys:
|
||||
# Query the appservices for any keys.
|
||||
appservice_results = await self._appservice_handler.query_keys(query)
|
||||
|
||||
# Merge results, overriding with what the appservice returned.
|
||||
for user_id, devices in appservice_results.get("device_keys", {}).items():
|
||||
# Copy the appservice device info over the homeserver device info, but
|
||||
# don't completely overwrite it.
|
||||
results.setdefault(user_id, {}).update(devices)
|
||||
|
||||
# TODO Handle cross-signing keys.
|
||||
|
||||
# Build the result structure
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, device_info in device_keys.items():
|
||||
|
|
|
@ -1949,27 +1949,25 @@ class FederationHandler:
|
|||
)
|
||||
for event in events:
|
||||
for attempt in itertools.count():
|
||||
# We try a new destination on every iteration.
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await self._federation_event_handler.update_state_for_partial_state_event(
|
||||
destination, event
|
||||
)
|
||||
break
|
||||
except FederationPullAttemptBackoffError as exc:
|
||||
# Log a warning about why we failed to process the event (the error message
|
||||
# for `FederationPullAttemptBackoffError` is pretty good)
|
||||
logger.warning("_sync_partial_state_room: %s", exc)
|
||||
# We do not record a failed pull attempt when we backoff fetching a missing
|
||||
# `prev_event` because not being able to fetch the `prev_events` just means
|
||||
# we won't be able to de-outlier the pulled event. But we can still use an
|
||||
# `outlier` in the state/auth chain for another event. So we shouldn't stop
|
||||
# a downstream event from trying to pull it.
|
||||
#
|
||||
# This avoids a cascade of backoff for all events in the DAG downstream from
|
||||
# one event backoff upstream.
|
||||
except FederationError as e:
|
||||
# TODO: We should `record_event_failed_pull_attempt` here,
|
||||
# see https://github.com/matrix-org/synapse/issues/13700
|
||||
except FederationPullAttemptBackoffError as e:
|
||||
# We are in the backoff period for one of the event's
|
||||
# prev_events. Wait it out and try again after.
|
||||
logger.warning(
|
||||
"%s; waiting for %d ms...", e, e.retry_after_ms
|
||||
)
|
||||
await self.clock.sleep(e.retry_after_ms / 1000)
|
||||
|
||||
# Success, no need to try the rest of the destinations.
|
||||
break
|
||||
except FederationError as e:
|
||||
if attempt == len(destinations) - 1:
|
||||
# We have tried every remote server for this event. Give up.
|
||||
# TODO(faster_joins) giving up isn't the right thing to do
|
||||
|
@ -1986,6 +1984,8 @@ class FederationHandler:
|
|||
destination,
|
||||
e,
|
||||
)
|
||||
# TODO: We should `record_event_failed_pull_attempt` here,
|
||||
# see https://github.com/matrix-org/synapse/issues/13700
|
||||
raise
|
||||
|
||||
# Try the next remote server.
|
||||
|
|
|
@ -140,6 +140,7 @@ class FederationEventHandler:
|
|||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
|
@ -1038,8 +1039,8 @@ class FederationEventHandler:
|
|||
|
||||
Raises:
|
||||
FederationPullAttemptBackoffError if we are are deliberately not attempting
|
||||
to pull the given event over federation because we've already done so
|
||||
recently and are backing off.
|
||||
to pull one of the given event's `prev_event`s over federation because
|
||||
we've already done so recently and are backing off.
|
||||
FederationError if we fail to get the state from the remote server after any
|
||||
missing `prev_event`s.
|
||||
"""
|
||||
|
@ -1053,13 +1054,22 @@ class FederationEventHandler:
|
|||
# If we've already recently attempted to pull this missing event, don't
|
||||
# try it again so soon. Since we have to fetch all of the prev_events, we can
|
||||
# bail early here if we find any to ignore.
|
||||
prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
|
||||
prevs_with_pull_backoff = (
|
||||
await self._store.get_event_ids_to_not_pull_from_backoff(
|
||||
room_id, missing_prevs
|
||||
)
|
||||
if len(prevs_to_ignore) > 0:
|
||||
)
|
||||
if len(prevs_with_pull_backoff) > 0:
|
||||
raise FederationPullAttemptBackoffError(
|
||||
event_ids=prevs_to_ignore,
|
||||
message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
|
||||
event_ids=prevs_with_pull_backoff.keys(),
|
||||
message=(
|
||||
f"While computing context for event={event_id}, not attempting to "
|
||||
f"pull missing prev_events={list(prevs_with_pull_backoff.keys())} "
|
||||
"because we already tried to pull recently (backing off)."
|
||||
),
|
||||
retry_after_ms=(
|
||||
max(prevs_with_pull_backoff.values()) - self._clock.time_msec()
|
||||
),
|
||||
)
|
||||
|
||||
if not missing_prevs:
|
||||
|
|
|
@ -1239,6 +1239,7 @@ class OidcProvider:
|
|||
grandfather_existing_users,
|
||||
extra_attributes,
|
||||
auth_provider_session_id=sid,
|
||||
registration_enabled=self._config.enable_registration,
|
||||
)
|
||||
|
||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||
|
|
|
@ -383,6 +383,7 @@ class SsoHandler:
|
|||
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
|
||||
extra_login_attributes: Optional[JsonDict] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
registration_enabled: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||
|
@ -435,6 +436,10 @@ class SsoHandler:
|
|||
|
||||
auth_provider_session_id: An optional session ID from the IdP.
|
||||
|
||||
registration_enabled: An optional boolean to enable/disable automatic
|
||||
registrations of new users. If false and the user does not exist then the
|
||||
flow is aborted. Defaults to true.
|
||||
|
||||
Raises:
|
||||
MappingException if there was a problem mapping the response to a user.
|
||||
RedirectException: if the mapping provider needs to redirect the user
|
||||
|
@ -462,8 +467,16 @@ class SsoHandler:
|
|||
auth_provider_id, remote_user_id, user_id
|
||||
)
|
||||
|
||||
# Otherwise, generate a new user.
|
||||
if not user_id:
|
||||
if not user_id and not registration_enabled:
|
||||
logger.info(
|
||||
"User does not exist and registration are disabled for IdP '%s' and remote_user_id '%s'",
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
raise MappingException(
|
||||
"User does not exist and registrations are disabled"
|
||||
)
|
||||
elif not user_id: # Otherwise, generate a new user.
|
||||
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
|
||||
|
||||
next_step_url = self._get_url_for_next_new_user_step(
|
||||
|
|
|
@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
|||
|
||||
def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
|
||||
return self
|
||||
|
||||
|
||||
def is_unknown_endpoint(
|
||||
e: HttpResponseException, synapse_error: Optional[SynapseError] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if the response was due to an endpoint being unimplemented.
|
||||
|
||||
Args:
|
||||
e: The error response received from the remote server.
|
||||
synapse_error: The above error converted to a SynapseError. This is
|
||||
automatically generated if not provided.
|
||||
|
||||
"""
|
||||
if synapse_error is None:
|
||||
synapse_error = e.to_synapse_error()
|
||||
# MSC3743 specifies that servers should return a 404 or 405 with an errcode
|
||||
# of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
|
||||
# to an unknown method, respectively.
|
||||
#
|
||||
# Older versions of servers don't properly handle this. This needs to be
|
||||
# rather specific as some endpoints truly do return 404 errors.
|
||||
return (
|
||||
# 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
|
||||
(e.code == 404 or e.code == 405)
|
||||
and (
|
||||
# Older Dendrites returned a text body or empty body.
|
||||
# Older Conduit returned an empty body.
|
||||
not e.response
|
||||
or e.response == b"404 page not found"
|
||||
# The proper response JSON with M_UNRECOGNIZED errcode.
|
||||
or synapse_error.errcode == Codes.UNRECOGNIZED
|
||||
)
|
||||
) or (
|
||||
# Older Synapses returned a 400 error.
|
||||
e.code == 400
|
||||
and synapse_error.errcode == Codes.UNRECOGNIZED
|
||||
)
|
||||
|
|
|
@ -149,7 +149,7 @@ class Mailer:
|
|||
await self.send_email(
|
||||
email_address,
|
||||
self.email_subjects.password_reset
|
||||
% {"server_name": self.hs.config.server.server_name},
|
||||
% {"server_name": self.hs.config.server.server_name, "app": self.app_name},
|
||||
template_vars,
|
||||
)
|
||||
|
||||
|
|
|
@ -14,36 +14,7 @@
|
|||
"""This module contains the implementation of both the client and server
|
||||
protocols.
|
||||
|
||||
The basic structure of the protocol is line based, where the initial word of
|
||||
each line specifies the command. The rest of the line is parsed based on the
|
||||
command. For example, the `RDATA` command is defined as::
|
||||
|
||||
RDATA <stream_name> <token> <row_json>
|
||||
|
||||
(Note that `<row_json>` may contains spaces, but cannot contain newlines.)
|
||||
|
||||
Blank lines are ignored.
|
||||
|
||||
# Example
|
||||
|
||||
An example iteraction is shown below. Each line is prefixed with '>' or '<' to
|
||||
indicate which side is sending, these are *not* included on the wire::
|
||||
|
||||
* connection established *
|
||||
> SERVER localhost:8823
|
||||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
|
||||
> RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823",
|
||||
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]]
|
||||
< PING 1490197675618
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
An explanation of this protocol is available in docs/tcp_replication.md
|
||||
"""
|
||||
import fcntl
|
||||
import logging
|
||||
|
|
|
@ -152,8 +152,8 @@ class Stream:
|
|||
Returns:
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
position in stream (ie the highest token returned in the updates),
|
||||
and `limited` is whether there are more updates to fetch.
|
||||
"""
|
||||
current_token = self.current_token(self.local_instance_name)
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
|
|
|
@ -617,14 +617,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
# We limit like this as we might have multiple rows per stream_id, and
|
||||
# we want to make sure we always get all entries for any stream_id
|
||||
# we return.
|
||||
upper_pos = min(current_id, last_id + limit)
|
||||
upto_token = min(current_id, last_id + limit)
|
||||
sql = (
|
||||
"SELECT max(stream_id), user_id"
|
||||
" FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY user_id"
|
||||
)
|
||||
txn.execute(sql, (last_id, upper_pos))
|
||||
txn.execute(sql, (last_id, upto_token))
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
|
||||
sql = (
|
||||
|
@ -633,19 +633,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY destination"
|
||||
)
|
||||
txn.execute(sql, (last_id, upper_pos))
|
||||
txn.execute(sql, (last_id, upto_token))
|
||||
updates.extend((row[0], row[1:]) for row in txn)
|
||||
|
||||
# Order by ascending stream ordering
|
||||
updates.sort()
|
||||
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
return updates, upto_token, upto_token < current_id
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
|
|
|
@ -1544,7 +1544,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
) -> List[str]:
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Filter down the events to ones that we've failed to pull before recently. Uses
|
||||
exponential backoff.
|
||||
|
@ -1554,7 +1554,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
event_ids: A list of events to filter down
|
||||
|
||||
Returns:
|
||||
List of event_ids that should not be attempted to be pulled
|
||||
A dictionary of event_ids that should not be attempted to be pulled and the
|
||||
next timestamp at which we may try pulling them again.
|
||||
"""
|
||||
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
|
||||
table="event_failed_pull_attempts",
|
||||
|
@ -1570,13 +1571,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
current_time = self._clock.time_msec()
|
||||
return [
|
||||
event_failed_pull_attempt["event_id"]
|
||||
for event_failed_pull_attempt in event_failed_pull_attempts
|
||||
|
||||
event_ids_with_backoff = {}
|
||||
for event_failed_pull_attempt in event_failed_pull_attempts:
|
||||
event_id = event_failed_pull_attempt["event_id"]
|
||||
# Exponential back-off (up to the upper bound) so we don't try to
|
||||
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
||||
if current_time
|
||||
< event_failed_pull_attempt["last_attempt_ts"]
|
||||
backoff_end_time = (
|
||||
event_failed_pull_attempt["last_attempt_ts"]
|
||||
+ (
|
||||
2
|
||||
** min(
|
||||
|
@ -1585,7 +1587,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
)
|
||||
* BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
|
||||
]
|
||||
)
|
||||
|
||||
if current_time < backoff_end_time: # `backoff_end_time` is exclusive
|
||||
event_ids_with_backoff[event_id] = backoff_end_time
|
||||
|
||||
return event_ids_with_backoff
|
||||
|
||||
async def get_missing_events(
|
||||
self,
|
||||
|
|
|
@ -100,7 +100,6 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
@ -289,180 +288,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
unique=True,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"event_push_backfill_thread_id",
|
||||
self._background_backfill_thread_id,
|
||||
)
|
||||
|
||||
# Indexes which will be used to quickly make the thread_id column non-null.
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"event_push_actions_thread_id_null",
|
||||
index_name="event_push_actions_thread_id_null",
|
||||
table="event_push_actions",
|
||||
columns=["thread_id"],
|
||||
where_clause="thread_id IS NULL",
|
||||
)
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"event_push_summary_thread_id_null",
|
||||
index_name="event_push_summary_thread_id_null",
|
||||
table="event_push_summary",
|
||||
columns=["thread_id"],
|
||||
where_clause="thread_id IS NULL",
|
||||
)
|
||||
|
||||
# Check ASAP (and then later, every 1s) to see if we have finished
|
||||
# background updates the event_push_actions and event_push_summary tables.
|
||||
self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
|
||||
self._event_push_backfill_thread_id_done = False
|
||||
|
||||
@wrap_as_background_process("check_event_push_backfill_thread_id")
|
||||
async def _check_event_push_backfill_thread_id(self) -> None:
|
||||
"""
|
||||
Has thread_id finished backfilling?
|
||||
|
||||
If not, we need to just-in-time update it so the queries work.
|
||||
"""
|
||||
done = await self.db_pool.updates.has_completed_background_update(
|
||||
"event_push_backfill_thread_id"
|
||||
)
|
||||
|
||||
if done:
|
||||
self._event_push_backfill_thread_id_done = True
|
||||
else:
|
||||
# Reschedule to run.
|
||||
self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
|
||||
|
||||
async def _background_backfill_thread_id(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Fill in the thread_id field for event_push_actions and event_push_summary.
|
||||
|
||||
This is preparatory so that it can be made non-nullable in the future.
|
||||
|
||||
Because all current (null) data is done in an unthreaded manner this
|
||||
simply assumes it is on the "main" timeline. Since event_push_actions
|
||||
are periodically cleared it is not possible to correctly re-calculate
|
||||
the thread_id.
|
||||
"""
|
||||
event_push_actions_done = progress.get("event_push_actions_done", False)
|
||||
|
||||
def add_thread_id_txn(
|
||||
txn: LoggingTransaction, start_stream_ordering: int
|
||||
) -> int:
|
||||
sql = """
|
||||
SELECT stream_ordering
|
||||
FROM event_push_actions
|
||||
WHERE
|
||||
thread_id IS NULL
|
||||
AND stream_ordering > ?
|
||||
ORDER BY stream_ordering
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (start_stream_ordering, batch_size))
|
||||
|
||||
# No more rows to process.
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
progress["event_push_actions_done"] = True
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "event_push_backfill_thread_id", progress
|
||||
)
|
||||
return 0
|
||||
|
||||
# Update the thread ID for any of those rows.
|
||||
max_stream_ordering = rows[-1][0]
|
||||
|
||||
sql = """
|
||||
UPDATE event_push_actions
|
||||
SET thread_id = 'main'
|
||||
WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
start_stream_ordering,
|
||||
max_stream_ordering,
|
||||
),
|
||||
)
|
||||
|
||||
# Update progress.
|
||||
processed_rows = txn.rowcount
|
||||
progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "event_push_backfill_thread_id", progress
|
||||
)
|
||||
|
||||
return processed_rows
|
||||
|
||||
def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
|
||||
min_user_id = progress.get("max_summary_user_id", "")
|
||||
min_room_id = progress.get("max_summary_room_id", "")
|
||||
|
||||
# Slightly overcomplicated query for getting the Nth user ID / room
|
||||
# ID tuple, or the last if there are less than N remaining.
|
||||
sql = """
|
||||
SELECT user_id, room_id FROM (
|
||||
SELECT user_id, room_id FROM event_push_summary
|
||||
WHERE (user_id, room_id) > (?, ?)
|
||||
AND thread_id IS NULL
|
||||
ORDER BY user_id, room_id
|
||||
LIMIT ?
|
||||
) AS e
|
||||
ORDER BY user_id DESC, room_id DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (min_user_id, min_room_id, batch_size))
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
|
||||
max_user_id, max_room_id = row
|
||||
|
||||
sql = """
|
||||
UPDATE event_push_summary
|
||||
SET thread_id = 'main'
|
||||
WHERE
|
||||
(?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
|
||||
AND thread_id IS NULL
|
||||
"""
|
||||
txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
|
||||
processed_rows = txn.rowcount
|
||||
|
||||
progress["max_summary_user_id"] = max_user_id
|
||||
progress["max_summary_room_id"] = max_room_id
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "event_push_backfill_thread_id", progress
|
||||
)
|
||||
|
||||
return processed_rows
|
||||
|
||||
# First update the event_push_actions table, then the event_push_summary table.
|
||||
#
|
||||
# Note that the event_push_actions_staging table is ignored since it is
|
||||
# assumed that items in that table will only exist for a short period of
|
||||
# time.
|
||||
if not event_push_actions_done:
|
||||
result = await self.db_pool.runInteraction(
|
||||
"event_push_backfill_thread_id",
|
||||
add_thread_id_txn,
|
||||
progress.get("max_event_push_actions_stream_ordering", 0),
|
||||
)
|
||||
else:
|
||||
result = await self.db_pool.runInteraction(
|
||||
"event_push_backfill_thread_id",
|
||||
add_thread_id_summary_txn,
|
||||
)
|
||||
|
||||
# Only done after the event_push_summary table is done.
|
||||
if not result:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"event_push_backfill_thread_id"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
|
||||
"""Get the notification count by room for a user. Only considers notifications,
|
||||
not highlight or unread counts, and threads are currently aggregated under their room.
|
||||
|
@ -711,25 +536,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||
)
|
||||
|
||||
# First ensure that the existing rows have an updated thread_id field.
|
||||
if not self._event_push_backfill_thread_id_done:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_summary
|
||||
SET thread_id = ?
|
||||
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
|
||||
""",
|
||||
(MAIN_TIMELINE, room_id, user_id),
|
||||
)
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_actions
|
||||
SET thread_id = ?
|
||||
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
|
||||
""",
|
||||
(MAIN_TIMELINE, room_id, user_id),
|
||||
)
|
||||
|
||||
# First we pull the counts from the summary table.
|
||||
#
|
||||
# We check that `last_receipt_stream_ordering` matches the stream ordering of the
|
||||
|
@ -1545,25 +1351,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
(room_id, user_id, stream_ordering, *thread_args),
|
||||
)
|
||||
|
||||
# First ensure that the existing rows have an updated thread_id field.
|
||||
if not self._event_push_backfill_thread_id_done:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_summary
|
||||
SET thread_id = ?
|
||||
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
|
||||
""",
|
||||
(MAIN_TIMELINE, room_id, user_id),
|
||||
)
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_actions
|
||||
SET thread_id = ?
|
||||
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
|
||||
""",
|
||||
(MAIN_TIMELINE, room_id, user_id),
|
||||
)
|
||||
|
||||
# Fetch the notification counts between the stream ordering of the
|
||||
# latest receipt and what was previously summarised.
|
||||
unread_counts = self._get_notif_unread_count_for_user_room(
|
||||
|
@ -1698,19 +1485,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
|
||||
"""
|
||||
|
||||
# Ensure that any new actions have an updated thread_id.
|
||||
if not self._event_push_backfill_thread_id_done:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_actions
|
||||
SET thread_id = ?
|
||||
WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
|
||||
""",
|
||||
(MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
|
||||
)
|
||||
|
||||
# XXX Do we need to update summaries here too?
|
||||
|
||||
# Calculate the new counts that should be upserted into event_push_summary
|
||||
sql = """
|
||||
SELECT user_id, room_id, thread_id,
|
||||
|
@ -1773,20 +1547,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
|
||||
logger.info("Rotating notifications, handling %d rows", len(summaries))
|
||||
|
||||
# Ensure that any updated threads have the proper thread_id.
|
||||
if not self._event_push_backfill_thread_id_done:
|
||||
txn.execute_batch(
|
||||
"""
|
||||
UPDATE event_push_summary
|
||||
SET thread_id = ?
|
||||
WHERE room_id = ? AND user_id = ? AND thread_id is NULL
|
||||
""",
|
||||
[
|
||||
(MAIN_TIMELINE, room_id, user_id)
|
||||
for user_id, room_id, _ in summaries
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="event_push_summary",
|
||||
|
|
|
@ -34,6 +34,13 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
|
|||
":memory:",
|
||||
)
|
||||
|
||||
# A connection to a database that has already been prepared, to use as a
|
||||
# base for an in-memory connection. This is used during unit tests to
|
||||
# speed up setting up the DB.
|
||||
self._prepped_conn: Optional[sqlite3.Connection] = database_config.get(
|
||||
"_TEST_PREPPED_CONN"
|
||||
)
|
||||
|
||||
if platform.python_implementation() == "PyPy":
|
||||
# pypy's sqlite3 module doesn't handle bytearrays, convert them
|
||||
# back to bytes.
|
||||
|
@ -84,6 +91,14 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
|
|||
# In memory databases need to be rebuilt each time. Ideally we'd
|
||||
# reuse the same connection as we do when starting up, but that
|
||||
# would involve using adbapi before we have started the reactor.
|
||||
#
|
||||
# If we have a `prepped_conn` we can use that to initialise the DB,
|
||||
# otherwise we need to call `prepare_database`.
|
||||
if self._prepped_conn is not None:
|
||||
# Initialise the new DB from the pre-prepared DB.
|
||||
assert isinstance(db_conn.conn, sqlite3.Connection)
|
||||
self._prepped_conn.backup(db_conn.conn)
|
||||
else:
|
||||
prepare_database(db_conn, self, config=None)
|
||||
|
||||
db_conn.create_function("rank", 1, _rank)
|
||||
|
|
|
@ -95,9 +95,9 @@ Changes in SCHEMA_VERSION = 74:
|
|||
|
||||
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
# The threads_id column must exist for event_push_actions, event_push_summary,
|
||||
# receipts_linearized, and receipts_graph.
|
||||
73
|
||||
# The threads_id column must written to with non-null values event_push_actions,
|
||||
# event_push_actions_staging, and event_push_summary.
|
||||
74
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Force the background updates from 06thread_notifications.sql to run in the
|
||||
-- foreground as code will now require those to be "done".
|
||||
|
||||
DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id';
|
||||
|
||||
-- Overwrite any null thread_id values.
|
||||
UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL;
|
||||
UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL;
|
||||
UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL;
|
||||
|
||||
-- Drop the background updates to calculate the indexes used to find null thread_ids.
|
||||
DELETE FROM background_updates WHERE update_name = 'event_push_actions_thread_id_null';
|
||||
DELETE FROM background_updates WHERE update_name = 'event_push_summary_thread_id_null';
|
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Drop the indexes used to find null thread_ids.
|
||||
DROP INDEX IF EXISTS event_push_actions_thread_id_null;
|
||||
DROP INDEX IF EXISTS event_push_summary_thread_id_null;
|
||||
|
||||
-- The thread_id columns can now be made non-nullable.
|
||||
ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL;
|
||||
ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL;
|
||||
ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL;
|
|
@ -0,0 +1,99 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- The thread_id columns can now be made non-nullable.
|
||||
--
|
||||
-- SQLite doesn't support modifying columns to an existing table, so it must
|
||||
-- be recreated.
|
||||
|
||||
-- Create the new tables.
|
||||
CREATE TABLE event_push_actions_staging_new (
|
||||
event_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
notif SMALLINT NOT NULL,
|
||||
highlight SMALLINT NOT NULL,
|
||||
unread SMALLINT,
|
||||
thread_id TEXT NOT NULL,
|
||||
inserted_ts BIGINT
|
||||
);
|
||||
|
||||
CREATE TABLE event_push_actions_new (
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
profile_tag VARCHAR(32),
|
||||
actions TEXT NOT NULL,
|
||||
topological_ordering BIGINT,
|
||||
stream_ordering BIGINT,
|
||||
notif SMALLINT,
|
||||
highlight SMALLINT,
|
||||
unread SMALLINT,
|
||||
thread_id TEXT NOT NULL,
|
||||
CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
|
||||
);
|
||||
|
||||
CREATE TABLE event_push_summary_new (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
notif_count BIGINT NOT NULL,
|
||||
stream_ordering BIGINT NOT NULL,
|
||||
unread_count BIGINT,
|
||||
last_receipt_stream_ordering BIGINT,
|
||||
thread_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Copy the data.
|
||||
INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts)
|
||||
SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts
|
||||
FROM event_push_actions_staging;
|
||||
|
||||
INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id)
|
||||
SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id
|
||||
FROM event_push_actions;
|
||||
|
||||
INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id)
|
||||
SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id
|
||||
FROM event_push_summary;
|
||||
|
||||
-- Drop the old tables.
|
||||
DROP TABLE event_push_actions_staging;
|
||||
DROP TABLE event_push_actions;
|
||||
DROP TABLE event_push_summary;
|
||||
|
||||
-- Rename the tables.
|
||||
ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging;
|
||||
ALTER TABLE event_push_actions_new RENAME TO event_push_actions;
|
||||
ALTER TABLE event_push_summary_new RENAME TO event_push_summary;
|
||||
|
||||
-- Recreate the indexes.
|
||||
CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
|
||||
|
||||
CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering);
|
||||
CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering );
|
||||
CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);
|
||||
CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id );
|
||||
CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering);
|
||||
|
||||
CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary (user_id, room_id, thread_id) ;
|
||||
|
||||
-- Recreate some indexes in the background, by re-running the background updates
|
||||
-- from 72/02event_push_actions_index.sql and 72/06thread_notifications.sql.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7403, 'event_push_summary_unique_index2', '{}')
|
||||
ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7403, 'event_push_actions_stream_highlight_index', '{}')
|
||||
ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}';
|
|
@ -960,7 +960,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
|
||||
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
sender="@as_main:test",
|
||||
)
|
||||
|
@ -1015,3 +1015,122 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
},
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
|
||||
def test_query_local_devices_appservice(self) -> None:
|
||||
"""Test that querying of appservices for keys overrides responses from the database."""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_1 = "abc"
|
||||
device_2 = "def"
|
||||
device_3 = "ghi"
|
||||
|
||||
# There are 3 devices:
|
||||
#
|
||||
# 1. One which is uploaded to the homeserver.
|
||||
# 2. One which is uploaded to the homeserver, but a newer copy is returned
|
||||
# by the appservice.
|
||||
# 3. One which is only returned by the appservice.
|
||||
device_key_1: JsonDict = {
|
||||
"user_id": local_user,
|
||||
"device_id": device_1,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
"keys": {
|
||||
"ed25519:abc": "base64+ed25519+key",
|
||||
"curve25519:abc": "base64+curve25519+key",
|
||||
},
|
||||
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
|
||||
}
|
||||
device_key_2a: JsonDict = {
|
||||
"user_id": local_user,
|
||||
"device_id": device_2,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
"keys": {
|
||||
"ed25519:def": "base64+ed25519+key",
|
||||
"curve25519:def": "base64+curve25519+key",
|
||||
},
|
||||
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
|
||||
}
|
||||
|
||||
device_key_2b: JsonDict = {
|
||||
"user_id": local_user,
|
||||
"device_id": device_2,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
# The device ID is the same (above), but the keys are different.
|
||||
"keys": {
|
||||
"ed25519:xyz": "base64+ed25519+key",
|
||||
"curve25519:xyz": "base64+curve25519+key",
|
||||
},
|
||||
"signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
|
||||
}
|
||||
device_key_3: JsonDict = {
|
||||
"user_id": local_user,
|
||||
"device_id": device_3,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
"keys": {
|
||||
"ed25519:jkl": "base64+ed25519+key",
|
||||
"curve25519:jkl": "base64+curve25519+key",
|
||||
},
|
||||
"signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
|
||||
}
|
||||
|
||||
# Upload keys for devices 1 & 2a.
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_1, {"device_keys": device_key_1}
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_2, {"device_keys": device_key_2a}
|
||||
)
|
||||
)
|
||||
|
||||
# Inject an appservice interested in this user.
|
||||
appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
sender="@as_main:test",
|
||||
)
|
||||
self.hs.get_datastores().main.services_cache = [appservice]
|
||||
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
|
||||
[appservice]
|
||||
)
|
||||
|
||||
# Setup a response.
|
||||
self.appservice_api.query_keys.return_value = make_awaitable(
|
||||
{
|
||||
"device_keys": {
|
||||
local_user: {device_2: device_key_2b, device_3: device_key_3}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Request all devices.
|
||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||
self.assertIn(local_user, res)
|
||||
for res_key in res[local_user].values():
|
||||
res_key.pop("unsigned", None)
|
||||
self.assertDictEqual(
|
||||
res,
|
||||
{
|
||||
local_user: {
|
||||
device_1: device_key_1,
|
||||
device_2: device_key_2b,
|
||||
device_3: device_key_3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
|
|
@ -922,7 +922,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
auth_provider_session_id=None,
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": True}})
|
||||
def test_map_userinfo_to_user(self) -> None:
|
||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||
userinfo: dict = {
|
||||
|
@ -975,6 +975,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
"Mapping provider does not support de-duplicating Matrix IDs",
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": False}})
|
||||
def test_map_userinfo_to_user_does_not_register_new_user(self) -> None:
|
||||
"""Ensures new users are not registered if the enabled registration flag is disabled."""
|
||||
userinfo: dict = {
|
||||
"sub": "test_user",
|
||||
"username": "test_user",
|
||||
}
|
||||
request, _ = self.start_authorization(userinfo)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.complete_sso_login.assert_not_called()
|
||||
self.assertRenderedError(
|
||||
"mapping_error",
|
||||
"User does not exist and registrations are disabled",
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||
def test_map_userinfo_to_existing_user(self) -> None:
|
||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||
|
|
|
@ -54,6 +54,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
# Redis replication only takes place on Postgres
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
|
|
|
@ -37,11 +37,6 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
|
|||
# also one global update
|
||||
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import synapse
|
||||
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.replication._base import BaseStreamTestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToDeviceStreamTestCase(BaseStreamTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
synapse.rest.client.login.register_servlets,
|
||||
]
|
||||
|
||||
def test_to_device_stream(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
user1 = self.register_user("user1", "pass")
|
||||
self.login("user1", "pass", "device")
|
||||
user2 = self.register_user("user2", "pass")
|
||||
self.login("user2", "pass", "device")
|
||||
|
||||
# connect to pull the updates related to users creation/login
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
# disconnect so we can accumulate the updates without pulling them
|
||||
self.disconnect()
|
||||
|
||||
msg: JsonDict = {}
|
||||
msg["sender"] = "@sender:example.org"
|
||||
msg["type"] = "m.new_device"
|
||||
|
||||
# add messages to the device inbox for user1 up until the
|
||||
# limit defined for a stream update batch
|
||||
for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
|
||||
msg["content"] = {"device": {}}
|
||||
messages = {user1: {"device": msg}}
|
||||
|
||||
self.get_success(
|
||||
store.add_messages_from_remote_to_device_inbox(
|
||||
"example.org",
|
||||
f"{i}",
|
||||
messages,
|
||||
)
|
||||
)
|
||||
|
||||
# add one more message, for user2 this time
|
||||
# this message would be dropped before fixing #15335
|
||||
msg["content"] = {"device": {}}
|
||||
messages = {user2: {"device": msg}}
|
||||
|
||||
self.get_success(
|
||||
store.add_messages_from_remote_to_device_inbox(
|
||||
"example.org",
|
||||
f"{_STREAM_UPDATE_TARGET_ROW_COUNT}",
|
||||
messages,
|
||||
)
|
||||
)
|
||||
|
||||
# replication is disconnected so we shouldn't get any updates yet
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should receive the fact that we have to_device updates
|
||||
# for user1 and user2
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
self.assertEqual(len(received_rows), 2)
|
||||
self.assertEqual(received_rows[0][2].entity, user1)
|
||||
self.assertEqual(received_rows[1][2].entity, user2)
|
|
@ -16,6 +16,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
@ -79,7 +80,9 @@ from synapse.http.site import SynapseRequest
|
|||
from synapse.logging.context import ContextResourceUsage
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
from synapse.storage.engines import PostgresEngine, create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.types import ISynapseReactor, JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -104,6 +107,10 @@ P = ParamSpec("P")
|
|||
# the type of thing that can be passed into `make_request` in the headers list
|
||||
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
||||
|
||||
# A pre-prepared SQLite DB that is used as a template when creating new SQLite
|
||||
# DB each test run. This dramatically speeds up test set up when using SQLite.
|
||||
PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
|
||||
|
||||
|
||||
class TimedOutException(Exception):
|
||||
"""
|
||||
|
@ -899,6 +906,22 @@ def setup_test_homeserver(
|
|||
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
|
||||
}
|
||||
|
||||
# Check if we have set up a DB that we can use as a template.
|
||||
global PREPPED_SQLITE_DB_CONN
|
||||
if PREPPED_SQLITE_DB_CONN is None:
|
||||
temp_engine = create_engine(database_config)
|
||||
PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
|
||||
sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
|
||||
)
|
||||
|
||||
database = DatabaseConnectionConfig("master", database_config)
|
||||
config.database.databases = [database]
|
||||
prepare_database(
|
||||
PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
|
||||
)
|
||||
|
||||
database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
|
||||
|
||||
if "db_txn_limit" in kwargs:
|
||||
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
||||
|
||||
|
|
|
@ -1143,19 +1143,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
tok = self.login("alice", "test")
|
||||
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
|
||||
|
||||
failure_time = self.clock.time_msec()
|
||||
self.get_success(
|
||||
self.store.record_event_failed_pull_attempt(
|
||||
room_id, "$failed_event_id", "fake cause"
|
||||
)
|
||||
)
|
||||
|
||||
event_ids_to_backoff = self.get_success(
|
||||
event_ids_with_backoff = self.get_success(
|
||||
self.store.get_event_ids_to_not_pull_from_backoff(
|
||||
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
|
||||
self.assertEqual(
|
||||
event_ids_with_backoff,
|
||||
# We expect a 2^1 hour backoff after a single failed attempt.
|
||||
{"$failed_event_id": failure_time + 2 * 60 * 60 * 1000},
|
||||
)
|
||||
|
||||
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
|
||||
self,
|
||||
|
@ -1179,14 +1184,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
# attempt (2^1 hours).
|
||||
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
||||
|
||||
event_ids_to_backoff = self.get_success(
|
||||
event_ids_with_backoff = self.get_success(
|
||||
self.store.get_event_ids_to_not_pull_from_backoff(
|
||||
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
|
||||
)
|
||||
)
|
||||
# Since this function only returns events we should backoff from, time has
|
||||
# elapsed past the backoff range so there is no events to backoff from.
|
||||
self.assertEqual(event_ids_to_backoff, [])
|
||||
self.assertEqual(event_ids_with_backoff, {})
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
|
|
|
@ -146,6 +146,9 @@ class TestCase(unittest.TestCase):
|
|||
% (current_context(),)
|
||||
)
|
||||
|
||||
# Disable GC for duration of test. See below for why.
|
||||
gc.disable()
|
||||
|
||||
old_level = logging.getLogger().level
|
||||
if level is not None and old_level != level:
|
||||
|
||||
|
@ -163,12 +166,19 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
return orig()
|
||||
|
||||
# We want to force a GC to workaround problems with deferreds leaking
|
||||
# logcontexts when they are GCed (see the logcontext docs).
|
||||
#
|
||||
# The easiest way to do this would be to do a full GC after each test
|
||||
# run, but that is very expensive. Instead, we disable GC (above) for
|
||||
# the duration of the test so that we only need to run a gen-0 GC, which
|
||||
# is a lot quicker.
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig: Callable[[], R]) -> R:
|
||||
ret = orig()
|
||||
# force a GC to workaround problems with deferreds leaking logcontexts when
|
||||
# they are GCed (see the logcontext docs)
|
||||
gc.collect()
|
||||
gc.collect(0)
|
||||
gc.enable()
|
||||
set_current_context(SENTINEL_CONTEXT)
|
||||
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue