Format files with Ruff (#17643)

I thought ruff check would also format, but it doesn't.

This runs ruff format in CI and dev scripts. The first commit is just a
run of `ruff format .` in the root directory.
This commit is contained in:
Quentin Gliech 2024-09-02 13:39:04 +02:00 committed by GitHub
parent 709b7363fe
commit 7d52ce7d4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
152 changed files with 526 additions and 492 deletions

View File

@ -29,10 +29,14 @@ jobs:
with:
install-project: "false"
- name: Run ruff
- name: Run ruff check
continue-on-error: true
run: poetry run ruff check --fix .
- name: Run ruff format
continue-on-error: true
run: poetry run ruff format --quiet .
- run: cargo clippy --all-features --fix -- -D warnings
continue-on-error: true

View File

@ -131,9 +131,12 @@ jobs:
with:
install-project: "false"
- name: Check style
- name: Run ruff check
run: poetry run ruff check --output-format=github .
- name: Run ruff format
run: poetry run ruff format --check .
lint-mypy:
runs-on: ubuntu-latest
name: Typechecking

1
changelog.d/17643.misc Normal file
View File

@ -0,0 +1 @@
Replace `isort` and `black with `ruff`.

View File

@ -21,7 +21,8 @@
#
#
""" Starts a synapse client console. """
"""Starts a synapse client console."""
import argparse
import binascii
import cmd

View File

@ -31,6 +31,7 @@ Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. Se
until then, this script is a best effort to stop us from introducing type coersion bugs
(like the infamous stringy power levels fixed in room version 10).
"""
import argparse
import contextlib
import functools

View File

@ -109,6 +109,9 @@ set -x
# --quiet suppresses the update check.
ruff check --quiet --fix "${files[@]}"
# Reformat Python code.
ruff format --quiet "${files[@]}"
# Catch any common programming mistakes in Rust code.
#
# --bins, --examples, --lib, --tests combined explicitly disable checking

View File

@ -20,8 +20,7 @@
#
#
"""An interactive script for doing a release. See `cli()` below.
"""
"""An interactive script for doing a release. See `cli()` below."""
import glob
import json

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains *incomplete* type hints for txredisapi.
"""
"""Contains *incomplete* type hints for txredisapi."""
from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol

View File

@ -20,8 +20,7 @@
#
#
""" This is an implementation of a Matrix homeserver.
"""
"""This is an implementation of a Matrix homeserver."""
import os
import sys

View File

@ -171,7 +171,7 @@ def elide_http_methods_if_unconflicting(
"""
def paths_to_methods_dict(
methods_and_paths: Iterable[Tuple[str, str]]
methods_and_paths: Iterable[Tuple[str, str]],
) -> Dict[str, Set[str]]:
"""
Given (method, path) pairs, produces a dict from path to set of methods
@ -201,7 +201,7 @@ def elide_http_methods_if_unconflicting(
def simplify_path_regexes(
registrations: Dict[Tuple[str, str], EndpointDescription]
registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]:
"""
Simplify all the path regexes for the dict of endpoint descriptions,

View File

@ -40,6 +40,7 @@ from synapse.storage.engines import create_engine
class ReviewConfig(RootConfig):
"A config class that just pulls out the database config"
config_classes = [DatabaseConfig]
@ -160,7 +161,11 @@ def main() -> None:
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
# This generates a type of Cursor, not LoggingTransaction.
user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type]
user_infos = get_recent_users(
db_conn.cursor(),
since_ms, # type: ignore[arg-type]
exclude_users_with_appservice,
)
for user_info in user_infos:
if exclude_users_with_email and user_info.emails:

View File

@ -717,9 +717,7 @@ class Porter:
return
# Check if all background updates are done, abort if not.
updates_complete = (
await self.sqlite_store.db_pool.updates.has_completed_background_updates()
)
updates_complete = await self.sqlite_store.db_pool.updates.has_completed_background_updates()
if not updates_complete:
end_error = (
"Pending background updates exist in the SQLite3 database."
@ -1095,10 +1093,10 @@ class Porter:
return done, remaining + done
async def _setup_state_group_id_seq(self) -> None:
curr_id: Optional[int] = (
await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
curr_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
if not curr_id:
@ -1186,13 +1184,13 @@ class Porter:
)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id: Optional[int] = (
await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
curr_chain_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
def r(txn: LoggingTransaction) -> None:

View File

@ -19,7 +19,8 @@
#
#
"""Contains the URL paths to prefix various aspects of the server with. """
"""Contains the URL paths to prefix various aspects of the server with."""
import hmac
from hashlib import sha256
from urllib.parse import urlencode

View File

@ -54,6 +54,7 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import (
TYPE_CHECKING,

View File

@ -200,16 +200,13 @@ class KeyConfig(Config):
)
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
return (
"""\
return """\
%(macaroon_secret_key)s
%(form_secret)s
signing_key_path: "%(base_key_name)s.signing.key"
trusted_key_servers:
- server_name: "matrix.org"
"""
% locals()
)
""" % locals()
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
"""Read the signing keys in the given path.
@ -249,7 +246,9 @@ class KeyConfig(Config):
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment]
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(
key_id, key_bytes
) # type: ignore[assignment]
verify_key.expired = key_data["expired_ts"]
keys[key_id] = verify_key
else:

View File

@ -157,12 +157,9 @@ class LoggingConfig(Config):
self, config_dir_path: str, server_name: str, **kwargs: Any
) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return (
"""\
return """\
log_config: "%(log_config)s"
"""
% locals()
)
""" % locals()
def read_arguments(self, args: argparse.Namespace) -> None:
if args.no_redirect_stdio is not None:

View File

@ -828,13 +828,10 @@ class ServerConfig(Config):
).lstrip()
if not unsecure_listeners:
unsecure_http_bindings = (
"""- port: %(unsecure_port)s
unsecure_http_bindings = """- port: %(unsecure_port)s
tls: false
type: http
x_forwarded: true"""
% locals()
)
x_forwarded: true""" % locals()
if not open_private_ports:
unsecure_http_bindings += (
@ -853,16 +850,13 @@ class ServerConfig(Config):
if not secure_listeners:
secure_http_bindings = ""
return (
"""\
return """\
server_name: "%(server_name)s"
pid_file: %(pid_file)s
listeners:
%(secure_http_bindings)s
%(unsecure_http_bindings)s
"""
% locals()
)
""" % locals()
def read_arguments(self, args: argparse.Namespace) -> None:
if args.manhole is not None:

View File

@ -328,10 +328,11 @@ class WorkerConfig(Config):
)
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
self.instance_map: Dict[
str, InstanceLocationConfig
] = parse_and_validate_mapping(
instance_map, InstanceLocationConfig # type: ignore[arg-type]
self.instance_map: Dict[str, InstanceLocationConfig] = (
parse_and_validate_mapping(
instance_map,
InstanceLocationConfig, # type: ignore[arg-type]
)
)
# Map from type of streams to source, c.f. WriterLocations.

View File

@ -887,7 +887,8 @@ def _check_power_levels(
raise SynapseError(400, f"{v!r} must be an integer.")
if k in {"events", "notifications", "users"}:
if not isinstance(v, collections.abc.Mapping) or not all(
type(v) is int for v in v.values() # noqa: E721
type(v) is int
for v in v.values() # noqa: E721
):
raise SynapseError(
400,

View File

@ -80,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(
f: Optional[Callable[P, R]]
f: Optional[Callable[P, R]],
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.

View File

@ -504,7 +504,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
def _encode_state_group_delta(
state_group_delta: Dict[Tuple[int, int], StateMap[str]]
state_group_delta: Dict[Tuple[int, int], StateMap[str]],
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
if not state_group_delta:
return []
@ -517,7 +517,7 @@ def _encode_state_group_delta(
def _decode_state_group_delta(
input: List[Tuple[int, int, List[Tuple[str, str, str]]]]
input: List[Tuple[int, int, List[Tuple[str, str, str]]]],
) -> Dict[Tuple[int, int], StateMap[str]]:
if not input:
return {}
@ -544,7 +544,7 @@ def _encode_state_dict(
def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]]
input: Optional[List[Tuple[str, str, str]]],
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:

View File

@ -19,5 +19,4 @@
#
#
""" This package includes all the federation specific logic.
"""
"""This package includes all the federation specific logic."""

View File

@ -20,7 +20,7 @@
#
#
""" This module contains all the persistence actions done by the federation
"""This module contains all the persistence actions done by the federation
package.
These actions are mostly only used by the :py:mod:`.replication` module.

View File

@ -859,7 +859,6 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
request: SynapseRequest,
media_id: str,
) -> None:
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")

View File

@ -19,7 +19,7 @@
#
#
""" Defines the JSON structure of the protocol units used by the server to
"""Defines the JSON structure of the protocol units used by the server to
server protocol.
"""

View File

@ -118,10 +118,10 @@ class AccountHandler:
}
if self._use_account_validity_in_account_status:
status["org.matrix.expired"] = (
await self._account_validity_handler.is_user_expired(
user_id.to_string()
)
status[
"org.matrix.expired"
] = await self._account_validity_handler.is_user_expired(
user_id.to_string()
)
return status

View File

@ -197,14 +197,15 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = (
await self._store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_key,
to_key=to_key,
limit=100,
direction=Direction.FORWARDS,
)
(
events,
_,
) = await self._store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_key,
to_key=to_key,
limit=100,
direction=Direction.FORWARDS,
)
if not events:
break

View File

@ -166,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
"phone" not in identifier and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM

View File

@ -265,9 +265,9 @@ class DirectoryHandler:
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
result: Optional[RoomAliasMapping] = (
await self.get_association_from_room_alias(room_alias)
)
result: Optional[
RoomAliasMapping
] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@ -512,11 +512,9 @@ class DirectoryHandler:
raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module
allowed_by_third_party_rules = (
await (
self._third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
allowed_by_third_party_rules = await (
self._third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
)
if not allowed_by_third_party_rules:

View File

@ -1001,11 +1001,11 @@ class FederationHandler:
)
if include_auth_user_id:
event_content[EventContentFields.AUTHORISING_USER] = (
await self._event_auth_handler.get_user_which_could_invite(
room_id,
state_ids,
)
event_content[
EventContentFields.AUTHORISING_USER
] = await self._event_auth_handler.get_user_which_could_invite(
room_id,
state_ids,
)
builder = self.event_builder_factory.for_room_version(

View File

@ -21,6 +21,7 @@
#
"""Utilities for interacting with Identity Servers"""
import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

View File

@ -1225,10 +1225,9 @@ class EventCreationHandler:
)
if prev_event_ids is not None:
assert (
len(prev_event_ids) <= 10
), "Attempting to create an event with %i prev_events" % (
len(prev_event_ids),
assert len(prev_event_ids) <= 10, (
"Attempting to create an event with %i prev_events"
% (len(prev_event_ids),)
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)

View File

@ -507,15 +507,16 @@ class PaginationHandler:
# Initially fetch the events from the database. With any luck, we can return
# these without blocking on backfill (handled below).
events, next_key = (
await self.store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_token.room_key,
to_key=to_room_key,
direction=pagin_config.direction,
limit=pagin_config.limit,
event_filter=event_filter,
)
(
events,
next_key,
) = await self.store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_token.room_key,
to_key=to_room_key,
direction=pagin_config.direction,
limit=pagin_config.limit,
event_filter=event_filter,
)
if pagin_config.direction == Direction.BACKWARDS:
@ -584,15 +585,16 @@ class PaginationHandler:
# If we did backfill something, refetch the events from the database to
# catch anything new that might have been added since we last fetched.
if did_backfill:
events, next_key = (
await self.store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_token.room_key,
to_key=to_room_key,
direction=pagin_config.direction,
limit=pagin_config.limit,
event_filter=event_filter,
)
(
events,
next_key,
) = await self.store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_token.room_key,
to_key=to_room_key,
direction=pagin_config.direction,
limit=pagin_config.limit,
event_filter=event_filter,
)
else:
# Otherwise, we can backfill in the background for eventual

View File

@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will
automatically be replaced with any information from currently available devices.
"""
import abc
import contextlib
import itertools
@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by (user ID, device ID).
# Empty if _presence_enabled is false.
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
{}
)
self._user_device_to_num_current_syncs: Dict[
Tuple[str, Optional[str]], int
] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
{}
)
self._user_device_to_num_current_syncs: Dict[
Tuple[str, Optional[str]], int
] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
#

View File

@ -351,9 +351,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
media_info: Optional[Union[LocalMedia, RemoteMedia]] = (
await self.store.get_local_media(media_id)
)
media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)

View File

@ -188,13 +188,13 @@ class RelationsHandler:
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
return_value["original_event"] = (
await self._event_serializer.serialize_event(
event,
now,
bundle_aggregations=None,
config=serialize_options,
)
return_value[
"original_event"
] = await self._event_serializer.serialize_event(
event,
now,
bundle_aggregations=None,
config=serialize_options,
)
if next_token:

View File

@ -20,6 +20,7 @@
#
"""Contains functions for performing actions on rooms."""
import itertools
import logging
import math
@ -900,11 +901,9 @@ class RoomCreationHandler:
)
# Check whether this visibility value is blocked by a third party module
allowed_by_third_party_rules = (
await (
self._third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
allowed_by_third_party_rules = await (
self._third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
)
if not allowed_by_third_party_rules:

View File

@ -1302,11 +1302,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If this is going to be a local join, additional information must
# be included in the event content in order to efficiently validate
# the event.
content[EventContentFields.AUTHORISING_USER] = (
await self.event_auth_handler.get_user_which_could_invite(
room_id,
state_before_join,
)
content[
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
state_before_join,
)
return False, []
@ -1415,9 +1415,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if requester is not None:
sender = UserID.from_string(event.sender)
assert (
sender == requester.user
), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" % (sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = types.create_requester(target_user)

View File

@ -423,9 +423,9 @@ class SearchHandler:
}
if search_result.room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})[
"room_id"
] = search_result.room_groups
rooms_cat_res.setdefault("groups", {})["room_id"] = (
search_result.room_groups
)
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group

View File

@ -587,9 +587,7 @@ class SlidingSyncHandler:
Membership.LEAVE,
Membership.BAN,
):
to_bound = (
room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
)
to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
timeline_from_bound = from_bound
if ignore_timeline_bound:

View File

@ -386,9 +386,9 @@ class SlidingSyncExtensionHandler:
if have_push_rules_changed:
global_account_data_map = dict(global_account_data_map)
# TODO: This should take into account the `from_token` and `to_token`
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
await self.push_rules_handler.push_rules_for_user(sync_config.user)
)
global_account_data_map[
AccountDataTypes.PUSH_RULES
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
else:
# TODO: This should take into account the `to_token`
all_global_account_data = await self.store.get_global_account_data_for_user(
@ -397,9 +397,9 @@ class SlidingSyncExtensionHandler:
global_account_data_map = dict(all_global_account_data)
# TODO: This should take into account the `to_token`
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
await self.push_rules_handler.push_rules_for_user(sync_config.user)
)
global_account_data_map[
AccountDataTypes.PUSH_RULES
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
# Fetch room account data
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}

View File

@ -293,10 +293,11 @@ class SlidingSyncRoomLists:
is_encrypted=is_encrypted,
)
newly_joined_room_ids, newly_left_room_map = (
await self._get_newly_joined_and_left_rooms(
user_id, from_token=from_token, to_token=to_token
)
(
newly_joined_room_ids,
newly_left_room_map,
) = await self._get_newly_joined_and_left_rooms(
user_id, from_token=from_token, to_token=to_token
)
dm_room_ids = await self._get_dm_rooms_for_user(user_id)
@ -958,10 +959,11 @@ class SlidingSyncRoomLists:
else:
rooms_for_user[room_id] = change_room_for_user
newly_joined_room_ids, newly_left_room_ids = (
await self._get_newly_joined_and_left_rooms(
user_id, to_token=to_token, from_token=from_token
)
(
newly_joined_room_ids,
newly_left_room_ids,
) = await self._get_newly_joined_and_left_rooms(
user_id, to_token=to_token, from_token=from_token
)
dm_room_ids = await self._get_dm_rooms_for_user(user_id)

View File

@ -183,10 +183,7 @@ class JoinedSyncResult:
to tell if room needs to be part of the sync result.
"""
return bool(
self.timeline
or self.state
or self.ephemeral
or self.account_data
self.timeline or self.state or self.ephemeral or self.account_data
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
@ -575,10 +572,10 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result: Union[SyncResult, E2eeSyncResult] = (
await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
)
result: Union[
SyncResult, E2eeSyncResult
] = await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
)
else:
# Otherwise, we wait for something to happen and report it to the user.
@ -673,10 +670,10 @@ class SyncHandler:
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
sync_result: Union[SyncResult, E2eeSyncResult] = (
await self.generate_sync_result(
sync_config, since_token, full_state
)
sync_result: Union[
SyncResult, E2eeSyncResult
] = await self.generate_sync_result(
sync_config, since_token, full_state
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC:
@ -1488,13 +1485,16 @@ class SyncHandler:
# timeline here. The caller will then dedupe any redundant
# ones.
state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member) for member in members_to_fetch
),
await_full_state=False,
state_ids = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member)
for member in members_to_fetch
),
await_full_state=False,
)
)
return state_ids
@ -2166,18 +2166,18 @@ class SyncHandler:
if push_rules_changed:
global_account_data = dict(global_account_data)
global_account_data[AccountDataTypes.PUSH_RULES] = (
await self._push_rules_handler.push_rules_for_user(sync_config.user)
)
global_account_data[
AccountDataTypes.PUSH_RULES
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
else:
all_global_account_data = await self.store.get_global_account_data_for_user(
user_id
)
global_account_data = dict(all_global_account_data)
global_account_data[AccountDataTypes.PUSH_RULES] = (
await self._push_rules_handler.push_rules_for_user(sync_config.user)
)
global_account_data[
AccountDataTypes.PUSH_RULES
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data(

View File

@ -183,7 +183,7 @@ class WorkerLocksHandler:
return
def _wake_all_locks(
locks: Collection[Union[WaitingLock, WaitingMultiLock]]
locks: Collection[Union[WaitingLock, WaitingMultiLock]],
) -> None:
for lock in locks:
deferred = lock.deferred

View File

@ -1313,6 +1313,5 @@ def is_unknown_endpoint(
)
) or (
# Older Synapses returned a 400 error.
e.code == 400
and synapse_error.errcode == Codes.UNRECOGNIZED
e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
)

View File

@ -233,7 +233,7 @@ def return_html_error(
def wrap_async_request_handler(
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]]
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]],
) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.

View File

@ -22,6 +22,7 @@
"""
Log formatters that output terse JSON.
"""
import json
import logging

View File

@ -20,7 +20,7 @@
#
#
""" Thread-local-alike tracking of log contexts within synapse
"""Thread-local-alike tracking of log contexts within synapse
This module provides objects and utilities for tracking contexts through
synapse code, so that log lines can include a request identifier, and so that
@ -29,6 +29,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
import logging
import threading
import typing
@ -751,7 +752,7 @@ def preserve_fn(
f: Union[
Callable[P, R],
Callable[P, Awaitable[R]],
]
],
) -> Callable[P, "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""

View File

@ -169,6 +169,7 @@ Gotchas
than one caller? Will all of those calling functions have be in a context
with an active span?
"""
import contextlib
import enum
import inspect
@ -414,7 +415,7 @@ def ensure_active_span(
"""
def ensure_active_span_inner_1(
func: Callable[P, R]
func: Callable[P, R],
) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
def ensure_active_span_inner_2(
@ -700,7 +701,7 @@ def set_operation_name(operation_name: str) -> None:
@only_if_tracing
def force_tracing(
span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel
span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel,
) -> None:
"""Force sampling for the active/given span and its children.
@ -1093,9 +1094,10 @@ def trace_servlet(
# Mypy seems to think that start_context.tag below can be Optional[str], but
# that doesn't appear to be correct and works in practice.
request_tags[
SynapseTags.REQUEST_TAG
] = request.request_metrics.start_context.tag # type: ignore[assignment]
request_tags[SynapseTags.REQUEST_TAG] = (
request.request_metrics.start_context.tag # type: ignore[assignment]
)
# set the tags *after* the servlet completes, in case it decided to
# prioritise the span (tags will get dropped on unprioritised spans)

View File

@ -293,7 +293,7 @@ def wrap_as_background_process(
"""
def wrap_as_background_process_inner(
func: Callable[P, Awaitable[Optional[R]]]
func: Callable[P, Awaitable[Optional[R]]],
) -> Callable[P, "defer.Deferred[Optional[R]]"]:
@wraps(func)
def wrap_as_background_process_inner_2(

View File

@ -304,9 +304,9 @@ class BulkPushRuleEvaluator:
if relation_type == "m.thread" and event.content.get(
"m.relates_to", {}
).get("is_falling_back", False):
related_events["m.in_reply_to"][
"im.vector.is_falling_back"
] = ""
related_events["m.in_reply_to"]["im.vector.is_falling_back"] = (
""
)
return related_events
@ -372,7 +372,8 @@ class BulkPushRuleEvaluator:
gather_results(
(
run_in_background( # type: ignore[call-arg]
self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type]
self.store.get_number_joined_users_in_room,
event.room_id, # type: ignore[arg-type]
),
run_in_background(
self._get_power_levels_and_sender_level,

View File

@ -119,7 +119,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_events_parse"):
room_id = content["room_id"]
backfilled = content["backfilled"]

View File

@ -98,7 +98,9 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
self._store = hs.get_datastores().main
@staticmethod
async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override]
async def _serialize_payload( # type: ignore[override]
user_id: str, old_room_id: str, new_room_id: str
) -> JsonDict:
return {}
async def _handle_request( # type: ignore[override]
@ -109,7 +111,6 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
old_room_id: str,
new_room_id: str,
) -> Tuple[int, JsonDict]:
await self._store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id
)

View File

@ -18,8 +18,8 @@
# [This file includes modifications made by New Vector Limited]
#
#
"""A replication client for use by synapse workers.
"""