Rename `event_map` to `unpersisted_events` (#13603)

This commit is contained in:
David Robertson 2022-08-24 21:06:31 +01:00 committed by GitHub
parent 1a209efdb2
commit c406d50d2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 31 deletions

1
changelog.d/13603.misc Normal file
View File

@ -0,0 +1 @@
Rename `event_map` to `unpersisted_events` when computing the auth differences.

View File

@ -271,40 +271,41 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference( async def _get_auth_chain_difference(
room_id: str, room_id: str,
state_sets: Sequence[Mapping[Any, str]], state_sets: Sequence[Mapping[Any, str]],
event_map: Dict[str, EventBase], unpersisted_events: Dict[str, EventBase],
state_res_store: StateResolutionStore, state_res_store: StateResolutionStore,
) -> Set[str]: ) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events """Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains. that only appear in some, but not all of the auth chains.
Args: Args:
state_sets state_sets: The input state sets we are trying to resolve across.
event_map unpersisted_events: A map from event ID to EventBase containing all unpersisted
state_res_store events involved in this resolution.
state_res_store:
Returns: Returns:
Set of event IDs The auth difference of the given state sets, as a set of event IDs.
""" """
# The `StateResolutionStore.get_auth_chain_difference` function assumes that # The `StateResolutionStore.get_auth_chain_difference` function assumes that
# all events passed to it (and their auth chains) have been persisted # all events passed to it (and their auth chains) have been persisted
# previously. This is not the case for any events in the `event_map`, and so # previously. We need to manually handle any other events that are yet to be
# we need to manually handle those events. # persisted.
# #
# We do this by: # We do this in three steps:
# 1. calculating the auth chain difference for the state sets based on the # 1. Compute the set of unpersisted events belonging to the auth difference.
# events in `event_map` alone # 2. Replacing any unpersisted events in the state_sets with their auth events,
# 2. replacing any events in the state_sets that are also in `event_map` # recursively, until the state_sets contain only persisted events.
# with their auth events (recursively), and then calling # Then we call `store.get_auth_chain_difference` as normal, which computes
# `store.get_auth_chain_difference` as normal # the set of persisted events belonging to the auth difference.
# 3. adding the results of 1 and 2 together. # 3. Adding the results of 1 and 2 together.
# Map from event ID in `event_map` to their auth event IDs, and their auth # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
# event IDs if they appear in the `event_map`. This is the intersection of # event IDs if they appear in the `unpersisted_events`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their # the event's auth chain with the events in `unpersisted_events` *plus* their
# auth event IDs. # auth event IDs.
events_to_auth_chain: Dict[str, Set[str]] = {} events_to_auth_chain: Dict[str, Set[str]] = {}
for event in event_map.values(): for event in unpersisted_events.values():
chain = {event.event_id} chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain events_to_auth_chain[event.event_id] = chain
@ -312,16 +313,16 @@ async def _get_auth_chain_difference(
while to_search: while to_search:
for auth_id in to_search.pop().auth_event_ids(): for auth_id in to_search.pop().auth_event_ids():
chain.add(auth_id) chain.add(auth_id)
auth_event = event_map.get(auth_id) auth_event = unpersisted_events.get(auth_id)
if auth_event: if auth_event:
to_search.append(auth_event) to_search.append(auth_event)
# We now a) calculate the auth chain difference for the unpersisted events # We now 1) calculate the auth chain difference for the unpersisted events
# and b) work out the state sets to pass to the store. # and 2) work out the state sets to pass to the store.
# #
# Note: If the `event_map` is empty (which is the common case), we can do a # Note: If there are no `unpersisted_events` (which is the common case), we can do a
# much simpler calculation. # much simpler calculation.
if event_map: if unpersisted_events:
# The list of state sets to pass to the store, where each state set is a set # The list of state sets to pass to the store, where each state set is a set
# of the event ids making up the state. This is similar to `state_sets`, # of the event ids making up the state. This is similar to `state_sets`,
# except that (a) we only have event ids, not the complete # except that (a) we only have event ids, not the complete
@ -344,14 +345,18 @@ async def _get_auth_chain_difference(
for event_id in state_set.values(): for event_id in state_set.values():
event_chain = events_to_auth_chain.get(event_id) event_chain = events_to_auth_chain.get(event_id)
if event_chain is not None: if event_chain is not None:
# We have an event in `event_map`. We add all the auth # We have an unpersisted event. We add all the auth
# events that it references (that aren't also in `event_map`). # events that it references which are also unpersisted.
set_ids.update(e for e in event_chain if e not in event_map) set_ids.update(
e for e in event_chain if e not in unpersisted_events
)
# We also add the full chain of unpersisted event IDs # We also add the full chain of unpersisted event IDs
# referenced by this state set, so that we can work out the # referenced by this state set, so that we can work out the
# auth chain difference of the unpersisted events. # auth chain difference of the unpersisted events.
unpersisted_ids.update(e for e in event_chain if e in event_map) unpersisted_ids.update(
e for e in event_chain if e in unpersisted_events
)
else: else:
set_ids.add(event_id) set_ids.add(event_id)
@ -361,15 +366,15 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
difference_from_event_map: Collection[str] = union - intersection auth_difference_unpersisted_part: Collection[str] = union - intersection
else: else:
difference_from_event_map = () auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets] state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference( difference = await state_res_store.get_auth_chain_difference(
room_id, state_sets_ids room_id, state_sets_ids
) )
difference.update(difference_from_event_map) difference.update(auth_difference_unpersisted_part)
return difference return difference