Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551)

This commit is contained in:
Sean Quah 2021-12-13 16:28:26 +00:00 committed by GitHub
parent 6da8591f2e
commit 1abfb15f07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 69 deletions

1
changelog.d/11551.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/e2e_room_keys.py
|synapse/storage/databases/main/end_to_end_keys.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
@ -189,6 +188,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.end_to_end_keys]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

View File

@ -143,9 +143,6 @@ class DataStore(
("device_lists_outbound_pokes", "stream_id"),
],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
)
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")

View File

@ -14,19 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
cast,
)
import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
class EndToEndKeyBackgroundStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
)
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._allow_device_name_lookup_over_federation = (
@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
rv: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
assert target_device_result is not None
assert target_device_result.keys is not None
target_device_signatures = target_device_result.keys.setdefault(
"signatures", {}
)
@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
self,
txn: LoggingTransaction,
query_list: Collection[Tuple[str, str]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
txn.execute(signature_sql, signature_query_params)
return txn.fetchall()
return cast(
List[
Tuple[
str,
str,
str,
str,
]
],
txn.fetchall(),
)
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _set_e2e_fallback_keys_txn(
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
fallback_keys: JsonDict,
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""Returns a user's cross-signing key.
Args:
@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
their user ID will map to None.
"""
return await self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
# The `Optional` comes from the `@cachedList` decorator.
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
txn: LoggingTransaction,
user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Dict[str, JsonDict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_ids (list[str]): the users whose keys are being requested
txn: db connection
user_ids: the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.
Mapping from user ID to key type to key data.
If a user's cross-signing keys were not found, their user ID will not be in
the dict.
"""
result = {}
result: Dict[str, Dict[str, JsonDict]] = {}
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
user_keys = result.setdefault(user_id, {})
user_keys[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self,
txn: Connection,
keys: Dict[str, Dict[str, dict]],
txn: LoggingTransaction,
keys: Dict[str, Optional[Dict[str, JsonDict]]],
from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
key data. This dict will be modified to add signatures.
from_user_id (str): fetch the signatures made by this user
txn: db connection
keys: a map of user ID to key type to key data.
This dict will be modified to add signatures.
from_user_id: fetch the signatures made by this user
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. The return value will be the same as the keys argument,
with the modifications included.
Mapping from user ID to key type to key data.
The return value will be the same as the keys argument, with the
modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
devices = {}
devices: Dict[Tuple[str, str], str] = {}
for user_id, user_info in keys.items():
if user_info is None:
for user_id, user_keys in keys.items():
if user_keys is None:
continue
for key_type, key in user_info.items():
for key_type, key in user_keys.items():
device_id = None
for k in key["keys"].values():
device_id = k
# `key` ought to be a `CrossSigningKey`, whose .keys property is a
# dictionary with a single entry:
# "algorithm:base64_public_key": "base64_public_key"
# See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
assert isinstance(device_id, str)
devices[(user_id, device_id)] = key_type
for batch in batch_iter(devices.keys(), size=100):
@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
key_id: str = row["key_id"]
target_user_id: str = row["target_user_id"]
target_device_id: str = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
user_info = keys[target_user_id] = keys[target_user_id].copy()
target_user_key = user_info[key_type] = user_info[key_type].copy()
user_keys = keys[target_user_id]
# `user_keys` cannot be `None` because we only fetched signatures for
# users with keys
assert user_keys is not None
user_keys = keys[target_user_id] = user_keys.copy()
target_user_key = user_keys[key_type] = user_keys[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, dict]]]:
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if last_id == current_id:
return [], current_id, False
def _get_all_user_signature_changes_for_remotes_txn(txn):
def _get_all_user_signature_changes_for_remotes_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_simple(
txn, user_id: str, device_id: str, algorithm: str
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_returning(
txn, user_id: str, device_id: str, algorithm: str
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
results = {}
results: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
)
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
def _set_e2e_cross_signing_key_txn(
self,
txn: LoggingTransaction,
user_id: str,
key_type: str,
key: JsonDict,
stream_id: int,
) -> None:
"""Set a user's cross-signing key.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user to set the signing key for
key_type (str): the type of key that is being set: either 'master'
txn: db connection
user_id: the user to set the signing key for
key_type: the type of key that is being set: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
stream_id (int)
key: the key data
stream_id
"""
# the 'key' dict will look something like:
# {
@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
async def set_e2e_cross_signing_key(
self, user_id: str, key_type: str, key: JsonDict
) -> None:
"""Set a user's cross-signing key.
Args:
user_id (str): the user to set the user-signing key for
key_type (str): the type of cross-signing key to set
key (dict): the key data
user_id: the user to set the user-signing key for
key_type: the type of cross-signing key to set
key: the key data
"""
async with self._cross_signing_id_gen.get_next() as stream_id: