make FederationHandler._check_for_soft_fail async

This commit is contained in:
Richard van der Hoff 2020-02-03 16:16:31 +00:00
parent dbdf843012
commit c3f296af32
1 changed files with 9 additions and 13 deletions

View File

@ -1997,27 +1997,23 @@ class FederationHandler(BaseHandler):
return context return context
@defer.inlineCallbacks async def _check_for_soft_fail(
def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
): ) -> None:
"""Checks if we should soft fail the event, if so marks the event as """Checks if we should soft fail the event; if so, marks the event as
such. such.
Args: Args:
event event
state: The state at the event if we don't have all the event's prev events state: The state at the event if we don't have all the event's prev events
backfilled: Whether the event is from backfill backfilled: Whether the event is from backfill
Returns:
Deferred
""" """
# For new (non-backfilled and non-outlier) events we check if the event # For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we # passes auth based on the current state. If it doesn't then we
# "soft-fail" the event. # "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
if do_soft_fail_check: if do_soft_fail_check:
extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids) extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids()) prev_event_ids = set(event.prev_event_ids())
@ -2028,7 +2024,7 @@ class FederationHandler(BaseHandler):
do_soft_fail_check = False do_soft_fail_check = False
if do_soft_fail_check: if do_soft_fail_check:
room_version = yield self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state". # Calculate the "current state".
@ -2045,19 +2041,19 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets = yield self.state_store.get_state_groups( state_sets = await self.state_store.get_state_groups(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets = list(state_sets.values()) state_sets = list(state_sets.values())
state_sets.append(state) state_sets.append(state)
current_state_ids = yield self.state_handler.resolve_events( current_state_ids = await self.state_handler.resolve_events(
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = { current_state_ids = {
k: e.event_id for k, e in iteritems(current_state_ids) k: e.event_id for k, e in iteritems(current_state_ids)
} }
else: else:
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids event.room_id, latest_event_ids=extrem_ids
) )
@ -2073,7 +2069,7 @@ class FederationHandler(BaseHandler):
e for k, e in iteritems(current_state_ids) if k in auth_types e for k, e in iteritems(current_state_ids) if k in auth_types
] ]
current_auth_events = yield self.store.get_events(current_state_ids) current_auth_events = await self.store.get_events(current_state_ids)
current_auth_events = { current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values() (e.type, e.state_key): e for e in current_auth_events.values()
} }