Implement a batch API for verify_json_objects_for_server

This commit is contained in:
Erik Johnston 2015-06-24 11:21:35 +01:00
parent f859e3ca37
commit a29319fefa
4 changed files with 319 additions and 182 deletions

View File

@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from OpenSSL import crypto
from collections import namedtuple
import urllib
import hashlib
import logging
@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@ -49,141 +52,257 @@ class Keyring(object):
self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
def verify_json_objects_for_server(self, server_and_json):
server_to_key_groupings = {}
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
server_to_key_groupings.setdefault(server_name, []).append(group)
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
deferreds = self.get_server_verify_keys(
group_id_to_group
)
logger.info(
"Deferred count: %d vs. %d",
len(deferreds.items()),
len(server_and_json)
)
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
]
def get_server_verify_keys(self, group_id_to_group):
merged_results = {}
fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
group_deferreds = {
group_id: defer.Deferred()
for group_id in group_id_to_group
}
@defer.inlineCallbacks
def do_iterations():
missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_deferreds.pop(group.group_id).callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_deferreds.pop(group.group_id).errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_deferreds.values():
deferred.errback(err)
group_deferreds.clear()
do_iterations().addErrback(on_err)
return group_deferreds
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args:
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
if cached:
defer.returnValue(cached[0])
return
download = self.key_downloads.get(server_name)
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.key_downloads[server_name] = download
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
r = yield download.observe()
defer.returnValue(r)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
logging.info(
"Unable to getting key %r for %r from %r: %s %s",
key_ids, server_name, perspective_name,
logger.info(
"Unable to get key from %r: %s %s",
perspective_name,
type(e).__name__, str(e.message),
)
perspective_results = yield defer.gatherResults([
results = yield defer.gatherResults([
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
])
for results in perspective_results:
if results is not None:
keys = results
union_of_keys = {}
for result in results:
for server_name, keys in results.items():
union_of_keys.setdefault(server_name, {}).update(keys)
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
defer.returnValue(union_of_keys)
with limiter:
if not keys:
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logging.info(
logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
for key_id in key_ids:
if key_id in keys:
defer.returnValue(keys[key_id])
return
raise ValueError("No verification key found for given key ids")
keys = {server_name: keys}
defer.returnValue(keys)
results = yield defer.gatherResults([
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
])
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@ -204,6 +323,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
},
)
@ -243,23 +363,24 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
response_keys = yield self.process_v2_response(
server_name, perspective_name, response
processed_response = yield self.process_v2_response(
perspective_name, response
)
keys.update(response_keys)
for server_name, response_keys in processed_response:
keys.setdefault(server_name, {}).update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
)
for server_name, response_keys in keys.items():
yield self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
@ -295,25 +416,25 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name,
requested_id=requested_key_id,
requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
for server_name, verify_keys in keys.items():
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=verify_keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json,
requested_id=None):
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@ -335,50 +456,50 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
)
results = {}
for server_name, keys_dict in response_json["signatures"].items():
for key_id in keys_dict:
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
)
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
)
defer.returnValue(response_keys)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
raise ValueError("No verification key found for given key ids")
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
results[server_name] = response_keys
defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):

View File

@ -99,35 +99,50 @@ class FederationBase(object):
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
raise
return failure
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting.",
pdu.event_id,
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)
return deferreds

View File

@ -166,10 +166,7 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus],
consumeErrors=True,
).addErrback(unwrapFirstError)
pdus[:] = yield self._check_sigs_and_hashes(pdus)
defer.returnValue(pdus)
@ -230,7 +227,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@ -402,7 +399,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
logger.warn(
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)

View File

@ -101,7 +101,11 @@ class KeyStore(SQLBaseStore):
(list of VerifyKey): The verification keys.
"""
keys = yield self.get_all_server_verify_keys(server_name)
defer.returnValue([keys[k] for k in key_ids if k in keys])
defer.returnValue({
k: keys[k]
for k in key_ids
if k in keys and keys[k]
})
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,