Add `include_event_in_state` to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one before it. This is a bit easier and requires fewer calls to get_events_from_store_or_dest.
This commit is contained in:
parent
be294d6fde
commit
20d5ba16e6
|
@ -0,0 +1 @@
|
|||
Refactor some code in the event authentication path for clarity.
|
|
@ -379,22 +379,10 @@ class FederationHandler(BaseHandler):
|
|||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = yield self._get_state_for_room(origin, room_id, p)
|
||||
|
||||
# we want the state *after* p; _get_state_for_room returns the
|
||||
# state *before* p.
|
||||
remote_event = yield self.federation_client.get_pdu(
|
||||
[origin], p, room_version, outlier=True
|
||||
) = yield self._get_state_for_room(
|
||||
origin, room_id, p, include_event_in_state=True
|
||||
)
|
||||
|
||||
if remote_event is None:
|
||||
raise Exception(
|
||||
"Unable to get missing prev_event %s" % (p,)
|
||||
)
|
||||
|
||||
if remote_event.is_state():
|
||||
remote_state.append(remote_event)
|
||||
|
||||
# XXX hrm I'm not convinced that duplicate events will compare
|
||||
# for equality, so I'm not sure this does what the author
|
||||
# hoped.
|
||||
|
@ -583,13 +571,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _get_state_for_room(self, destination, room_id, event_id):
|
||||
def _get_state_for_room(self, destination, room_id, event_id, include_event_in_state):
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
room_id (str): The id of the room we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
include_event_in_state: if true, the event itself will be included in the
|
||||
returned state event list.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[EventBase], List[EventBase]]]:
|
||||
|
@ -604,6 +594,10 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
desired_events = set(state_event_ids + auth_event_ids)
|
||||
|
||||
if include_event_in_state:
|
||||
desired_events.add(event_id)
|
||||
|
||||
event_map = yield self._get_events_from_store_or_dest(
|
||||
destination, room_id, desired_events
|
||||
)
|
||||
|
@ -616,12 +610,21 @@ class FederationHandler(BaseHandler):
|
|||
failed_to_fetch,
|
||||
)
|
||||
|
||||
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||
remote_state = [
|
||||
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
||||
]
|
||||
|
||||
if include_event_in_state:
|
||||
remote_event = event_map.get(event_id)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
if remote_event.is_state():
|
||||
remote_state.append(remote_event)
|
||||
|
||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
return pdus, auth_chain
|
||||
return remote_state, auth_chain
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
|
||||
|
|
Loading…
Reference in New Issue