look up cross-signing keys from the DB in bulk (#6486)
This commit is contained in:
parent
5bfd8855d6
commit
cb2db17994
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of looking up cross-signing keys.
|
|
@ -264,6 +264,7 @@ class E2eKeysHandler(object):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||||
"""Get cross-signing keys for users from the database
|
"""Get cross-signing keys for users from the database
|
||||||
|
|
||||||
|
@ -283,14 +284,32 @@ class E2eKeysHandler(object):
|
||||||
self_signing_keys = {}
|
self_signing_keys = {}
|
||||||
user_signing_keys = {}
|
user_signing_keys = {}
|
||||||
|
|
||||||
# Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
|
user_ids = list(query)
|
||||||
return defer.succeed(
|
|
||||||
{
|
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
|
||||||
|
|
||||||
|
for user_id, user_info in keys.items():
|
||||||
|
if user_info is None:
|
||||||
|
continue
|
||||||
|
if "master" in user_info:
|
||||||
|
master_keys[user_id] = user_info["master"]
|
||||||
|
if "self_signing" in user_info:
|
||||||
|
self_signing_keys[user_id] = user_info["self_signing"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
from_user_id in keys
|
||||||
|
and keys[from_user_id] is not None
|
||||||
|
and "user_signing" in keys[from_user_id]
|
||||||
|
):
|
||||||
|
# users can see other users' master and self-signing keys, but can
|
||||||
|
# only see their own user-signing keys
|
||||||
|
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
|
||||||
|
|
||||||
|
return {
|
||||||
"master_keys": master_keys,
|
"master_keys": master_keys,
|
||||||
"self_signing_keys": self_signing_keys,
|
"self_signing_keys": self_signing_keys,
|
||||||
"user_signing_keys": user_signing_keys,
|
"user_signing_keys": user_signing_keys,
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -14,15 +14,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json, json
|
||||||
|
|
||||||
|
from twisted.enterprise.adbapi import Connection
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
|
@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
Args:
|
Args:
|
||||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
user_id (str): the user whose key is being requested
|
user_id (str): the user whose key is being requested
|
||||||
key_type (str): the type of key that is being set: either 'master'
|
key_type (str): the type of key that is being requested: either 'master'
|
||||||
for a master key, 'self_signing' for a self-signing key, or
|
for a master key, 'self_signing' for a self-signing key, or
|
||||||
'user_signing' for a user-signing key
|
'user_signing' for a user-signing key
|
||||||
from_user_id (str): if specified, signatures made by this user on
|
from_user_id (str): if specified, signatures made by this user on
|
||||||
|
@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
"""Returns a user's cross-signing key.
|
"""Returns a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user whose self-signing key is being requested
|
user_id (str): the user whose key is being requested
|
||||||
key_type (str): the type of cross-signing key to get
|
key_type (str): the type of key that is being requested: either 'master'
|
||||||
|
for a master key, 'self_signing' for a self-signing key, or
|
||||||
|
'user_signing' for a user-signing key
|
||||||
from_user_id (str): if specified, signatures made by this user on
|
from_user_id (str): if specified, signatures made by this user on
|
||||||
the self-signing key will be included in the result
|
the self-signing key will be included in the result
|
||||||
|
|
||||||
|
@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
from_user_id,
|
from_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(num_args=1)
|
||||||
|
def _get_bare_e2e_cross_signing_keys(self, user_id):
|
||||||
|
"""Dummy function. Only used to make a cache for
|
||||||
|
_get_bare_e2e_cross_signing_keys_bulk.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@cachedList(
|
||||||
|
cached_method_name="_get_bare_e2e_cross_signing_keys",
|
||||||
|
list_name="user_ids",
|
||||||
|
num_args=1,
|
||||||
|
)
|
||||||
|
def _get_bare_e2e_cross_signing_keys_bulk(
|
||||||
|
self, user_ids: List[str]
|
||||||
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
the signatures for the calling user need to be fetched.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids (list[str]): the users whose keys are being requested
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||||
|
data. If a user's cross-signing keys were not found, either
|
||||||
|
their user ID will not be in the dict, or their user ID will map
|
||||||
|
to None.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_bare_e2e_cross_signing_keys_bulk",
|
||||||
|
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||||
|
user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||||
|
self, txn: Connection, user_ids: List[str],
|
||||||
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
the signatures for the calling user need to be fetched.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
user_ids (list[str]): the users whose keys are being requested
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||||
|
data. If a user's cross-signing keys were not found, their user
|
||||||
|
ID will not be in the dict.
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
batch_size = 100
|
||||||
|
chunks = [
|
||||||
|
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
|
||||||
|
]
|
||||||
|
for user_chunk in chunks:
|
||||||
|
sql = """
|
||||||
|
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
|
||||||
|
FROM e2e_cross_signing_keys k
|
||||||
|
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
|
||||||
|
FROM e2e_cross_signing_keys
|
||||||
|
GROUP BY user_id, keytype) s
|
||||||
|
USING (user_id, stream_id, keytype)
|
||||||
|
WHERE k.user_id IN (%s)
|
||||||
|
""" % (
|
||||||
|
",".join("?" for u in user_chunk),
|
||||||
|
)
|
||||||
|
query_params = []
|
||||||
|
query_params.extend(user_chunk)
|
||||||
|
|
||||||
|
txn.execute(sql, query_params)
|
||||||
|
rows = self.db.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_id = row["user_id"]
|
||||||
|
key_type = row["keytype"]
|
||||||
|
key = json.loads(row["keydata"])
|
||||||
|
user_info = result.setdefault(user_id, {})
|
||||||
|
user_info[key_type] = key
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_e2e_cross_signing_signatures_txn(
|
||||||
|
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
|
||||||
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
|
"""Returns the cross-signing signatures made by a user on a set of keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||||
|
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
|
||||||
|
key data. This dict will be modified to add signatures.
|
||||||
|
from_user_id (str): fetch the signatures made by this user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||||
|
data. The return value will be the same as the keys argument,
|
||||||
|
with the modifications included.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# find out what cross-signing keys (a.k.a. devices) we need to get
|
||||||
|
# signatures for. This is a map of (user_id, device_id) to key type
|
||||||
|
# (device_id is the key's public part).
|
||||||
|
devices = {}
|
||||||
|
|
||||||
|
for user_id, user_info in keys.items():
|
||||||
|
if user_info is None:
|
||||||
|
continue
|
||||||
|
for key_type, key in user_info.items():
|
||||||
|
device_id = None
|
||||||
|
for k in key["keys"].values():
|
||||||
|
device_id = k
|
||||||
|
devices[(user_id, device_id)] = key_type
|
||||||
|
|
||||||
|
device_list = list(devices)
|
||||||
|
|
||||||
|
# split into batches
|
||||||
|
batch_size = 100
|
||||||
|
chunks = [
|
||||||
|
device_list[i : i + batch_size]
|
||||||
|
for i in range(0, len(device_list), batch_size)
|
||||||
|
]
|
||||||
|
for user_chunk in chunks:
|
||||||
|
sql = """
|
||||||
|
SELECT target_user_id, target_device_id, key_id, signature
|
||||||
|
FROM e2e_cross_signing_signatures
|
||||||
|
WHERE user_id = ?
|
||||||
|
AND (%s)
|
||||||
|
""" % (
|
||||||
|
" OR ".join(
|
||||||
|
"(target_user_id = ? AND target_device_id = ?)" for d in devices
|
||||||
|
)
|
||||||
|
)
|
||||||
|
query_params = [from_user_id]
|
||||||
|
for item in devices:
|
||||||
|
# item is a (user_id, device_id) tuple
|
||||||
|
query_params.extend(item)
|
||||||
|
|
||||||
|
txn.execute(sql, query_params)
|
||||||
|
rows = self.db.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
# and add the signatures to the appropriate keys
|
||||||
|
for row in rows:
|
||||||
|
key_id = row["key_id"]
|
||||||
|
target_user_id = row["target_user_id"]
|
||||||
|
target_device_id = row["target_device_id"]
|
||||||
|
key_type = devices[(target_user_id, target_device_id)]
|
||||||
|
# We need to copy everything, because the result may have come
|
||||||
|
# from the cache. dict.copy only does a shallow copy, so we
|
||||||
|
# need to recursively copy the dicts that will be modified.
|
||||||
|
user_info = keys[target_user_id] = keys[target_user_id].copy()
|
||||||
|
target_user_key = user_info[key_type] = user_info[key_type].copy()
|
||||||
|
if "signatures" in target_user_key:
|
||||||
|
signatures = target_user_key["signatures"] = target_user_key[
|
||||||
|
"signatures"
|
||||||
|
].copy()
|
||||||
|
if from_user_id in signatures:
|
||||||
|
user_sigs = signatures[from_user_id] = signatures[from_user_id]
|
||||||
|
user_sigs[key_id] = row["signature"]
|
||||||
|
else:
|
||||||
|
signatures[from_user_id] = {key_id: row["signature"]}
|
||||||
|
else:
|
||||||
|
target_user_key["signatures"] = {
|
||||||
|
from_user_id: {key_id: row["signature"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_e2e_cross_signing_keys_bulk(
|
||||||
|
self, user_ids: List[str], from_user_id: str = None
|
||||||
|
) -> defer.Deferred:
|
||||||
|
"""Returns the cross-signing keys for a set of users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids (list[str]): the users whose keys are being requested
|
||||||
|
from_user_id (str): if specified, signatures made by this user on
|
||||||
|
the self-signing keys will be included in the result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
|
||||||
|
key data. If a user's cross-signing keys were not found, either
|
||||||
|
their user ID will not be in the dict, or their user ID will map
|
||||||
|
to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||||
|
|
||||||
|
if from_user_id:
|
||||||
|
result = yield self.db.runInteraction(
|
||||||
|
"get_e2e_cross_signing_signatures",
|
||||||
|
self._get_e2e_cross_signing_signatures_txn,
|
||||||
|
result,
|
||||||
|
from_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
||||||
"""Return a list of changes from the user signature stream to notify remotes.
|
"""Return a list of changes from the user signature stream to notify remotes.
|
||||||
Note that the user signature stream represents when a user signs their
|
Note that the user signature stream represents when a user signs their
|
||||||
|
@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||||
"""Set a user's cross-signing key.
|
"""Set a user's cross-signing key.
|
||||||
|
|
||||||
|
|
|
@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
|
||||||
else:
|
else:
|
||||||
self.function_to_call = orig
|
self.function_to_call = orig
|
||||||
|
|
||||||
arg_spec = inspect.getargspec(orig)
|
arg_spec = inspect.getfullargspec(orig)
|
||||||
all_args = arg_spec.args
|
all_args = arg_spec.args
|
||||||
|
|
||||||
if "cache_context" in all_args:
|
if "cache_context" in all_args:
|
||||||
|
|
|
@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||||
|
|
||||||
test_replace_master_key.skip = (
|
|
||||||
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_reupload_signatures(self):
|
def test_reupload_signatures(self):
|
||||||
"""re-uploading a signature should not fail"""
|
"""re-uploading a signature should not fail"""
|
||||||
|
@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||||
],
|
],
|
||||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||||
)
|
)
|
||||||
|
|
||||||
test_upload_signatures.skip = (
|
|
||||||
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue