Initial stab at implementing a batched get_missing_pdus request
This commit is contained in:
parent
894a89d99b
commit
0ac2a79faa
|
@ -305,6 +305,78 @@ class FederationServer(FederationBase):
|
||||||
(200, send_content)
|
(200, send_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_missing_events(self, origin, room_id, earliest_events,
|
||||||
|
latest_events, limit, min_depth):
|
||||||
|
limit = max(limit, 50)
|
||||||
|
min_depth = max(min_depth, 0)
|
||||||
|
|
||||||
|
missing_events = yield self.store.get_missing_events(
|
||||||
|
room_id=room_id,
|
||||||
|
earliest_events=earliest_events,
|
||||||
|
latest_events=latest_events,
|
||||||
|
limit=limit,
|
||||||
|
min_depth=min_depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
known_ids = {e.event_id for e in missing_events} | {earliest_events}
|
||||||
|
|
||||||
|
back_edges = {
|
||||||
|
e for e in missing_events
|
||||||
|
if {i for i, h in e.prev_events.items()} <= known_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded_auth_events = set()
|
||||||
|
state = {}
|
||||||
|
auth_events = set()
|
||||||
|
auth_and_state = {}
|
||||||
|
for event in back_edges:
|
||||||
|
state_pdus = yield self.handler.get_state_for_pdu(
|
||||||
|
origin, room_id, event.event_id,
|
||||||
|
do_auth=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
state[event.event_id] = [s.event_id for s in state_pdus]
|
||||||
|
|
||||||
|
auth_and_state.update({
|
||||||
|
s.event_id: s for s in state_pdus
|
||||||
|
})
|
||||||
|
|
||||||
|
state_ids = {pdu.event_id for pdu in state_pdus}
|
||||||
|
prev_ids = {i for i, h in event.prev_events.items()}
|
||||||
|
partial_auth_chain = yield self.store.get_auth_chain(
|
||||||
|
state_ids | prev_ids, have_ids=decoded_auth_events.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in partial_auth_chain:
|
||||||
|
p.signatures.update(
|
||||||
|
compute_event_signature(
|
||||||
|
p,
|
||||||
|
self.hs.hostname,
|
||||||
|
self.hs.config.signing_key[0]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_events.update(
|
||||||
|
a.event_id for a in partial_auth_chain
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_and_state.update({
|
||||||
|
a.event_id: a for a in partial_auth_chain
|
||||||
|
})
|
||||||
|
|
||||||
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||||
|
"state_for_events": state,
|
||||||
|
"auth_events": auth_events,
|
||||||
|
"event_map": {
|
||||||
|
k: ev.get_pdu_json(time_now)
|
||||||
|
for k, ev in auth_and_state.items()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||||
""" Get a PDU from the database with given origin and id.
|
""" Get a PDU from the database with given origin and id.
|
||||||
|
|
|
@ -581,12 +581,13 @@ class FederationHandler(BaseHandler):
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_pdu(self, origin, room_id, event_id):
|
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
if do_auth:
|
||||||
if not in_room:
|
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||||
raise AuthError(403, "Host not in room.")
|
if not in_room:
|
||||||
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
[event_id]
|
[event_id]
|
||||||
|
|
|
@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
|
||||||
and backfilling from another server respectively.
|
and backfilling from another server respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_auth_chain(self, event_ids):
|
def get_auth_chain(self, event_ids, have_ids=set()):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_auth_chain",
|
"get_auth_chain",
|
||||||
self._get_auth_chain_txn,
|
self._get_auth_chain_txn,
|
||||||
event_ids
|
event_ids, have_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_txn(self, txn, event_ids):
|
def _get_auth_chain_txn(self, txn, event_ids, have_ids):
|
||||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
results = self._get_auth_chain_ids_txn(txn, event_ids, have_ids)
|
||||||
|
|
||||||
return self._get_events_txn(txn, results)
|
return self._get_events_txn(txn, results)
|
||||||
|
|
||||||
|
@ -51,8 +51,9 @@ class EventFederationStore(SQLBaseStore):
|
||||||
event_ids
|
event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(self, txn, event_ids):
|
def _get_auth_chain_ids_txn(self, txn, event_ids, have_ids):
|
||||||
results = set()
|
results = set()
|
||||||
|
have_ids = set(have_ids)
|
||||||
|
|
||||||
base_sql = (
|
base_sql = (
|
||||||
"SELECT auth_id FROM event_auth WHERE event_id = ?"
|
"SELECT auth_id FROM event_auth WHERE event_id = ?"
|
||||||
|
@ -64,6 +65,10 @@ class EventFederationStore(SQLBaseStore):
|
||||||
for f in front:
|
for f in front:
|
||||||
txn.execute(base_sql, (f,))
|
txn.execute(base_sql, (f,))
|
||||||
new_front.update([r[0] for r in txn.fetchall()])
|
new_front.update([r[0] for r in txn.fetchall()])
|
||||||
|
|
||||||
|
new_front -= results
|
||||||
|
new_front -= have_ids
|
||||||
|
|
||||||
front = new_front
|
front = new_front
|
||||||
results.update(front)
|
results.update(front)
|
||||||
|
|
||||||
|
@ -378,3 +383,51 @@ class EventFederationStore(SQLBaseStore):
|
||||||
event_results += new_front
|
event_results += new_front
|
||||||
|
|
||||||
return self._get_events_txn(txn, event_results)
|
return self._get_events_txn(txn, event_results)
|
||||||
|
|
||||||
|
def get_missing_events(self, room_id, earliest_events, latest_events,
|
||||||
|
limit, min_depth):
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_missing_events",
|
||||||
|
self._get_missing_events,
|
||||||
|
room_id, earliest_events, latest_events, limit, min_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
||||||
|
limit, min_depth):
|
||||||
|
|
||||||
|
earliest_events = set(earliest_events)
|
||||||
|
front = set(latest_events) - earliest_events
|
||||||
|
|
||||||
|
event_results = set()
|
||||||
|
|
||||||
|
query = (
|
||||||
|
"SELECT prev_event_id FROM event_edges "
|
||||||
|
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
|
||||||
|
"LIMIT ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
while front and len(event_results) < limit:
|
||||||
|
new_front = set()
|
||||||
|
for event_id in front:
|
||||||
|
txn.execute(
|
||||||
|
query,
|
||||||
|
(room_id, event_id, limit - len(event_results))
|
||||||
|
)
|
||||||
|
|
||||||
|
for e_id, in txn.fetchall():
|
||||||
|
new_front.add(e_id)
|
||||||
|
|
||||||
|
new_front -= earliest_events
|
||||||
|
new_front -= event_results
|
||||||
|
|
||||||
|
front = new_front
|
||||||
|
event_results |= new_front
|
||||||
|
|
||||||
|
events = self._get_events_txn(txn, event_results)
|
||||||
|
|
||||||
|
events = sorted(
|
||||||
|
[ev for ev in events if ev.depth >= min_depth],
|
||||||
|
key=lambda e: e.depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
return events[:limit]
|
||||||
|
|
Loading…
Reference in New Issue