From 5758dcf30c245efa1032385cd1af7853d39642a9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 14:25:27 -0400 Subject: [PATCH] Add type hints for state. (#8140) --- changelog.d/8140.misc | 1 + stubs/frozendict.pyi | 47 +++++ synapse/federation/sender/__init__.py | 4 +- synapse/handlers/federation.py | 10 +- synapse/handlers/presence.py | 6 +- synapse/handlers/room_member.py | 20 +- synapse/state/__init__.py | 192 ++++++++++++------- synapse/state/v1.py | 87 ++++++--- synapse/state/v2.py | 255 +++++++++++++++++--------- tox.ini | 1 + 10 files changed, 420 insertions(+), 203 deletions(-) create mode 100644 changelog.d/8140.misc create mode 100644 stubs/frozendict.pyi diff --git a/changelog.d/8140.misc b/changelog.d/8140.misc new file mode 100644 index 0000000000..78d8834328 --- /dev/null +++ b/changelog.d/8140.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi new file mode 100644 index 0000000000..3f3af59f26 --- /dev/null +++ b/stubs/frozendict.pyi @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for frozendict. + +from typing import ( + Any, + Hashable, + Iterable, + Iterator, + Mapping, + overload, + Tuple, + TypeVar, +) + +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. + +class frozendict(Mapping[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + def __getitem__(self, key: _KT) -> _VT: ... + def __contains__(self, key: Any) -> bool: ... + def copy(self, **add_or_replace: Any) -> frozendict: ... + def __iter__(self) -> Iterator[_KT]: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index e53b6ac456..4662008bfd 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -329,10 +329,10 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = await self.state.get_current_hosts_in_room(room_id) + domains_set = await self.state.get_current_hosts_in_room(room_id) domains = [ d - for d in domains + for d in domains_set if d != self.server_name and self._federation_shard_config.should_handle(self._instance_name, d) ] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5b270228e7..f8b234cee2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler): ) state_sets = list(state_sets.values()) state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( + current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + current_state_ids = {k: e.event_id for k, e in current_states.items()} else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) - current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] - auth_events_map = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 24e1940ee5..1846068150 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -1318,7 +1318,7 @@ async def get_interested_parties( async def get_interested_remotes( store: DataStore, states: List[UserPresenceState], state_handler: StateHandler -) -> List[Tuple[List[str], List[UserPresenceState]]]: +) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1334,7 +1334,7 @@ async def get_interested_remotes( each tuple the list of UserPresenceState should be sent to each destination """ - hosts_and_states = [] + hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a03cb02792..52548087a9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID +from synapse.types import ( + Collection, + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -738,9 +746,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -969,9 +975,7 @@ class RoomMemberHandler(object): ) return stream_id - async def _is_host_in_room( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dba8d91eef..a601303fa3 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,11 +16,22 @@ import logging from collections import namedtuple -from typing import Awaitable, Dict, Iterable, List, Optional, Set +from typing import ( + Awaitable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Union, + overload, +) import attr from frozendict import frozendict from prometheus_client import Histogram +from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -30,7 +41,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import StateMap +from synapse.types import Collection, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -68,8 +79,14 @@ def _gen_state_id(): class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group, prev_group=None, delta_ids=None): - # dict[(str, str), str] map from (type, state_key) to event_id + def __init__( + self, + state: StateMap[str], + state_group: Optional[int], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ): + # A map from (type, state_key) to event_id. self.state = frozendict(state) # the ID of a state group if one and only one is involved. @@ -107,24 +124,49 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() + @overload async def get_current_state( - self, room_id, event_type=None, state_key="", latest_event_ids=None - ): - """ Retrieves the current state for the room. This is done by + self, + room_id: str, + event_type: Literal[None] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> StateMap[EventBase]: + ... + + @overload + async def get_current_state( + self, + room_id: str, + event_type: str, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Optional[EventBase]: + ... + + async def get_current_state( + self, + room_id: str, + event_type: Optional[str] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Union[Optional[EventBase], StateMap[EventBase]]: + """Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. This is equivalent to getting the state of an event that were to send next before receiving any new events. - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - Returns: - map from (type, state_key) to event + If `event_type` is specified, then the method returns only the one + event (or None) with that `event_type` and `state_key`. + + Otherwise, a map from (type, state_key) to event. """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) @@ -140,34 +182,30 @@ class StateHandler(object): state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) - state = { + return { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - return state - - async def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids( + self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: - room_id (str): - - latest_event_ids (iterable[str]|None): if given, the forward - extremities to resolve. If None, we look them up from the - database (via a cache) + room_id: + latest_event_ids: if given, the forward extremities to resolve. If + None, we look them up from the database (via a cache). Returns: - Deferred[dict[(str, str), str)]]: the state dict, mapping from - (event_type, state_key) -> event_id + the state dict, mapping from (event_type, state_key) -> event_id """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - return state + return dict(ret.state) async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -183,32 +221,34 @@ class StateHandler(object): """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None + logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = await self.store.get_joined_users_from_state(room_id, entry) - return joined_users + return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id): + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) - async def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events( + self, room_id: str, event_ids: List[str] + ) -> Set[str]: """Get the hosts that were in a room at the given event ids Args: - room_id (str): - event_ids (list[str]): + room_id: + event_ids: Returns: - Deferred[list[str]]: the hosts in the room at the given events + The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = await self.store.get_joined_hosts(room_id, entry) - return joined_hosts + return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None - ): + ) -> EventContext: """Build an EventContext structure for the event. This works out what the current state should be for the event, and @@ -221,7 +261,7 @@ class StateHandler(object): when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. Returns: - synapse.events.snapshot.EventContext: + The event context. """ if event.internal_metadata.is_outlier(): @@ -275,7 +315,7 @@ class StateHandler(object): event.room_id, event.prev_event_ids() ) - state_ids_before_event = entry.state + state_ids_before_event = dict(entry.state) state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -346,19 +386,18 @@ class StateHandler(object): ) @measure_func() - async def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events( + self, room_id: str, event_ids: Iterable[str] + ) -> _StateCacheEntry: """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str) - event_ids (list[str]) - explicit_room_version (str|None): If set uses the the given room - version to choose the resolution algorithm. If None, then - checks the database for room version. + room_id + event_ids Returns: - Deferred[_StateCacheEntry]: resolved state + The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -394,7 +433,12 @@ class StateHandler(object): ) return result - async def resolve_events(self, room_version, state_sets, event): + async def resolve_events( + self, + room_version: str, + state_sets: Collection[Iterable[EventBase]], + event: EventBase, + ) -> StateMap[EventBase]: logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,9 +458,7 @@ class StateHandler(object): state_res_store=StateResolutionStore(self.store), ) - new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()} - - return new_state + return {key: state_map[ev_id] for key, ev_id in new_state.items()} class StateResolutionHandler(object): @@ -444,7 +486,12 @@ class StateResolutionHandler(object): @log_function async def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store + self, + room_id: str, + room_version: str, + state_groups_ids: Dict[int, StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", ): """Resolves conflicts between a set of state groups @@ -452,13 +499,13 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging and sanity checks) - room_version (str): version of the room - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group + room_id: room we are resolving for (used for logging and sanity checks) + room_version: version of the room + state_groups_ids: + A map from state group id to the state in that state group (where 'state' is a map from state key to event id) - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -466,10 +513,10 @@ class StateResolutionHandler(object): If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store Returns: - _StateCacheEntry: resolved state + The resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) @@ -530,21 +577,22 @@ class StateResolutionHandler(object): return cache -def _make_state_cache_entry(new_state, state_groups_ids): +def _make_state_cache_entry( + new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] +) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. Args: - new_state (dict[(str, str), str]): resolved state map (mapping from - (type, state_key) to event_id) + new_state: resolved state map (mapping from (type, state_key) to event_id) - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group - (where 'state' is a map from state key to event id) + state_groups_ids: + map from state group id to the state in that state group (where + 'state' is a map from state key to event id) Returns: - _StateCacheEntry + The cache entry. """ # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -585,7 +633,7 @@ def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> Awaitable[StateMap[str]]: @@ -633,15 +681,17 @@ class StateResolutionStore(object): store = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. + An awaitable which resolves to a dict from event_id to event. """ return self.store.get_events( @@ -651,7 +701,9 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + def get_auth_chain_difference( + self, state_sets: List[Set[str]] + ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -660,7 +712,7 @@ class StateResolutionStore(object): chain. Returns: - Deferred[Set[str]]: Set of event IDs. + An awaitable that resolves to a set of event IDs. """ return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index ab5e24841d..0eb7fdd9e5 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,17 @@ import hashlib import logging -from typing import Awaitable, Callable, Dict, List, Optional +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "") async def resolve_events_with_store( room_id: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[List[str]], Awaitable], -): + state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], +) -> StateMap[str]: """ Args: room_id: the room we are working in @@ -56,8 +66,7 @@ async def resolve_events_with_store( an Awaitable that resolves to a dict of event_id to event. Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -75,8 +84,8 @@ async def resolve_events_with_store( "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) + # A map from state event id to event. Only includes the state events which + # are in conflict (and those in event_map). state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -91,8 +100,6 @@ async def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -122,29 +129,30 @@ async def resolve_events_with_store( ) -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. Args: - state_sets(iterable[dict[(str, str), str]]): + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: + A tuple of (unconflicted_state, conflicted_state), where: - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} + conflicted_state = {} # type: StateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -171,7 +179,21 @@ def _seperate(state_sets): return unconflicted_state, conflicted_state -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): +def _create_auth_events_from_maps( + unconflicted_state: StateMap[str], + conflicted_state: StateMap[Set[str]], + state_map: Dict[str, EventBase], +) -> StateMap[str]: + """ + + Args: + unconflicted_state: The unconflicted state map. + conflicted_state: The conflicted state map. + state_map: + + Returns: + A map from state key to event id. + """ auth_events = {} for event_ids in conflicted_state.values(): for event_id in event_ids: @@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma keys = event_auth.auth_types_for_event(state_map[event_id]) for key in keys: if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + auth_event_id = unconflicted_state.get(key, None) + if auth_event_id: + auth_events[key] = auth_event_id return auth_events def _resolve_with_state( - unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map + unconflicted_state_ids: StateMap[str], + conflicted_state_ids: StateMap[Set[str]], + auth_event_ids: StateMap[str], + state_map: Dict[str, EventBase], ): conflicted_state = {} for key, event_ids in conflicted_state_ids.items(): @@ -215,7 +240,9 @@ def _resolve_with_state( return new_state -def _resolve_state_events(conflicted_state, auth_events): +def _resolve_state_events( + conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] +) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. @@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events): return resolved_state -def _resolve_auth_events(events, auth_events): +def _resolve_auth_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { @@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events): return event -def _resolve_normal_events(events, auth_events): +def _resolve_normal_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: for event in _ordered_events(events): try: # The signatures have already been checked at this point @@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events): return event -def _ordered_events(events): +def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]: def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 6634955cdc..0e9ffbd6e6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,21 @@ import heapq import itertools import logging -from typing import Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + overload, +) + +from typing_extensions import Literal import synapse.state from synapse import event_auth @@ -40,10 +54,10 @@ async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", -): +) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm Args: @@ -63,8 +77,7 @@ async def resolve_events_with_store( state_res_store: Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ logger.debug("Computing conflicted state") @@ -171,18 +184,23 @@ async def resolve_events_with_store( return resolved_state -async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Return the power level of the sender of the given event according to their auth events. Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + room_id + event_id + event_map + state_res_store Returns: - Deferred[int] + The power level. """ event = await _get_event(room_id, event_id, event_map, state_res_store) @@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st return int(level) -async def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference( + state_sets: Sequence[StateMap[str]], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. Args: - state_sets (list) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + state_sets + event_map + state_res_store Returns: - Deferred[set[str]]: Set of event IDs + Set of event IDs """ difference = await state_res_store.get_auth_chain_difference( @@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store): return difference -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Return the unconflicted and conflicted state. This is different than in the original algorithm, as this defines a key to be conflicted if one of the state sets doesn't have that key. Args: - state_sets (list) + state_sets Returns: - tuple[dict, dict]: A tuple of unconflicted and conflicted state. The - conflicted state dict is a map from type/state_key to set of event IDs + A tuple of unconflicted and conflicted state. The conflicted state dict + is a map from type/state_key to set of event IDs """ unconflicted_state = {} conflicted_state = {} @@ -260,18 +284,20 @@ def _seperate(state_sets): event_ids.discard(None) conflicted_state[key] = event_ids - return unconflicted_state, conflicted_state + # mypy doesn't understand that discarding None above means that conflicted + # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. + return unconflicted_state, conflicted_state # type: ignore -def _is_power_event(event): +def _is_power_event(event: EventBase) -> bool: """Return whether or not the event is a "power event", as defined by the v2 state resolution algorithm Args: - event (FrozenEvent) + event Returns: - boolean + True if the event is a power event. """ if (event.type, event.state_key) in ( (EventTypes.PowerLevels, ""), @@ -288,19 +314,23 @@ def _is_power_event(event): async def _add_event_and_auth_chain_to_graph( - graph, room_id, event_id, event_map, state_res_store, auth_diff -): + graph: Dict[str, Set[str]], + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> None: """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph Args: - graph (dict[str, set[str]]): A map from event ID to the events auth - event IDs - room_id (str): the room we are working in - event_id (str): Event to add to the graph - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + graph: A map from event ID to the events auth event IDs + room_id: the room we are working in + event_id: Event to add to the graph + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. """ state = [event_id] @@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph( async def _reverse_topological_power_sort( - clock, room_id, event_ids, event_map, state_res_store, auth_diff -): + clock: Clock, + room_id: str, + event_ids: Iterable[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: - clock (Clock) - room_id (str): the room we are working in - event_ids (list[str]): The events to sort - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + clock + room_id: the room we are working in + event_ids: The events to sort + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. Returns: - Deferred[list[str]]: The sorted list + The sorted list """ - graph = {} + graph = {} # type: Dict[str, Set[str]] for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff @@ -372,22 +407,28 @@ async def _reverse_topological_power_sort( async def _iterative_auth_checks( - clock, room_id, room_version, event_ids, base_state, event_map, state_res_store -): + clock: Clock, + room_id: str, + room_version: str, + event_ids: List[str], + base_state: StateMap[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> StateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: - clock (Clock) - room_id (str) - room_version (str) - event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (StateMap[str]): The set of state to start with - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id + room_version + event_ids: Ordered list of events to apply auth checks to + base_state: The set of state to start with + event_map + state_res_store Returns: - Deferred[StateMap[str]]: Returns the final updated state + Returns the final updated state """ resolved_state = base_state.copy() room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -439,21 +480,26 @@ async def _iterative_auth_checks( async def _mainline_sort( - clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store -): + clock: Clock, + room_id: str, + event_ids: List[str], + resolved_power_event_id: Optional[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> List[str]: """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: - clock (Clock) - room_id (str): room we're working in - event_ids (list[str]): Events to sort - resolved_power_event_id (str): The final resolved power level event ID - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id: room we're working in + event_ids: Events to sort + resolved_power_event_id: The final resolved power level event ID + event_map + state_res_store Returns: - Deferred[list[str]]: The sorted list + The sorted list """ if not event_ids: # It's possible for there to be no event IDs here to sort, so we can @@ -505,59 +551,90 @@ async def _mainline_sort( async def _get_mainline_depth_for_event( - event, mainline_map, event_map, state_res_store -): + event: EventBase, + mainline_map: Dict[str, int], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Get the mainline depths for the given event based on the mainline map Args: - event (FrozenEvent) - mainline_map (dict[str, int]): Map from event_id to mainline depth for - events in the mainline. - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + event + mainline_map: Map from event_id to mainline depth for events in the mainline. + event_map + state_res_store Returns: - Deferred[int] + The mainline depth """ room_id = event.room_id + tmp_event = event # type: Optional[EventBase] # We do an iterative search, replacing `event with the power level in its # auth events (if any) - while event: + while tmp_event: depth = mainline_map.get(event.event_id) if depth is not None: return depth - auth_events = event.auth_event_ids() - event = None + auth_events = tmp_event.auth_event_ids() + tmp_event = None for aid in auth_events: aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): - event = aev + tmp_event = aev break # Didn't find a power level auth event, so we just return 0 return 0 -async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[False] = False, +) -> EventBase: + ... + + +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[True], +) -> Optional[EventBase]: + ... + + +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: bool = False, +) -> Optional[EventBase]: """Helper function to look up event in event_map, falling back to looking it up in the store Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - allow_none (bool): if the event is not found, return None rather than raising + room_id + event_id + event_map + state_res_store + allow_none: if the event is not found, return None rather than raising an exception Returns: - Deferred[Optional[FrozenEvent]] + The event, or none if the event does not exist (and allow_none is True). """ if event_id not in event_map: events = await state_res_store.get_events([event_id], allow_rejected=True) @@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F return event -def lexicographical_topological_sort(graph, key): +def lexicographical_topological_sort( + graph: Dict[str, Set[str]], key: Callable[[str], Any] +) -> Generator[str, None, None]: """Performs a lexicographic reverse topological sort on the graph. This returns a reverse topological sort (i.e. if node A references B then B @@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key): NOTE: `graph` is modified during the sort. Args: - graph (dict[str, set[str]]): A representation of the graph where each - node is a key in the dict and its value are the nodes edges. - key (func): A function that takes a node and returns a value that is - comparable and used to order nodes + graph: A representation of the graph where each node is a key in the + dict and its value are the nodes edges. + key: A function that takes a node and returns a value that is comparable + and used to order nodes Yields: - str: The next node in the topological sort + The next node in the topological sort """ # Note, this is basically Kahn's algorithm except we look at nodes with no # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} + reverse_graph = {} # type: Dict[str, Set[str]] # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing diff --git a/tox.ini b/tox.ini index ea804108b5..edeb757f7b 100644 --- a/tox.ini +++ b/tox.ini @@ -209,6 +209,7 @@ commands = mypy \ synapse/server.py \ synapse/server_notices \ synapse/spam_checker_api \ + synapse/state \ synapse/storage/databases/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \