Implement device key caching over federation

This commit is contained in:
Erik Johnston 2017-01-26 16:06:54 +00:00
parent 51e9fe36e4
commit c974116f19
13 changed files with 381 additions and 57 deletions

View File

@ -126,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout
)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.

View File

@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):

View File

@ -346,6 +346,32 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content, timeout):

View File

@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,

View File

@ -15,6 +15,7 @@
from synapse.api import errors
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id
from twisted.internet import defer
from ._base import BaseHandler
@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation = hs.get_federation_sender()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, device_id)
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, device_id)
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
attempts += 1
@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_id)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, device_id)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
raise
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_id):
def notify_device_update(self, user_id, device_ids):
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_id, list(hosts)
user_id, device_ids, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation.send_device_messages(host)
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def get_device_list_changes(self, user_id, room_ids, from_key):
@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
defer.returnValue(user_ids_changed)
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
return
logger.info("Got edu: %r", edu_content)
with (yield self._remote_edue_linearizer.queue(user_id)):
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_remote_extremity(user_id)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})

View File

@ -73,8 +73,7 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
remote_queries[user_id] = device_ids
# do the queries
failures = {}
@ -85,9 +84,40 @@ class E2eKeysHandler(object):
if user_id in local_query:
results[user_id] = keys
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
@ -119,7 +149,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
for destination in remote_queries_not_in_cache
]))
defer.returnValue({
@ -259,7 +289,7 @@ class E2eKeysHandler(object):
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
yield self.device_handler.notify_device_update(user_id, device_id)
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:

View File

@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_remote_extremity(self, user_id):
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
now_stream_id = self._device_list_id_gen.get_current_token()
@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
txn.execute(prev_sent_id_sql, (destination, user_id, True))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, result in user_devices.iteritems():
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
key_json = result.get("key_json", None)
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = result.get("device_display_name", None)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
for user_id, user_devices in devices.iteritems():
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(row["user_id"] for row in rows))
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_id, hosts):
def add_device_change_to_streams(self, user_id, device_ids, hosts):
# device_lists_stream
# device_lists_outbound_pokes
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_id, hosts, stream_id,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
host, stream_id,
)
self._simple_insert_txn(
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values={
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
"sent": False,
}
for destination in hosts
for device_id in device_ids
]
)

View File

@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore):
query_params = []
for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?"
query_clause = "user_id = ?"
query_params.append(user_id)
if device_id:
query_clause += " AND k.device_id = ?"
query_clause += " AND device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)

View File

@ -13,18 +13,6 @@
* limitations under the License.
*/
CREATE TABLE device_list_streams_remote (
list_id TEXT NOT NULL,
origin TEXT NOT NULL,
user_id TEXT NOT NULL,
is_full BOOLEAN NOT NULL,
ts BIGINT NOT NULL
);
CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote(
origin, list_id, user_id
);
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache (
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,

View File

@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None)
self.handler = synapse.handlers.device.DeviceHandler(hs)
hs = yield utils.setup_test_homeserver()
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="theresa",
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display"
)
dev = yield self.handler.store.get_device("theresa", device_id)
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks

View File

@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View File

@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View File

@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1"
self.as_url = "some_url"
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool()
self.as_list = [
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
ApplicationServiceStore(hs)
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm: