Convert state resolution to async/await (#7942)

This commit is contained in:
Patrick Cloke 2020-07-24 10:59:51 -04:00 committed by GitHub
parent e739b20588
commit b975fa2e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 198 additions and 184 deletions

1
changelog.d/7942.misc Normal file
View File

@ -0,0 +1 @@
Convert state resolution to async/await.

View File

@ -127,8 +127,10 @@ class Auth(object):
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
)
membership = member.membership if member else None
@ -665,8 +667,10 @@ class Auth(object):
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
)
if (
visibility

View File

@ -106,8 +106,8 @@ class EventBuilder(object):
Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids
state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)

View File

@ -330,7 +330,9 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = yield self.state.get_current_hosts_in_room(room_id)
domains = yield defer.ensureDeferred(
self.state.get_current_hosts_in_room(room_id)
)
domains = [
d
for d in domains

View File

@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)

View File

@ -304,7 +304,9 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))

View File

@ -16,14 +16,12 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def get_current_state(
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = yield self.store.get_events(
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
@ -148,8 +146,7 @@ class StateHandler(object):
return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
room_id (str): The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest
event IDs. Will be computed if None.
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
profileinfo.
Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id):
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
async def get_current_hosts_in_room(self, room_id):
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
@defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
@ -208,12 +202,11 @@ class StateHandler(object):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
@defer.inlineCallbacks
def compute_event_context(
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events(
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@ -295,7 +288,7 @@ class StateHandler(object):
#
if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group(
state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
state_group_after_event = yield self.state_store.store_state_group(
state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@ -353,8 +346,7 @@ class StateHandler(object):
)
@measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.state_store.get_state_groups_ids(
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
room_version = yield self.store.get_room_version_id(room_id)
room_version = await self.store.get_room_version_id(room_id)
result = yield self._state_resolution_handler.resolve_state_groups(
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@ -402,8 +394,7 @@ class StateHandler(object):
)
return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore)
Returns:
Deferred[_StateCacheEntry]: resolved state
_StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)):
with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:

View File

@ -15,9 +15,7 @@
import hashlib
import logging
from typing import Callable, Dict, List, Optional
from twisted.internet import defer
from typing import Awaitable, Callable, Dict, List, Optional
from synapse import event_auth
from synapse.api.constants import EventTypes
@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
def resolve_events_with_store(
async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
state_map_factory: Callable[[List[str]], Awaitable],
):
"""
Args:
@ -56,7 +53,7 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
@ -80,7 +77,7 @@ def resolve_events_with_store(
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@ -110,7 +107,7 @@ def resolve_events_with_store(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
state_map_new = yield state_map_factory(new_needed_events)
state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(

View File

@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
from twisted.internet import defer
import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
@ -32,14 +30,13 @@ from synapse.util import Clock
logger = logging.getLogger(__name__)
# We want to yield to the reactor occasionally during state res when dealing
# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
# yielding to reactor during loops every N iterations.
_YIELD_AFTER_ITERATIONS = 100
# awaiting to reactor during loops every N iterations.
_AWAIT_AFTER_ITERATIONS = 100
@defer.inlineCallbacks
def resolve_events_with_store(
async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
@ -87,7 +84,7 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set(
itertools.chain(
@ -95,7 +92,7 @@ def resolve_events_with_store(
)
)
events = yield state_res_store.get_events(
events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True,
)
@ -118,14 +115,14 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
sorted_power_events = await _reverse_topological_power_sort(
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@ -148,13 +145,13 @@ def resolve_events_with_store(
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
leftover_events = await _mainline_sort(
clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@ -174,8 +171,7 @@ def resolve_events_with_store(
return resolved_state
@defer.inlineCallbacks
def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
event = yield _get_event(room_id, event_id, event_map, state_res_store)
event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level)
@defer.inlineCallbacks
def _get_auth_chain_difference(state_sets, event_map, state_res_store):
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Deferred[set[str]]: Set of event IDs
"""
difference = yield state_res_store.get_auth_chain_difference(
difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
@ -292,8 +287,7 @@ def _is_power_event(event):
return False
@defer.inlineCallbacks
def _add_event_and_auth_chain_to_graph(
async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
"""Helper function for _reverse_topological_power_sort that add the event
@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
event = yield _get_event(room_id, eid, event_map, state_res_store)
event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid)
@defer.inlineCallbacks
def _reverse_topological_power_sort(
async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
graph = {}
for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph(
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
event_to_pl = {}
for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender(
pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
def _get_power_order(event_id):
ev = event_map[event_id]
@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
return sorted_events
@defer.inlineCallbacks
def _iterative_auth_checks(
async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
@ -405,7 +397,7 @@ def _iterative_auth_checks(
auth_events = {}
for aid in event.auth_event_ids():
ev = yield _get_event(
ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
@ -420,7 +412,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
@ -438,16 +430,15 @@ def _iterative_auth_checks(
except AuthError:
pass
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
return resolved_state
@defer.inlineCallbacks
def _mainline_sort(
async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
@ -474,21 +465,21 @@ def _mainline_sort(
idx = 0
while pl:
mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
ev = yield _get_event(
ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
idx += 1
@ -498,23 +489,24 @@ def _mainline_sort(
order_map = {}
for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event(
depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
@defer.inlineCallbacks
def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store
):
"""Get the mainline depths for the given event based on the mainline map
Args:
@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None
for aid in auth_events:
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
return 0
@defer.inlineCallbacks
def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
event = event_map.get(event_id)

View File

@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)

View File

@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)

View File

@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
users_with_profile = yield state.get_current_users_in_room(room_id)
users_with_profile = yield defer.ensureDeferred(
state.get_current_users_in_room(room_id)
)
user_ids = set(users_with_profile)
# Update each user in the user directory.

View File

@ -29,7 +29,6 @@ from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
@ -648,6 +647,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
# Avoid a circular import.
from synapse.state import StateResolutionStore
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,

View File

@ -26,21 +26,24 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
# stub out get_current_hosts_in_room
mock_state_handler = hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):

View File

@ -14,6 +14,7 @@
# limitations under the License.
import itertools
from typing import List
import attr
@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
state_before = self.successResultOf(state_d)
state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
state = self.successResultOf(state_d)
state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
)
def _get_auth_chain(self, event_ids):
def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
event_ids (list): The event IDs of the events to fetch the auth
event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
return set(chains[0]).union(*chains[1:]) - common
return defer.succeed(set(chains[0]).union(*chains[1:]) - common)

View File

@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
state = yield self.store.get_current_state(room_id=self.room.to_string())
state = yield defer.ensureDeferred(
self.store.get_current_state(room_id=self.room.to_string())
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
state = yield self.store.get_current_state(room_id=self.room.to_string())
state = yield defer.ensureDeferred(
self.store.get_current_state(room_id=self.room.to_string())
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(

View File

@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
return state_group
return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
return {
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event
}
return defer.succeed(
{
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event
}
)
def get_state_group_delta(self, name):
return None, None
return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(event, old_state=old_state)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(event, old_state=old_state)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
group_name = self.store.store_state_group(
group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
group_name = self.store.store_state_group(
group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
@defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
sg1 = self.store.store_state_group(
sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = self.store.store_state_group(
sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event)
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result

View File

@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
from typing import Awaitable, TypeVar
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
async def make_awaitable(result: Any):
"""Create an awaitable that just returns a result."""
return result