Initial stab at implementing a batched get_missing_pdus request
This commit is contained in:
parent
894a89d99b
commit
0ac2a79faa
|
@ -305,6 +305,78 @@ class FederationServer(FederationBase):
|
|||
(200, send_content)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_missing_events(self, origin, room_id, earliest_events,
|
||||
latest_events, limit, min_depth):
|
||||
limit = max(limit, 50)
|
||||
min_depth = max(min_depth, 0)
|
||||
|
||||
missing_events = yield self.store.get_missing_events(
|
||||
room_id=room_id,
|
||||
earliest_events=earliest_events,
|
||||
latest_events=latest_events,
|
||||
limit=limit,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
|
||||
known_ids = {e.event_id for e in missing_events} | {earliest_events}
|
||||
|
||||
back_edges = {
|
||||
e for e in missing_events
|
||||
if {i for i, h in e.prev_events.items()} <= known_ids
|
||||
}
|
||||
|
||||
decoded_auth_events = set()
|
||||
state = {}
|
||||
auth_events = set()
|
||||
auth_and_state = {}
|
||||
for event in back_edges:
|
||||
state_pdus = yield self.handler.get_state_for_pdu(
|
||||
origin, room_id, event.event_id,
|
||||
do_auth=False,
|
||||
)
|
||||
|
||||
state[event.event_id] = [s.event_id for s in state_pdus]
|
||||
|
||||
auth_and_state.update({
|
||||
s.event_id: s for s in state_pdus
|
||||
})
|
||||
|
||||
state_ids = {pdu.event_id for pdu in state_pdus}
|
||||
prev_ids = {i for i, h in event.prev_events.items()}
|
||||
partial_auth_chain = yield self.store.get_auth_chain(
|
||||
state_ids | prev_ids, have_ids=decoded_auth_events.keys()
|
||||
)
|
||||
|
||||
for p in partial_auth_chain:
|
||||
p.signatures.update(
|
||||
compute_event_signature(
|
||||
p,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
auth_events.update(
|
||||
a.event_id for a in partial_auth_chain
|
||||
)
|
||||
|
||||
auth_and_state.update({
|
||||
a.event_id: a for a in partial_auth_chain
|
||||
})
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||
"state_for_events": state,
|
||||
"auth_events": auth_events,
|
||||
"event_map": {
|
||||
k: ev.get_pdu_json(time_now)
|
||||
for k, ev in auth_and_state.items()
|
||||
},
|
||||
})
|
||||
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
|
|
@ -581,9 +581,10 @@ class FederationHandler(BaseHandler):
|
|||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, origin, room_id, event_id):
|
||||
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
|
||||
yield run_on_reactor()
|
||||
|
||||
if do_auth:
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
|
|
@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
|
|||
and backfilling from another server respectively.
|
||||
"""
|
||||
|
||||
def get_auth_chain(self, event_ids):
|
||||
def get_auth_chain(self, event_ids, have_ids=set()):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain",
|
||||
self._get_auth_chain_txn,
|
||||
event_ids
|
||||
event_ids, have_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_txn(self, txn, event_ids):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
||||
def _get_auth_chain_txn(self, txn, event_ids, have_ids):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_ids, have_ids)
|
||||
|
||||
return self._get_events_txn(txn, results)
|
||||
|
||||
|
@ -51,8 +51,9 @@ class EventFederationStore(SQLBaseStore):
|
|||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids):
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids, have_ids):
|
||||
results = set()
|
||||
have_ids = set(have_ids)
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE event_id = ?"
|
||||
|
@ -64,6 +65,10 @@ class EventFederationStore(SQLBaseStore):
|
|||
for f in front:
|
||||
txn.execute(base_sql, (f,))
|
||||
new_front.update([r[0] for r in txn.fetchall()])
|
||||
|
||||
new_front -= results
|
||||
new_front -= have_ids
|
||||
|
||||
front = new_front
|
||||
results.update(front)
|
||||
|
||||
|
@ -378,3 +383,51 @@ class EventFederationStore(SQLBaseStore):
|
|||
event_results += new_front
|
||||
|
||||
return self._get_events_txn(txn, event_results)
|
||||
|
||||
def get_missing_events(self, room_id, earliest_events, latest_events,
|
||||
limit, min_depth):
|
||||
return self.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
||||
limit, min_depth):
|
||||
|
||||
earliest_events = set(earliest_events)
|
||||
front = set(latest_events) - earliest_events
|
||||
|
||||
event_results = set()
|
||||
|
||||
query = (
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
|
||||
"LIMIT ?"
|
||||
)
|
||||
|
||||
while front and len(event_results) < limit:
|
||||
new_front = set()
|
||||
for event_id in front:
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, limit - len(event_results))
|
||||
)
|
||||
|
||||
for e_id, in txn.fetchall():
|
||||
new_front.add(e_id)
|
||||
|
||||
new_front -= earliest_events
|
||||
new_front -= event_results
|
||||
|
||||
front = new_front
|
||||
event_results |= new_front
|
||||
|
||||
events = self._get_events_txn(txn, event_results)
|
||||
|
||||
events = sorted(
|
||||
[ev for ev in events if ev.depth >= min_depth],
|
||||
key=lambda e: e.depth,
|
||||
)
|
||||
|
||||
return events[:limit]
|
||||
|
|
Loading…
Reference in New Issue