Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551)
This commit is contained in:
parent
6da8591f2e
commit
1abfb15f07
|
@ -0,0 +1 @@
|
|||
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
|
@ -28,7 +28,6 @@ exclude = (?x)
|
|||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|synapse/storage/databases/main/e2e_room_keys.py
|
||||
|synapse/storage/databases/main/end_to_end_keys.py
|
||||
|synapse/storage/databases/main/event_federation.py
|
||||
|synapse/storage/databases/main/event_push_actions.py
|
||||
|synapse/storage/databases/main/events_bg_updates.py
|
||||
|
@ -189,6 +188,9 @@ disallow_untyped_defs = True
|
|||
[mypy-synapse.storage.databases.main.directory]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.storage.databases.main.end_to_end_keys]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.storage.databases.main.events_worker]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -143,9 +143,6 @@ class DataStore(
|
|||
("device_lists_outbound_pokes", "stream_id"),
|
||||
],
|
||||
)
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||
)
|
||||
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
|
|
|
@ -14,19 +14,32 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
|
||||
from synapse.api.constants import DeviceKeyAlgorithms
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
|
|||
|
||||
|
||||
class EndToEndKeyBackgroundStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._allow_device_name_lookup_over_federation = (
|
||||
|
@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
# Build the result structure, un-jsonify the results, and add the
|
||||
# "unsigned" section
|
||||
rv = {}
|
||||
rv: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for user_id, device_keys in results.items():
|
||||
rv[user_id] = {}
|
||||
for device_id, device_info in device_keys.items():
|
||||
|
@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
# add each cross-signing signature to the correct device in the result dict.
|
||||
for (user_id, key_id, device_id, signature) in cross_sigs_result:
|
||||
target_device_result = result[user_id][device_id]
|
||||
# We've only looked up cross-signatures for non-deleted devices with key
|
||||
# data.
|
||||
assert target_device_result is not None
|
||||
assert target_device_result.keys is not None
|
||||
target_device_signatures = target_device_result.keys.setdefault(
|
||||
"signatures", {}
|
||||
)
|
||||
|
@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
return result
|
||||
|
||||
def _get_e2e_device_keys_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
query_list: Collection[Tuple[str, str]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Get information on devices from the database
|
||||
|
||||
|
@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
|
||||
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""Get cross-signing signatures for a given list of devices
|
||||
|
||||
|
@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
)
|
||||
|
||||
txn.execute(signature_sql, signature_query_params)
|
||||
return txn.fetchall()
|
||||
return cast(
|
||||
List[
|
||||
Tuple[
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
]
|
||||
],
|
||||
txn.fetchall(),
|
||||
)
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str, key_ids: List[str]
|
||||
|
@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
||||
"""
|
||||
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
|
||||
set_tag("user_id", user_id)
|
||||
set_tag("device_id", device_id)
|
||||
set_tag("new_keys", new_keys)
|
||||
|
@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
A mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
|
||||
def _count_e2e_one_time_keys(txn):
|
||||
def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
|
||||
sql = (
|
||||
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
|
@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
)
|
||||
|
||||
def _set_e2e_fallback_keys_txn(
|
||||
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
fallback_keys: JsonDict,
|
||||
) -> None:
|
||||
# fallback_keys will usually only have one item in it, so using a for
|
||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||
|
@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
async def get_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
) -> Optional[JsonDict]:
|
||||
"""Returns a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
|
@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
return user_keys.get(key_type)
|
||||
|
||||
@cached(num_args=1)
|
||||
def _get_bare_e2e_cross_signing_keys(self, user_id):
|
||||
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
|
||||
"""Dummy function. Only used to make a cache for
|
||||
_get_bare_e2e_cross_signing_keys_bulk.
|
||||
"""
|
||||
|
@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
)
|
||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
the signatures for the calling user need to be fetched.
|
||||
|
@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
their user ID will map to None.
|
||||
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_bare_e2e_cross_signing_keys_bulk",
|
||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||
user_ids,
|
||||
)
|
||||
|
||||
# The `Optional` comes from the `@cachedList` decorator.
|
||||
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
|
||||
|
||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||
self,
|
||||
txn: Connection,
|
||||
txn: LoggingTransaction,
|
||||
user_ids: Iterable[str],
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
) -> Dict[str, Dict[str, JsonDict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
the signatures for the calling user need to be fetched.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
txn: db connection
|
||||
user_ids: the users whose keys are being requested
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. If a user's cross-signing keys were not found, their user
|
||||
ID will not be in the dict.
|
||||
Mapping from user ID to key type to key data.
|
||||
If a user's cross-signing keys were not found, their user ID will not be in
|
||||
the dict.
|
||||
|
||||
"""
|
||||
result = {}
|
||||
result: Dict[str, Dict[str, JsonDict]] = {}
|
||||
|
||||
for user_chunk in batch_iter(user_ids, 100):
|
||||
clause, params = make_in_list_sql_clause(
|
||||
|
@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
user_id = row["user_id"]
|
||||
key_type = row["keytype"]
|
||||
key = db_to_json(row["keydata"])
|
||||
user_info = result.setdefault(user_id, {})
|
||||
user_info[key_type] = key
|
||||
user_keys = result.setdefault(user_id, {})
|
||||
user_keys[key_type] = key
|
||||
|
||||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_txn(
|
||||
self,
|
||||
txn: Connection,
|
||||
keys: Dict[str, Dict[str, dict]],
|
||||
txn: LoggingTransaction,
|
||||
keys: Dict[str, Optional[Dict[str, JsonDict]]],
|
||||
from_user_id: str,
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||
"""Returns the cross-signing signatures made by a user on a set of keys.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
|
||||
key data. This dict will be modified to add signatures.
|
||||
from_user_id (str): fetch the signatures made by this user
|
||||
txn: db connection
|
||||
keys: a map of user ID to key type to key data.
|
||||
This dict will be modified to add signatures.
|
||||
from_user_id: fetch the signatures made by this user
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. The return value will be the same as the keys argument,
|
||||
with the modifications included.
|
||||
Mapping from user ID to key type to key data.
|
||||
The return value will be the same as the keys argument, with the
|
||||
modifications included.
|
||||
"""
|
||||
|
||||
# find out what cross-signing keys (a.k.a. devices) we need to get
|
||||
# signatures for. This is a map of (user_id, device_id) to key type
|
||||
# (device_id is the key's public part).
|
||||
devices = {}
|
||||
devices: Dict[Tuple[str, str], str] = {}
|
||||
|
||||
for user_id, user_info in keys.items():
|
||||
if user_info is None:
|
||||
for user_id, user_keys in keys.items():
|
||||
if user_keys is None:
|
||||
continue
|
||||
for key_type, key in user_info.items():
|
||||
for key_type, key in user_keys.items():
|
||||
device_id = None
|
||||
for k in key["keys"].values():
|
||||
device_id = k
|
||||
# `key` ought to be a `CrossSigningKey`, whose .keys property is a
|
||||
# dictionary with a single entry:
|
||||
# "algorithm:base64_public_key": "base64_public_key"
|
||||
# See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
|
||||
assert isinstance(device_id, str)
|
||||
devices[(user_id, device_id)] = key_type
|
||||
|
||||
for batch in batch_iter(devices.keys(), size=100):
|
||||
|
@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
# and add the signatures to the appropriate keys
|
||||
for row in rows:
|
||||
key_id = row["key_id"]
|
||||
target_user_id = row["target_user_id"]
|
||||
target_device_id = row["target_device_id"]
|
||||
key_id: str = row["key_id"]
|
||||
target_user_id: str = row["target_user_id"]
|
||||
target_device_id: str = row["target_device_id"]
|
||||
key_type = devices[(target_user_id, target_device_id)]
|
||||
# We need to copy everything, because the result may have come
|
||||
# from the cache. dict.copy only does a shallow copy, so we
|
||||
# need to recursively copy the dicts that will be modified.
|
||||
user_info = keys[target_user_id] = keys[target_user_id].copy()
|
||||
target_user_key = user_info[key_type] = user_info[key_type].copy()
|
||||
user_keys = keys[target_user_id]
|
||||
# `user_keys` cannot be `None` because we only fetched signatures for
|
||||
# users with keys
|
||||
assert user_keys is not None
|
||||
user_keys = keys[target_user_id] = user_keys.copy()
|
||||
|
||||
target_user_key = user_keys[key_type] = user_keys[key_type].copy()
|
||||
if "signatures" in target_user_key:
|
||||
signatures = target_user_key["signatures"] = target_user_key[
|
||||
"signatures"
|
||||
|
@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
async def get_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str], from_user_id: Optional[str] = None
|
||||
) -> Dict[str, Optional[Dict[str, dict]]]:
|
||||
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||
"""Returns the cross-signing keys for a set of users.
|
||||
|
||||
Args:
|
||||
|
@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def _get_all_user_signature_changes_for_remotes_txn(txn):
|
||||
def _get_all_user_signature_changes_for_remotes_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
sql = """
|
||||
SELECT stream_id, from_user_id AS user_id
|
||||
FROM user_signature_stream
|
||||
|
@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_simple(
|
||||
txn, user_id: str, device_id: str, algorithm: str
|
||||
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||
|
||||
|
@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_returning(
|
||||
txn, user_id: str, device_id: str, algorithm: str
|
||||
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that support RETURNING.
|
||||
|
||||
|
@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
key_id, key_json = otk_row
|
||||
return f"{algorithm}:{key_id}", key_json
|
||||
|
||||
results = {}
|
||||
results: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
if self.database_engine.supports_returning:
|
||||
# If we support RETURNING clause we can use a single query that
|
||||
|
@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||
) -> bool:
|
||||
|
@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
or the keys were already in the database.
|
||||
"""
|
||||
|
||||
def _set_e2e_device_keys_txn(txn):
|
||||
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
|
||||
set_tag("user_id", user_id)
|
||||
set_tag("device_id", device_id)
|
||||
set_tag("time_now", time_now)
|
||||
|
@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||
def delete_e2e_keys_by_device_txn(txn):
|
||||
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
||||
log_kv(
|
||||
{
|
||||
"message": "Deleting keys for device",
|
||||
|
@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
||||
def _set_e2e_cross_signing_key_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
key_type: str,
|
||||
key: JsonDict,
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
user_id (str): the user to set the signing key for
|
||||
key_type (str): the type of key that is being set: either 'master'
|
||||
txn: db connection
|
||||
user_id: the user to set the signing key for
|
||||
key_type: the type of key that is being set: either 'master'
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
key (dict): the key data
|
||||
stream_id (int)
|
||||
key: the key data
|
||||
stream_id
|
||||
"""
|
||||
# the 'key' dict will look something like:
|
||||
# {
|
||||
|
@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||
)
|
||||
|
||||
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
async def set_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, key: JsonDict
|
||||
) -> None:
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
user_id (str): the user to set the user-signing key for
|
||||
key_type (str): the type of cross-signing key to set
|
||||
key (dict): the key data
|
||||
user_id: the user to set the user-signing key for
|
||||
key_type: the type of cross-signing key to set
|
||||
key: the key data
|
||||
"""
|
||||
|
||||
async with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
|
|
Loading…
Reference in New Issue