implement federation parts of cross-signing

This commit is contained in:
Hubert Chathi 2019-05-22 16:42:00 -04:00
parent 2761731634
commit 8d3542a64e
4 changed files with 179 additions and 10 deletions

View File

@ -366,10 +366,10 @@ class PerDestinationQueue(object):
Edu( Edu(
origin=self._server_name, origin=self._server_name,
destination=self._destination, destination=self._destination,
edu_type="m.device_list_update", edu_type=edu_type,
content=content, content=content,
) )
for content in results for (edu_type, content) in results
] ]
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"

View File

@ -458,7 +458,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id): def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
return {"user_id": user_id, "stream_id": stream_id, "devices": devices} master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
return {
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
"master_key": master_key,
"self_signing_key": self_signing_key
}
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):

View File

@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,10 +51,17 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._edu_updater = SigningKeyEduUpdater(hs, self)
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.signing_key_update", self._edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
hs.get_federation_registry().register_query_handler( federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
@ -343,7 +352,15 @@ class E2eKeysHandler(object):
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
return {"device_keys": res} ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query)
for key, value in iteritems(cross_signing_keys):
ret[key + "_keys"] = value
return ret
@trace @trace
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1047,3 +1064,98 @@ class SignatureListItem:
target_user_id = attr.ib() target_user_id = attr.ib()
target_device_id = attr.ib() target_device_id = attr.ib()
signature = attr.ib() signature = attr.ib()
class SigningKeyEduUpdater(object):
"Handles incoming signing key updates from federation and updates the DB"
def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
@defer.inlineCallbacks
def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
origin (string): the server that sent the EDU
edu_content (dict): the contents of the EDU
"""
user_id = edu_content.pop("user_id")
master_key = edu_content.pop("master_key", None)
self_signing_key = edu_content.pop("self_signing_key", None)
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
self._pending_updates.setdefault(user_id, []).append(
(master_key, self_signing_key, edu_content)
)
yield self._handle_signing_key_updates(user_id)
@defer.inlineCallbacks
def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
user_id (string): the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
return
device_ids = []
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key, edu_content in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
device_id = \
get_verify_key_from_cross_signing_key(master_key)[1].version
device_ids.append(device_id)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
device_id = \
get_verify_key_from_cross_signing_key(self_signing_key)[1].version
device_ids.append(device_id)
yield device_handler.notify_device_update(user_id, device_ids)

View File

@ -94,9 +94,10 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get stream of updates to send to remote servers """Get stream of updates to send to remote servers
Returns: Returns:
Deferred[tuple[int, list[dict]]]: Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the current stream id (ie, the stream id of the last update included in the
response), and the list of updates response), and the list of updates, where each update is a pair of EDU
type and EDU contents
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
@ -129,6 +130,25 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates: if not updates:
return now_stream_id, [] return now_stream_id, []
# get the cross-signing keys of the users the list
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key)
master_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version
}
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing")
key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version
}
# if we have exceeded the limit, we need to exclude any results with the # if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row. # same stream_id as the last row.
if len(updates) > limit: if len(updates) > limit:
@ -158,6 +178,10 @@ class DeviceWorkerStore(SQLBaseStore):
# Stop processing updates # Stop processing updates
break break
if update[1] == master_key_by_user[update[0]]["pubkey"] or \
update[1] == self_signing_key_by_user[update[0]]["pubkey"]:
continue
key = (update[0], update[1]) key = (update[0], update[1])
update_context = update[3] update_context = update[3]
@ -172,16 +196,37 @@ class DeviceWorkerStore(SQLBaseStore):
# means that there are more than limit updates all of which have the same # means that there are more than limit updates all of which have the same
# steam_id. # steam_id.
# figure out which cross-signing keys were changed by intersecting the
# update list with the master/self-signing key by user maps
cross_signing_keys_by_user = {}
for user_id, device_id, stream in updates:
if device_id == master_key_by_user[user_id]["pubkey"]:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = \
master_key_by_user[user_id]["key_info"]
elif device_id == self_signing_key_by_user[user_id]["pubkey"]:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = \
self_signing_key_by_user[user_id]["key_info"]
cross_signing_results = []
# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
cross_signing_results.append(("m.signing_key_update", result))
# That should only happen if a client is spamming the server with new # That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just # devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next # skip that stream_id and return an empty list, and continue with the next
# stream_id next time. # stream_id next time.
if not query_map: if not query_map and not cross_signing_results:
return stream_id_cutoff, [] return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote( results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map destination, from_stream_id, query_map
) )
results.extend(cross_signing_results)
return now_stream_id, results return now_stream_id, results
@ -200,6 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns: Returns:
List: List of device updates List: List of device updates
""" """
# get the list of device updates that need to be sent
sql = """ sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@ -231,7 +277,7 @@ class DeviceWorkerStore(SQLBaseStore):
query_map.keys(), query_map.keys(),
include_all_devices=True, include_all_devices=True,
include_deleted_devices=True, include_deleted_devices=True,
) ) if query_map else {}
results = [] results = []
for user_id, user_devices in iteritems(devices): for user_id, user_devices in iteritems(devices):
@ -262,7 +308,7 @@ class DeviceWorkerStore(SQLBaseStore):
else: else:
result["deleted"] = True result["deleted"] = True
results.append(result) results.append(("m.device_list_update", result))
return results return results