Include device name in /keys/query response

Add an 'unsigned' section which includes the device display name.
This commit is contained in:
Richard van der Hoff 2016-08-03 07:46:57 +01:00
parent 91fa69e029
commit 68264d7404
3 changed files with 143 additions and 20 deletions

View File

@ -117,10 +117,15 @@ class E2eKeysHandler(object):
results = yield self.store.get_e2e_device_keys(local_query) results = yield self.store.get_e2e_device_keys(local_query)
# un-jsonify the results # Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items(): for device_id, device_info in device_keys.items():
result_dict[user_id][device_id] = json.loads(json_bytes) r = json.loads(device_info["key_json"])
r["unsigned"] = {
"device_display_name": device_info["device_display_name"],
}
result_dict[user_id][device_id] = r
defer.returnValue(result_dict) defer.returnValue(result_dict)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import twisted.internet.defer import twisted.internet.defer
@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
key json byte strings. dict containing "key_json", "device_display_name".
""" """
def _get_e2e_device_keys(txn): if not query_list:
result = {} return {}
for user_id, device_id in query_list:
user_result = result.setdefault(user_id, {}) return self.runInteraction(
keyvalues = {"user_id": user_id} "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
if device_id:
keyvalues["device_id"] = device_id
rows = self._simple_select_list_txn(
txn, table="e2e_device_keys_json",
keyvalues=keyvalues,
retcols=["device_id", "key_json"]
) )
def _get_e2e_device_keys_txn(self, txn, query_list):
query_clauses = []
query_params = []
for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?"
query_params.append(user_id)
if device_id:
query_clause += " AND k.device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
sql = (
"SELECT k.user_id, k.device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM e2e_device_keys_json k"
" LEFT JOIN devices d ON d.user_id = k.user_id"
" AND d.device_id = k.device_id"
" WHERE %s"
) % (
" OR ".join("("+q+")" for q in query_clauses)
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict)
for row in rows: for row in rows:
user_result[row["device_id"]] = row["key_json"] result[row["user_id"]][row["device_id"]] = row
return result return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import tests.unittest
import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"device_display_name": None,
}, dev)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
yield self.store.store_device(
"user", "device", "display_name"
)
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"device_display_name": "display_name",
}, dev)
@defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys(
"user1", "device2", now, 'json12')
yield self.store.set_e2e_device_keys(
"user2", "device1", now, 'json21')
yield self.store.set_e2e_device_keys(
"user2", "device2", now, 'json22')
res = yield self.store.get_e2e_device_keys((("user1", "device1"),
("user2", "device2")))
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
self.assertNotIn("device2", res["user1"])
self.assertIn("user2", res)
self.assertNotIn("device1", res["user2"])
self.assertIn("device2", res["user2"])