Fix bug where we did not send the full auth chain to people that joined over federation
This commit is contained in:
parent
88484f684f
commit
96779d2490
|
@ -285,7 +285,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, event_id):
|
||||
auth = yield self.store.get_auth_chain(event_id)
|
||||
auth = yield self.store.get_auth_chain([event_id])
|
||||
|
||||
for event in auth:
|
||||
event.signatures.update(
|
||||
|
@ -494,7 +494,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
yield self.replication_layer.send_pdu(new_pdu)
|
||||
|
||||
auth_chain = yield self.store.get_auth_chain(event.event_id)
|
||||
state_ids = [e.event_id for e in event.state_events.values()]
|
||||
auth_chain = yield self.store.get_auth_chain(set(
|
||||
[event.event_id] + state_ids
|
||||
))
|
||||
|
||||
defer.returnValue({
|
||||
"state": event.state_events.values(),
|
||||
|
|
|
@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
|
|||
and backfilling from another server respectively.
|
||||
"""
|
||||
|
||||
def get_auth_chain(self, event_id):
|
||||
def get_auth_chain(self, event_ids):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain",
|
||||
self._get_auth_chain_txn,
|
||||
event_id
|
||||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_txn(self, txn, event_id):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_id)
|
||||
def _get_auth_chain_txn(self, txn, event_ids):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
rows = []
|
||||
|
@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore):
|
|||
|
||||
return self._parse_events_txn(txn, rows)
|
||||
|
||||
def get_auth_chain_ids(self, event_id):
|
||||
def get_auth_chain_ids(self, event_ids):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_id
|
||||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_id):
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids):
|
||||
results = set()
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE %s"
|
||||
)
|
||||
|
||||
front = set([event_id])
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
sql = base_sql % (
|
||||
" OR ".join(["event_id=?"] * len(front)),
|
||||
|
|
Loading…
Reference in New Issue