Verify state and auth_chain in the same batch
This commit is contained in:
parent
74f7b44955
commit
0f2ac80305
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class FederationBase(object):
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
include_none=False):
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
|
@ -56,51 +57,60 @@ class FederationBase(object):
|
|||
deferreds = self._check_sigs_and_hashes(pdus)
|
||||
|
||||
def callback(pdu):
|
||||
signed_pdus.append(pdu)
|
||||
return pdu
|
||||
|
||||
def errback(failure, pdu):
|
||||
failure.trap(SynapseError)
|
||||
return None
|
||||
|
||||
# Check local db.
|
||||
new_pdu = yield self.store.get_event(
|
||||
pdu.event_id,
|
||||
allow_rejected=True,
|
||||
allow_none=True,
|
||||
)
|
||||
if new_pdu:
|
||||
signed_pdus.append(new_pdu)
|
||||
return
|
||||
def try_local_db(res, pdu):
|
||||
if not res:
|
||||
# Check local db.
|
||||
return self.store.get_event(
|
||||
pdu.event_id,
|
||||
allow_rejected=True,
|
||||
allow_none=True,
|
||||
)
|
||||
return res
|
||||
|
||||
# Check pdu.origin
|
||||
if pdu.origin != origin:
|
||||
try:
|
||||
new_pdu = yield self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
def try_remote(res, pdu):
|
||||
if not res and pdu.origin != origin:
|
||||
return self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
).addErrback(lambda e: None)
|
||||
return res
|
||||
|
||||
if new_pdu:
|
||||
signed_pdus.append(new_pdu)
|
||||
return
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.warn(
|
||||
"Failed to find copy of %s with valid signature",
|
||||
pdu.event_id,
|
||||
)
|
||||
def warn(res, pdu):
|
||||
if not res:
|
||||
logger.warn(
|
||||
"Failed to find copy of %s with valid signature",
|
||||
pdu.event_id,
|
||||
)
|
||||
return res
|
||||
|
||||
for pdu, deferred in zip(pdus, deferreds):
|
||||
deferred.addCallbacks(callback, errback, errbackArgs=[pdu])
|
||||
deferred.addCallbacks(
|
||||
callback, errback, errbackArgs=[pdu]
|
||||
).addCallback(
|
||||
try_local_db, pdu
|
||||
).addCallback(
|
||||
try_remote, pdu
|
||||
).addCallback(
|
||||
warn, pdu
|
||||
)
|
||||
|
||||
yield defer.gatherResults(
|
||||
valid_pdus = yield defer.gatherResults(
|
||||
deferreds,
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(signed_pdus)
|
||||
if include_none:
|
||||
defer.returnValue(valid_pdus)
|
||||
else:
|
||||
defer.returnValue([p for p in valid_pdus if p])
|
||||
|
||||
def _check_sigs_and_hash(self, pdu):
|
||||
return self._check_sigs_and_hashes([pdu])[0]
|
||||
|
|
|
@ -380,17 +380,14 @@ class FederationClient(FederationBase):
|
|||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
signed_state, signed_auth = yield defer.gatherResults(
|
||||
[
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, state, outlier=True
|
||||
),
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
)
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, state + auth_chain,
|
||||
outlier=True,
|
||||
include_none=True,
|
||||
)
|
||||
|
||||
signed_state = [p for p in valid_pdus[:len(state)] if p]
|
||||
signed_auth = [p for p in valid_pdus[len(state):] if p]
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
|
|
Loading…
Reference in New Issue