Merge pull request #2206 from matrix-org/rav/one_time_key_upload_change_sig
Allow clients to upload one-time-keys with new sigs
This commit is contained in:
commit
5331cd150a
|
@ -440,6 +440,16 @@ class FederationServer(FederationBase):
|
|||
key_id: json.loads(json_bytes)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
",".join((
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in json_result.iteritems()
|
||||
for device_id, device_keys in user_keys.iteritems()
|
||||
for key_id, _ in device_keys.iteritems()
|
||||
)),
|
||||
)
|
||||
|
||||
defer.returnValue({"one_time_keys": json_result})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -145,7 +145,7 @@ class E2eKeysHandler(object):
|
|||
"status": 503, "message": e.message
|
||||
}
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
yield make_deferred_yieldable(defer.gatherResults([
|
||||
preserve_fn(do_remote_query)(destination)
|
||||
for destination in remote_queries_not_in_cache
|
||||
]))
|
||||
|
@ -257,11 +257,21 @@ class E2eKeysHandler(object):
|
|||
"status": 503, "message": e.message
|
||||
}
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
yield make_deferred_yieldable(defer.gatherResults([
|
||||
preserve_fn(claim_client_keys)(destination)
|
||||
for destination in remote_queries
|
||||
]))
|
||||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
",".join((
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in json_result.iteritems()
|
||||
for device_id, device_keys in user_keys.iteritems()
|
||||
for key_id, _ in device_keys.iteritems()
|
||||
)),
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"one_time_keys": json_result,
|
||||
"failures": failures
|
||||
|
@ -288,19 +298,8 @@ class E2eKeysHandler(object):
|
|||
|
||||
one_time_keys = keys.get("one_time_keys", None)
|
||||
if one_time_keys:
|
||||
logger.info(
|
||||
"Adding %d one_time_keys for device %r for user %r at %d",
|
||||
len(one_time_keys), device_id, user_id, time_now
|
||||
)
|
||||
key_list = []
|
||||
for key_id, key_json in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((
|
||||
algorithm, key_id, encode_canonical_json(key_json)
|
||||
))
|
||||
|
||||
yield self.store.add_e2e_one_time_keys(
|
||||
user_id, device_id, time_now, key_list
|
||||
yield self._upload_one_time_keys_for_user(
|
||||
user_id, device_id, time_now, one_time_keys,
|
||||
)
|
||||
|
||||
# the device should have been registered already, but it may have been
|
||||
|
@ -313,3 +312,58 @@ class E2eKeysHandler(object):
|
|||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
|
||||
defer.returnValue({"one_time_key_counts": result})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
|
||||
one_time_keys):
|
||||
logger.info(
|
||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||
one_time_keys.keys(), device_id, user_id, time_now,
|
||||
)
|
||||
|
||||
# make a list of (alg, id, key) tuples
|
||||
key_list = []
|
||||
for key_id, key_obj in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((
|
||||
algorithm, key_id, key_obj
|
||||
))
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
existing_key_map = yield self.store.get_e2e_one_time_keys(
|
||||
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||
)
|
||||
|
||||
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
||||
for algorithm, key_id, key in key_list:
|
||||
ex_json = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_json:
|
||||
if not _one_time_keys_match(ex_json, key):
|
||||
raise SynapseError(
|
||||
400,
|
||||
("One time key %s:%s already exists. "
|
||||
"Old key: %s; new key: %r") %
|
||||
(algorithm, key_id, ex_json, key)
|
||||
)
|
||||
else:
|
||||
new_keys.append((algorithm, key_id, encode_canonical_json(key)))
|
||||
|
||||
yield self.store.add_e2e_one_time_keys(
|
||||
user_id, device_id, time_now, new_keys
|
||||
)
|
||||
|
||||
|
||||
def _one_time_keys_match(old_key_json, new_key):
|
||||
old_key = json.loads(old_key_json)
|
||||
|
||||
# if either is a string rather than an object, they must match exactly
|
||||
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
|
||||
return old_key == new_key
|
||||
|
||||
# otherwise, we strip off the 'signatures' if any, because it's legitimate
|
||||
# for different upload attempts to have different signatures.
|
||||
old_key.pop("signatures", None)
|
||||
new_key_copy = dict(new_key)
|
||||
new_key_copy.pop("signatures", None)
|
||||
|
||||
return old_key == new_key_copy
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||
"""Insert some new one time keys for a device.
|
||||
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
|
||||
"""Retrieve a number of one-time keys for a user
|
||||
|
||||
Checks if any of the keys are already inserted, if they are then check
|
||||
if they match. If they don't then we raise an error.
|
||||
Args:
|
||||
user_id(str): id of user to get keys for
|
||||
device_id(str): id of device to get keys for
|
||||
key_ids(list[str]): list of key ids (excluding algorithm) to
|
||||
retrieve
|
||||
|
||||
Returns:
|
||||
deferred resolving to Dict[(str, str), str]: map from (algorithm,
|
||||
key_id) to json string for key
|
||||
"""
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=[key_id for _, key_id, _ in key_list],
|
||||
iterable=key_ids,
|
||||
retcols=("algorithm", "key_id", "key_json",),
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
|
@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
desc="add_e2e_one_time_keys_check",
|
||||
)
|
||||
|
||||
existing_key_map = {
|
||||
defer.returnValue({
|
||||
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||
}
|
||||
})
|
||||
|
||||
new_keys = [] # Keys that we need to insert
|
||||
for algorithm, key_id, json_bytes in key_list:
|
||||
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_bytes:
|
||||
if json_bytes != ex_bytes:
|
||||
raise SynapseError(
|
||||
400, "One time key with key_id %r already exists" % (key_id,)
|
||||
)
|
||||
else:
|
||||
new_keys.append((algorithm, key_id, json_bytes))
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
|
||||
"""Insert some new one time keys for a device. Errors if any of the
|
||||
keys already exist.
|
||||
|
||||
Args:
|
||||
user_id(str): id of user to get keys for
|
||||
device_id(str): id of device to get keys for
|
||||
time_now(long): insertion time to record (ms since epoch)
|
||||
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
|
||||
(algorithm, key_id, key json)
|
||||
"""
|
||||
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
# We are protected from race between lookup and insertion due to
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import mock
|
||||
from synapse.api import errors
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
|
@ -44,3 +45,134 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
local_user = "@boris:" + self.hs.hostname
|
||||
res = yield self.handler.query_local_devices({local_user: None})
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reupload_one_time_keys(self):
|
||||
"""we should be able to re-upload the same keys"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {
|
||||
"key": "key2",
|
||||
"signatures": {"k1": "sig1"}
|
||||
},
|
||||
"alg2:k3": {
|
||||
"key": "key3",
|
||||
},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
|
||||
# we should be able to change the signature without a problem
|
||||
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_change_one_time_keys(self):
|
||||
"""attempts to change one-time-keys should be rejected"""
|
||||
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {
|
||||
"key": "key2",
|
||||
"signatures": {"k1": "sig1"}
|
||||
},
|
||||
"alg2:k3": {
|
||||
"key": "key3",
|
||||
},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
|
||||
)
|
||||
self.fail("No error when changing string key")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
|
||||
)
|
||||
self.fail("No error when replacing dict key with string")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {
|
||||
"one_time_keys": {"alg1:k1": {"key": "key"}}
|
||||
},
|
||||
)
|
||||
self.fail("No error when replacing string key with dict")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {
|
||||
"one_time_keys": {
|
||||
"alg2:k2": {
|
||||
"key": "key3",
|
||||
"signatures": {"k1": "sig1"},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
self.fail("No error when replacing dict key")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
@unittest.DEBUG
|
||||
@defer.inlineCallbacks
|
||||
def test_claim_one_time_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1}
|
||||
})
|
||||
|
||||
res2 = yield self.handler.claim_one_time_keys({
|
||||
"one_time_keys": {
|
||||
local_user: {
|
||||
device_id: "alg1"
|
||||
}
|
||||
}
|
||||
}, timeout=None)
|
||||
self.assertEqual(res2, {
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {
|
||||
device_id: {
|
||||
"alg1:k1": "key1"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue