Refactor `_get_e2e_device_keys_txn` to split large queries (#13956)
Instead of running a single large query, run a single query for user-only lookups and additional queries for batches of user device lookups. Resolves #13580. Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
061739d10f
commit
d65862c41f
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.
|
|
@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
|
|||
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
|
||||
|
||||
|
||||
# These overloads ensure that `columns` and `iterable` values have the same length.
|
||||
# Suppress "Single overload definition, multiple required" complaint.
|
||||
@overload # type: ignore[misc]
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, str],
|
||||
iterable: Collection[Tuple[Any, Any]],
|
||||
) -> Tuple[str, list]:
|
||||
...
|
||||
|
||||
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, ...],
|
||||
iterable: Collection[Tuple[Any, ...]],
|
||||
) -> Tuple[str, list]:
|
||||
"""Returns an SQL clause that checks the given tuple of columns is in the iterable.
|
||||
|
||||
Args:
|
||||
database_engine
|
||||
columns: Names of the columns in the tuple.
|
||||
iterable: The tuples to check the columns against.
|
||||
|
||||
Returns:
|
||||
A tuple of SQL query and the args
|
||||
"""
|
||||
if len(columns) == 0:
|
||||
# Should be unreachable due to mypy, as long as the overloads are set up right.
|
||||
if () in iterable:
|
||||
return "TRUE", []
|
||||
else:
|
||||
return "FALSE", []
|
||||
|
||||
if len(columns) == 1:
|
||||
# Use `= ANY(?)` on postgres.
|
||||
return make_in_list_sql_clause(
|
||||
database_engine, next(iter(columns)), [values[0] for values in iterable]
|
||||
)
|
||||
|
||||
# There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
|
||||
# indices are not used when there are multiple columns. Instead, use an `IN`
|
||||
# expression.
|
||||
#
|
||||
# `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
|
||||
# `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
|
||||
# Thus, the latter is chosen.
|
||||
|
||||
if len(iterable) == 0:
|
||||
# A 0-length `VALUES` list is not allowed in sqlite or postgres.
|
||||
# Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
|
||||
# allowed in postgres.
|
||||
return "FALSE", []
|
||||
|
||||
tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
|
||||
return "(%s) IN (VALUES %s)" % (
|
||||
",".join(column for column in columns),
|
||||
",".join(tuple_sql for _ in iterable),
|
||||
), [value for values in iterable for value in values]
|
||||
|
||||
|
||||
KV = TypeVar("KV")
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
make_tuple_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
def _get_e2e_device_keys_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
query_list: Collection[Tuple[str, str]],
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
|
@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
cross-signing signatures which have been added subsequently (for which, see
|
||||
get_e2e_device_keys_and_signatures)
|
||||
"""
|
||||
query_clauses = []
|
||||
query_params = []
|
||||
query_clauses: List[str] = []
|
||||
query_params_list: List[List[object]] = []
|
||||
|
||||
if include_all_devices is False:
|
||||
include_deleted_devices = False
|
||||
|
@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
if include_deleted_devices:
|
||||
deleted_devices = set(query_list)
|
||||
|
||||
# Split the query list into queries for users and queries for particular
|
||||
# devices.
|
||||
user_list = []
|
||||
user_device_list = []
|
||||
for (user_id, device_id) in query_list:
|
||||
query_clause = "user_id = ?"
|
||||
query_params.append(user_id)
|
||||
if device_id is None:
|
||||
user_list.append(user_id)
|
||||
else:
|
||||
user_device_list.append((user_id, device_id))
|
||||
|
||||
if device_id is not None:
|
||||
query_clause += " AND device_id = ?"
|
||||
query_params.append(device_id)
|
||||
if user_list:
|
||||
user_id_in_list_clause, user_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", user_list
|
||||
)
|
||||
query_clauses.append(user_id_in_list_clause)
|
||||
query_params_list.append(user_args)
|
||||
|
||||
query_clauses.append(query_clause)
|
||||
|
||||
sql = (
|
||||
"SELECT user_id, device_id, "
|
||||
" d.display_name, "
|
||||
" k.key_json"
|
||||
" FROM devices d"
|
||||
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
|
||||
" WHERE %s AND NOT d.hidden"
|
||||
) % (
|
||||
"LEFT" if include_all_devices else "INNER",
|
||||
" OR ".join("(" + q + ")" for q in query_clauses),
|
||||
)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
if user_device_list:
|
||||
# Divide the device queries into batches, to avoid excessively large
|
||||
# queries.
|
||||
for user_device_batch in batch_iter(user_device_list, 1024):
|
||||
(
|
||||
user_device_id_in_list_clause,
|
||||
user_device_args,
|
||||
) = make_tuple_in_list_sql_clause(
|
||||
txn.database_engine, ("user_id", "device_id"), user_device_batch
|
||||
)
|
||||
query_clauses.append(user_device_id_in_list_clause)
|
||||
query_params_list.append(user_device_args)
|
||||
|
||||
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
|
||||
for (user_id, device_id, display_name, key_json) in txn:
|
||||
if include_deleted_devices:
|
||||
deleted_devices.remove((user_id, device_id))
|
||||
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
|
||||
display_name, db_to_json(key_json) if key_json else None
|
||||
for query_clause, query_params in zip(query_clauses, query_params_list):
|
||||
sql = (
|
||||
"SELECT user_id, device_id, "
|
||||
" d.display_name, "
|
||||
" k.key_json"
|
||||
" FROM devices d"
|
||||
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
|
||||
" WHERE %s AND NOT d.hidden"
|
||||
) % (
|
||||
"LEFT" if include_all_devices else "INNER",
|
||||
query_clause,
|
||||
)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
|
||||
for (user_id, device_id, display_name, key_json) in txn:
|
||||
assert device_id is not None
|
||||
if include_deleted_devices:
|
||||
deleted_devices.remove((user_id, device_id))
|
||||
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
|
||||
display_name, db_to_json(key_json) if key_json else None
|
||||
)
|
||||
|
||||
if include_deleted_devices:
|
||||
for user_id, device_id in deleted_devices:
|
||||
if device_id is None:
|
||||
continue
|
||||
result.setdefault(user_id, {})[device_id] = None
|
||||
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue