Claim local one-time-keys in bulk (#16565)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2023-10-30 21:25:21 +00:00 committed by GitHub
parent 91aa52c911
commit de981ae567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 308 additions and 114 deletions

View File

@ -0,0 +1 @@
Improve the performance of claiming encryption keys.

View File

@ -753,6 +753,16 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
keys: the body of a /keys/upload request.
Returns a dictionary with one field:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

View File

@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str, int]]
self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@ -1121,120 +1121,52 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple pf:
A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled.
A copy of the input which has not been fulfilled. The returned counts
may be less than the input counts. In this case, the returned counts
are the number of claims that were not fulfilled.
"""
@trace
def _claim_e2e_one_time_key_simple(
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
"""
txn.execute(sql, (user_id, device_id, algorithm, count))
otk_rows = list(txn)
if not otk_rows:
return []
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
@trace
def _claim_e2e_one_time_key_returning(
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""
# We can use RETURNING to do the fetch and DELETE in once step.
sql = """
DELETE FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
)
RETURNING key_id, key_json
"""
txn.execute(
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
otk_rows = list(txn)
if not otk_rows:
return []
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
if isinstance(self.database_engine, PostgresEngine):
# If we can use execute_values we can use a single batch query
# in autocommit mode.
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
db_autocommit = True
else:
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
unfulfilled_claim_counts[user_id, device_id, algorithm] = count
bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_keys_bulk,
query_list,
db_autocommit=True,
)
for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
# Did we get enough OTKs?
missing = [
(user, device, alg, count)
for (user, device, alg), count in unfulfilled_claim_counts.items()
if count > 0
]
else:
for user_id, device_id, algorithm, count in query_list:
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
self._claim_e2e_one_time_key_simple,
user_id,
device_id,
algorithm,
count,
db_autocommit=db_autocommit,
db_autocommit=False,
)
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
@trace
def _claim_e2e_one_time_key_simple(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
"""
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
"""
txn.execute(sql, (user_id, device_id, algorithm, count))
otk_rows = list(txn)
if not otk_rows:
return []
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
@trace
def _claim_e2e_one_time_keys_bulk(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
Args:
query_list: Collection of tuples (user_id, device_id, algorithm, count)
as passed to claim_e2e_one_time_keys.
Returns:
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""
sql = """
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ?
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
DELETE FROM e2e_one_time_keys_json k
WHERE (user_id, device_id, algorithm, key_id) IN (
SELECT user_id, device_id, algorithm, key_id
FROM ranked_keys
WHERE r <= claim_count
)
RETURNING user_id, device_id, algorithm, key_id, key_json;
"""
otk_rows = cast(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return otk_rows
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(

View File

@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
# Apologies to the reader. This test is a little too verbose. It is particularly
# tricky to make assertions neatly with all these nested dictionaries in play.
# Three users with two devices each. Each device uses two algorithms.
# Each algorithm is invoked with two keys.
alice = f"@alice:{self.hs.hostname}"
brian = f"@brian:{self.hs.hostname}"
chris = f"@chris:{self.hs.hostname}"
one_time_keys = {
alice: {
"alice_dev_1": {
"alg1:k1": {"dummy_id": 1},
"alg1:k2": {"dummy_id": 2},
"alg2:k3": {"dummy_id": 3},
"alg2:k4": {"dummy_id": 4},
},
"alice_dev_2": {
"alg1:k5": {"dummy_id": 5},
"alg1:k6": {"dummy_id": 6},
"alg2:k7": {"dummy_id": 7},
"alg2:k8": {"dummy_id": 8},
},
},
brian: {
"brian_dev_1": {
"alg1:k9": {"dummy_id": 9},
"alg1:k10": {"dummy_id": 10},
"alg2:k11": {"dummy_id": 11},
"alg2:k12": {"dummy_id": 12},
},
"brian_dev_2": {
"alg1:k13": {"dummy_id": 13},
"alg1:k14": {"dummy_id": 14},
"alg2:k15": {"dummy_id": 15},
"alg2:k16": {"dummy_id": 16},
},
},
chris: {
"chris_dev_1": {
"alg1:k17": {"dummy_id": 17},
"alg1:k18": {"dummy_id": 18},
"alg2:k19": {"dummy_id": 19},
"alg2:k20": {"dummy_id": 20},
},
"chris_dev_2": {
"alg1:k21": {"dummy_id": 21},
"alg1:k22": {"dummy_id": 22},
"alg2:k23": {"dummy_id": 23},
"alg2:k24": {"dummy_id": 24},
},
},
}
for user_id, devices in one_time_keys.items():
for device_id, keys_dict in devices.items():
counts = self.get_success(
self.handler.upload_keys_for_user(
user_id,
device_id,
{"one_time_keys": keys_dict},
)
)
# The upload should report 2 keys per algorithm.
expected_counts = {
"one_time_key_counts": {
# See count_e2e_one_time_keys for why this is hardcoded.
"signed_curve25519": 0,
"alg1": 2,
"alg2": 2,
},
}
self.assertEqual(counts, expected_counts)
# Claim a variety of keys.
# Raw format, easier to make test assertions about.
claims_to_make = {
(alice, "alice_dev_1", "alg1"): 1,
(alice, "alice_dev_1", "alg2"): 2,
(alice, "alice_dev_2", "alg2"): 1,
(brian, "brian_dev_1", "alg1"): 2,
(brian, "brian_dev_2", "alg2"): 9001,
(chris, "chris_dev_2", "alg2"): 1,
}
# Convert to the format the handler wants.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for (user_id, device_id, algorithm), count in claims_to_make.items():
query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
claim_res = self.get_success(
self.handler.claim_one_time_keys(
query,
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
# No failures, please!
self.assertEqual(claim_res["failures"], {})
# Check that we get exactly the (user, device, algorithm)s we asked for.
got_otks = claim_res["one_time_keys"]
claimed_user_device_algorithms = {
(user_id, device_id, alg_key_id.split(":")[0])
for user_id, devices in got_otks.items()
for device_id, key_dict in devices.items()
for alg_key_id in key_dict
}
self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
# Now check the keys we got are what we expected.
def assertExactlyOneOtk(
user_id: str, device_id: str, *alg_key_pairs: str
) -> None:
key_dict = got_otks[user_id][device_id]
found = 0
for alg_key in alg_key_pairs:
if alg_key in key_dict:
expected_key_json = one_time_keys[user_id][device_id][alg_key]
self.assertEqual(key_dict[alg_key], expected_key_json)
found += 1
self.assertEqual(found, 1)
def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
key_dict = got_otks[user_id][device_id]
for alg_key in alg_key_pairs:
expected_key_json = one_time_keys[user_id][device_id][alg_key]
self.assertEqual(key_dict[alg_key], expected_key_json)
# Expect a single arbitrary key to be returned.
assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
# Now check the unused key counts.
for user_id, devices in one_time_keys.items():
for device_id in devices:
counts_by_alg = self.get_success(
self.store.count_e2e_one_time_keys(user_id, device_id)
)
# Somewhat fiddley to compute the expected count dict.
expected_counts_by_alg = {
"signed_curve25519": 0,
}
for alg in ["alg1", "alg2"]:
claim_count = claims_to_make.get((user_id, device_id, alg), 0)
remaining_count = max(0, 2 - claim_count)
if remaining_count > 0:
expected_counts_by_alg[alg] = remaining_count
self.assertEqual(
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"