Verify state and auth_chain in the same batch

This commit is contained in:
Erik Johnston 2015-06-24 14:51:10 +01:00
parent 74f7b44955
commit 0f2ac80305
2 changed files with 51 additions and 44 deletions

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@ -56,51 +57,60 @@ class FederationBase(object):
deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu):
signed_pdus.append(pdu)
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
return None
# Check local db.
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
def try_local_db(res, pdu):
if not res:
# Check local db.
return self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
return res
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
)
def try_remote(res, pdu):
if not res and pdu.origin != origin:
return self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
).addErrback(lambda e: None)
return res
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
def warn(res, pdu):
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
return res
for pdu, deferred in zip(pdus, deferreds):
deferred.addCallbacks(callback, errback, errbackArgs=[pdu])
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
yield defer.gatherResults(
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus)
if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0]

View File

@ -380,17 +380,14 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
signed_state, signed_auth = yield defer.gatherResults(
[
self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
),
self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
],
consumeErrors=True
).addErrback(unwrapFirstError)
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, state + auth_chain,
outlier=True,
include_none=True,
)
signed_state = [p for p in valid_pdus[:len(state)] if p]
signed_auth = [p for p in valid_pdus[len(state):] if p]
auth_chain.sort(key=lambda e: e.depth)