make FederationHandler._check_for_soft_fail async
This commit is contained in:
parent
dbdf843012
commit
c3f296af32
|
@ -1997,27 +1997,23 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _check_for_soft_fail(
|
||||||
def _check_for_soft_fail(
|
|
||||||
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
|
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
|
||||||
):
|
) -> None:
|
||||||
"""Checks if we should soft fail the event, if so marks the event as
|
"""Checks if we should soft fail the event; if so, marks the event as
|
||||||
such.
|
such.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event
|
event
|
||||||
state: The state at the event if we don't have all the event's prev events
|
state: The state at the event if we don't have all the event's prev events
|
||||||
backfilled: Whether the event is from backfill
|
backfilled: Whether the event is from backfill
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
# For new (non-backfilled and non-outlier) events we check if the event
|
# For new (non-backfilled and non-outlier) events we check if the event
|
||||||
# passes auth based on the current state. If it doesn't then we
|
# passes auth based on the current state. If it doesn't then we
|
||||||
# "soft-fail" the event.
|
# "soft-fail" the event.
|
||||||
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
|
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
|
||||||
if do_soft_fail_check:
|
if do_soft_fail_check:
|
||||||
extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
|
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||||
|
|
||||||
extrem_ids = set(extrem_ids)
|
extrem_ids = set(extrem_ids)
|
||||||
prev_event_ids = set(event.prev_event_ids())
|
prev_event_ids = set(event.prev_event_ids())
|
||||||
|
@ -2028,7 +2024,7 @@ class FederationHandler(BaseHandler):
|
||||||
do_soft_fail_check = False
|
do_soft_fail_check = False
|
||||||
|
|
||||||
if do_soft_fail_check:
|
if do_soft_fail_check:
|
||||||
room_version = yield self.store.get_room_version_id(event.room_id)
|
room_version = await self.store.get_room_version_id(event.room_id)
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
|
||||||
# Calculate the "current state".
|
# Calculate the "current state".
|
||||||
|
@ -2045,19 +2041,19 @@ class FederationHandler(BaseHandler):
|
||||||
# given state at the event. This should correctly handle cases
|
# given state at the event. This should correctly handle cases
|
||||||
# like bans, especially with state res v2.
|
# like bans, especially with state res v2.
|
||||||
|
|
||||||
state_sets = yield self.state_store.get_state_groups(
|
state_sets = await self.state_store.get_state_groups(
|
||||||
event.room_id, extrem_ids
|
event.room_id, extrem_ids
|
||||||
)
|
)
|
||||||
state_sets = list(state_sets.values())
|
state_sets = list(state_sets.values())
|
||||||
state_sets.append(state)
|
state_sets.append(state)
|
||||||
current_state_ids = yield self.state_handler.resolve_events(
|
current_state_ids = await self.state_handler.resolve_events(
|
||||||
room_version, state_sets, event
|
room_version, state_sets, event
|
||||||
)
|
)
|
||||||
current_state_ids = {
|
current_state_ids = {
|
||||||
k: e.event_id for k, e in iteritems(current_state_ids)
|
k: e.event_id for k, e in iteritems(current_state_ids)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||||
event.room_id, latest_event_ids=extrem_ids
|
event.room_id, latest_event_ids=extrem_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2073,7 +2069,7 @@ class FederationHandler(BaseHandler):
|
||||||
e for k, e in iteritems(current_state_ids) if k in auth_types
|
e for k, e in iteritems(current_state_ids) if k in auth_types
|
||||||
]
|
]
|
||||||
|
|
||||||
current_auth_events = yield self.store.get_events(current_state_ids)
|
current_auth_events = await self.store.get_events(current_state_ids)
|
||||||
current_auth_events = {
|
current_auth_events = {
|
||||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
(e.type, e.state_key): e for e in current_auth_events.values()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue