Move `resolve_events_with_store` into StateResolutionHandler

This commit is contained in:
Richard van der Hoff 2020-09-28 15:20:02 +01:00
parent c2bdf040aa
commit 937393abd8
2 changed files with 55 additions and 50 deletions

View File

@ -21,7 +21,7 @@ import itertools
import logging import logging
from collections.abc import Container from collections.abc import Container
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -69,7 +69,7 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet, ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet, ReplicationStoreRoomOnInviteRestServlet,
) )
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@ -85,6 +85,9 @@ from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,7 +119,7 @@ class FederationHandler(BaseHandler):
rooms. rooms.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
@ -126,6 +129,7 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state self.state_store = self.storage.state
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
@ -381,8 +385,7 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
state_map = await resolve_events_with_store( state_map = await self._state_resolution_handler.resolve_events_with_store(
self.clock,
room_id, room_id,
room_version, room_version,
state_maps, state_maps,

View File

@ -449,8 +449,7 @@ class StateHandler:
state_map = {ev.event_id: ev for st in state_sets for ev in st} state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = await resolve_events_with_store( new_state = await self._state_resolution_handler.resolve_events_with_store(
self.clock,
event.room_id, event.room_id,
room_version, room_version,
state_set_ids, state_set_ids,
@ -531,8 +530,7 @@ class StateResolutionHandler:
state_groups_histogram.observe(len(state_groups_ids)) state_groups_histogram.observe(len(state_groups_ids))
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = await resolve_events_with_store( new_state = await self.resolve_events_with_store(
self.clock,
room_id, room_id,
room_version, room_version,
list(state_groups_ids.values()), list(state_groups_ids.values()),
@ -552,6 +550,51 @@ class StateResolutionHandler:
return cache return cache
def resolve_events_with_store(
self,
room_id: str,
room_version: str,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
room_version: Version of the room
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_res_store.
state_res_store: a place to fetch events from
Returns:
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store(
room_id, state_sets, event_map, state_res_store.get_events
)
else:
return v2.resolve_events_with_store(
self.clock,
room_id,
room_version,
state_sets,
event_map,
state_res_store,
)
def _make_state_cache_entry( def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
@ -605,47 +648,6 @@ def _make_state_cache_entry(
) )
def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
room_version: Version of the room
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_res_store.
state_res_store: a place to fetch events from
Returns:
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store(
room_id, state_sets, event_map, state_res_store.get_events
)
else:
return v2.resolve_events_with_store(
clock, room_id, room_version, state_sets, event_map, state_res_store
)
@attr.s(slots=True) @attr.s(slots=True)
class StateResolutionStore: class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database """Interface that allows state resolution algorithms to access the database