diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0073e7c996..1a8144405a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,7 +21,7 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -69,7 +69,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnInviteRestServlet, ) -from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, @@ -85,6 +85,9 @@ from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -116,7 +119,7 @@ class FederationHandler(BaseHandler): rooms. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs @@ -126,6 +129,7 @@ class FederationHandler(BaseHandler): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() + self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() self.action_generator = hs.get_action_generator() @@ -381,8 +385,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await resolve_events_with_store( - self.clock, + state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5a5ea39e01..98ede2ea4f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -449,8 +449,7 @@ class StateHandler: state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, + new_state = await self._state_resolution_handler.resolve_events_with_store( event.room_id, room_version, state_set_ids, @@ -531,8 +530,7 @@ class StateResolutionHandler: state_groups_histogram.observe(len(state_groups_ids)) with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, + new_state = await self.resolve_events_with_store( room_id, room_version, list(state_groups_ids.values()), @@ -552,6 +550,51 @@ class StateResolutionHandler: return cache + def resolve_events_with_store( + self, + room_id: str, + room_version: str, + state_sets: Sequence[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", + ) -> Awaitable[StateMap[str]]: + """ + Args: + room_id: the room we are working in + + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map: + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_res_store. + + state_res_store: a place to fetch events from + + Returns: + a map from (type, state_key) to event_id. + """ + v = KNOWN_ROOM_VERSIONS[room_version] + if v.state_res == StateResolutionVersions.V1: + return v1.resolve_events_with_store( + room_id, state_sets, event_map, state_res_store.get_events + ) + else: + return v2.resolve_events_with_store( + self.clock, + room_id, + room_version, + state_sets, + event_map, + state_res_store, + ) + def _make_state_cache_entry( new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] @@ -605,47 +648,6 @@ def _make_state_cache_entry( ) -def resolve_events_with_store( - clock: Clock, - room_id: str, - room_version: str, - state_sets: Sequence[StateMap[str]], - event_map: Optional[Dict[str, EventBase]], - state_res_store: "StateResolutionStore", -) -> Awaitable[StateMap[str]]: - """ - Args: - room_id: the room we are working in - - room_version: Version of the room - - state_sets: List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - - event_map: - a dict from event_id to event, for any events that we happen to - have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. - - If None, all events will be fetched via state_res_store. - - state_res_store: a place to fetch events from - - Returns: - a map from (type, state_key) to event_id. - """ - v = KNOWN_ROOM_VERSIONS[room_version] - if v.state_res == StateResolutionVersions.V1: - return v1.resolve_events_with_store( - room_id, state_sets, event_map, state_res_store.get_events - ) - else: - return v2.resolve_events_with_store( - clock, room_id, room_version, state_sets, event_map, state_res_store - ) - - @attr.s(slots=True) class StateResolutionStore: """Interface that allows state resolution algorithms to access the database