Clean up verify_json_objects_for_server

This commit is contained in:
Mark Haines 2016-07-27 14:10:43 +01:00
parent c63b1697f4
commit fe1b369946
1 changed files with 74 additions and 67 deletions

View File

@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class Keyring(object): class Keyring(object):
@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
group_id_to_json = {} verify_requests = []
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferreds[group_id] = defer.fail(SynapseError( deferred = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferreds[group_id] = defer.Deferred() deferred = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids) verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group_id_to_group[group_id] = group verify_requests.append(verify_request)
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(group, deferred): def handle_key_deferred(verify_request):
server_name = group.server_name server_name = verify_request.server_name
try: try:
_, _, key_id, verify_key = yield deferred _, key_id, verify_key = yield verify_request.deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = group_id_to_json[group.group_id] json_object = verify_request.json_object
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) lambda _: self.get_server_verify_keys(verify_requests)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_gids = {} server_to_request_ids = {}
def remove_deferreds(res, server_name, group_id): def remove_deferreds(res, server_name, verify_request):
server_to_gids[server_name].discard(group_id) request_id = id(verify_request)
if not server_to_gids[server_name]: server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for g_id, deferred in deferreds.items(): for verify_request in verify_requests:
server_name = group_id_to_group[g_id].server_name server_name = verify_request.server_name
server_to_gids.setdefault(server_name, set()).add(g_id) request_id = id(verify_request)
deferred.addBoth(remove_deferreds, server_name, g_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn( preserve_context_over_fn(handle_key_deferred, verify_request)
handle_key_deferred, for verify_request in verify_requests
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@ -237,62 +242,64 @@ class Keyring(object):
merged_results = {} merged_results = {}
missing_keys = {} missing_keys = {}
for group in group_id_to_group.values(): for verify_request in verify_requests:
missing_keys.setdefault(group.server_name, set()).update( missing_keys.setdefault(verify_request.server_name, set()).update(
group.key_ids verify_request.key_ids
) )
for fn in key_fetch_fns: for fn in key_fetch_fns:
results = yield fn(missing_keys.items()) results = yield fn(missing_keys.items())
merged_results.update(results) merged_results.update(results)
# We now need to figure out which groups we have keys for # We now need to figure out which verify requests we have keys
# and which we don't # for and which we don't
missing_groups = {} missing_keys = {}
for group in group_id_to_group.values(): requests_missing_keys = []
for key_id in group.key_ids: for verify_request in verify_requests:
if key_id in merged_results[group.server_name]: server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext(): with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( verify_request.deferred.callback((
group.group_id, server_name,
group.server_name,
key_id, key_id,
merged_results[group.server_name][key_id], result_keys[key_id],
)) ))
break break
else: else:
missing_groups.setdefault( # The else block is only reached if the loop above
group.server_name, [] # doesn't break.
).append(group) missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_groups: if not missing_keys:
break break
missing_keys = { for verify_request in requests_missing_keys.values():
server_name: set( verify_request.deferred.errback(SynapseError(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401, 401,
"No key for %s with id %s" % ( "No key for %s with id %s" % (
group.server_name, group.key_ids, verify_request.server_name, verify_request.key_ids,
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
def on_err(err): def on_err(err):
for deferred in group_id_to_deferred.values(): for verify_request in verify_requests:
if not deferred.called: if not verify_request.deferred.called:
deferred.errback(err) verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield defer.gatherResults(