Return immutable objects for cachedList decorators (#16350)

This commit is contained in:
Patrick Cloke 2023-09-19 15:26:44 -04:00 committed by GitHub
parent 5a66ff2f5c
commit d7c89c5908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 134 additions and 100 deletions

1
changelog.d/16350.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -23,7 +23,7 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import DeviceListUpdates, JsonDict, UserID from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -379,8 +379,8 @@ class AppServiceTransaction:
service: ApplicationService, service: ApplicationService,
id: int, id: int,
events: Sequence[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonMapping],
to_device_messages: List[JsonDict], to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates, device_list_summary: DeviceListUpdates,

View File

@ -41,7 +41,7 @@ from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING: if TYPE_CHECKING:
@ -306,8 +306,8 @@ class ApplicationServiceApi(SimpleHttpClient):
self, self,
service: "ApplicationService", service: "ApplicationService",
events: Sequence[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonMapping],
to_device_messages: List[JsonDict], to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates, device_list_summary: DeviceListUpdates,

View File

@ -73,7 +73,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import DeviceListUpdates, JsonDict from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import Clock from synapse.util import Clock
if TYPE_CHECKING: if TYPE_CHECKING:
@ -121,8 +121,8 @@ class ApplicationServiceScheduler:
self, self,
appservice: ApplicationService, appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None, events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None, ephemeral: Optional[Collection[JsonMapping]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonMapping]] = None,
device_list_summary: Optional[DeviceListUpdates] = None, device_list_summary: Optional[DeviceListUpdates] = None,
) -> None: ) -> None:
""" """
@ -180,9 +180,9 @@ class _ServiceQueuer:
# dict of {service_id: [events]} # dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {} self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]} # dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {} self.queued_ephemeral: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [to_device_message_json]} # dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [device_list_summary]} # dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
@ -293,8 +293,8 @@ class _ServiceQueuer:
self, self,
service: ApplicationService, service: ApplicationService,
events: Iterable[EventBase], events: Iterable[EventBase],
ephemerals: Iterable[JsonDict], ephemerals: Iterable[JsonMapping],
to_device_messages: Iterable[JsonDict], to_device_messages: Iterable[JsonMapping],
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
""" """
Given a list of the events, ephemeral messages and to-device messages, Given a list of the events, ephemeral messages and to-device messages,
@ -364,8 +364,8 @@ class _TransactionController:
self, self,
service: ApplicationService, service: ApplicationService,
events: Sequence[EventBase], events: Sequence[EventBase],
ephemeral: Optional[List[JsonDict]] = None, ephemeral: Optional[List[JsonMapping]] = None,
to_device_messages: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonMapping]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None, device_list_summary: Optional[DeviceListUpdates] = None,

View File

@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import ( from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonDict, JsonDict,
JsonMapping,
RoomAlias, RoomAlias,
RoomStreamToken, RoomStreamToken,
StreamKeyType, StreamKeyType,
@ -397,7 +398,7 @@ class ApplicationServicesHandler:
async def _handle_typing( async def _handle_typing(
self, service: ApplicationService, new_token: int self, service: ApplicationService, new_token: int
) -> List[JsonDict]: ) -> List[JsonMapping]:
""" """
Return the typing events since the given stream token that the given application Return the typing events since the given stream token that the given application
service should receive. service should receive.
@ -432,7 +433,7 @@ class ApplicationServicesHandler:
async def _handle_receipts( async def _handle_receipts(
self, service: ApplicationService, new_token: int self, service: ApplicationService, new_token: int
) -> List[JsonDict]: ) -> List[JsonMapping]:
""" """
Return the latest read receipts that the given application service should receive. Return the latest read receipts that the given application service should receive.
@ -471,7 +472,7 @@ class ApplicationServicesHandler:
service: ApplicationService, service: ApplicationService,
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
new_token: Optional[int], new_token: Optional[int],
) -> List[JsonDict]: ) -> List[JsonMapping]:
""" """
Return the latest presence updates that the given application service should receive. Return the latest presence updates that the given application service should receive.
@ -491,7 +492,7 @@ class ApplicationServicesHandler:
A list of json dictionaries containing data derived from the presence events A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service. that should be sent to the given application service.
""" """
events: List[JsonDict] = [] events: List[JsonMapping] = []
presence_source = self.event_sources.sources.presence presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence" service, "presence"

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping,
UserID, UserID,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
@ -272,11 +273,7 @@ class E2eKeysHandler:
delay_cancellation=True, delay_cancellation=True,
) )
ret = {"device_keys": results, "failures": failures} return {"device_keys": results, "failures": failures, **cross_signing_keys}
ret.update(cross_signing_keys)
return ret
@trace @trace
async def _query_devices_for_destination( async def _query_devices_for_destination(
@ -408,7 +405,7 @@ class E2eKeysHandler:
@cancellable @cancellable
async def get_cross_signing_keys_from_cache( async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str] self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, JsonMapping]]:
"""Get cross-signing keys for users from the database """Get cross-signing keys for users from the database
Args: Args:
@ -551,16 +548,13 @@ class E2eKeysHandler:
self.config.federation.allow_device_name_lookup_over_federation self.config.federation.allow_device_name_lookup_over_federation
), ),
) )
ret = {"device_keys": res}
# add in the cross-signing keys # add in the cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache( cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None device_keys_query, None
) )
ret.update(cross_signing_keys) return {"device_keys": res, **cross_signing_keys}
return ret
async def claim_local_one_time_keys( async def claim_local_one_time_keys(
self, self,
@ -1127,7 +1121,7 @@ class E2eKeysHandler:
user_id: str, user_id: str,
master_key_id: str, master_key_id: str,
signed_master_key: JsonDict, signed_master_key: JsonDict,
stored_master_key: JsonDict, stored_master_key: JsonMapping,
devices: Dict[str, Dict[str, JsonDict]], devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]: ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices. """Check signatures of a user's master key made by their devices.
@ -1278,7 +1272,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key( async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Tuple[JsonDict, str, VerifyKey]: ) -> Tuple[JsonMapping, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key. """Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage. First, attempt to fetch the cross-signing public key from storage.
@ -1333,7 +1327,7 @@ class E2eKeysHandler:
self, self,
user: UserID, user: UserID,
desired_key_type: str, desired_key_type: str,
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database """Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys Only the key specified by `key_type` will be returned, while all retrieved keys
@ -1474,7 +1468,7 @@ def _check_device_signature(
user_id: str, user_id: str,
verify_key: VerifyKey, verify_key: VerifyKey,
signed_device: JsonDict, signed_device: JsonDict,
stored_device: JsonDict, stored_device: JsonMapping,
) -> None: ) -> None:
"""Check that a signature on a device or cross-signing key is correct and """Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an matches the copy of the device/key that we have stored. Throws an

View File

@ -32,6 +32,7 @@ from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping,
Requester, Requester,
RoomStreamToken, RoomStreamToken,
StreamKeyType, StreamKeyType,
@ -454,7 +455,7 @@ class InitialSyncHandler:
for s in states for s in states
] ]
async def get_receipts() -> List[JsonDict]: async def get_receipts() -> List[JsonMapping]:
receipts = await self.store.get_linearized_receipts_for_room( receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key room_id, to_key=now_token.receipt_key
) )

View File

@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping,
ReadReceipt, ReadReceipt,
StreamKeyType, StreamKeyType,
UserID, UserID,
@ -204,15 +205,15 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource(EventSource[int, JsonDict]): class ReceiptEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.config = hs.config self.config = hs.config
@staticmethod @staticmethod
def filter_out_private_receipts( def filter_out_private_receipts(
rooms: Sequence[JsonDict], user_id: str rooms: Sequence[JsonMapping], user_id: str
) -> List[JsonDict]: ) -> List[JsonMapping]:
""" """
Filters a list of serialized receipts (as returned by /sync and /initialSync) Filters a list of serialized receipts (as returned by /sync and /initialSync)
and removes private read receipts of other users. and removes private read receipts of other users.
@ -229,7 +230,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
The same as rooms, but filtered. The same as rooms, but filtered.
""" """
result = [] result: List[JsonMapping] = []
# Iterate through each room's receipt content. # Iterate through each room's receipt content.
for room in rooms: for room in rooms:
@ -282,7 +283,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
room_ids: Iterable[str], room_ids: Iterable[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonMapping], int]:
from_key = int(from_key) from_key = int(from_key)
to_key = self.get_current_key() to_key = self.get_current_key()
@ -301,7 +302,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
async def get_new_events_as( async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService self, from_key: int, to_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new read receipt events that an appservice """Returns a set of new read receipt events that an appservice
may be interested in. may be interested in.

View File

@ -235,7 +235,7 @@ class SyncResult:
archived: List[ArchivedSyncResult] archived: List[ArchivedSyncResult]
to_device: List[JsonDict] to_device: List[JsonDict]
device_lists: DeviceListUpdates device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str] device_unused_fallback_key_types: List[str]
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -1558,7 +1558,7 @@ class SyncHandler:
logger.debug("Fetching OTK data") logger.debug("Fetching OTK data")
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_keys_count: JsonDict = {} one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = [] unused_fallback_key_types: List[str] = []
if device_id: if device_id:
# TODO: We should have a way to let clients differentiate between the states of: # TODO: We should have a way to let clients differentiate between the states of:

View File

@ -26,7 +26,14 @@ from synapse.metrics.background_process_metrics import (
) )
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types import (
JsonDict,
JsonMapping,
Requester,
StrCollection,
StreamKeyType,
UserID,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.retryutils import filter_destinations_by_retry_limiter from synapse.util.retryutils import filter_destinations_by_retry_limiter
@ -487,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication") raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource(EventSource[int, JsonDict]): class TypingNotificationEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main self._main_store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -497,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
# #
self.get_typing_handler = hs.get_typing_handler self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id: str) -> JsonDict: def _make_event_for(self, room_id: str) -> JsonMapping:
typing = self.get_typing_handler()._room_typing[room_id] typing = self.get_typing_handler()._room_typing[room_id]
return { return {
"type": EduTypes.TYPING, "type": EduTypes.TYPING,
@ -507,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
async def get_new_events_as( async def get_new_events_as(
self, from_key: int, service: ApplicationService self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new typing events that an appservice """Returns a set of new typing events that an appservice
may be interested in. may be interested in.
@ -551,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
room_ids: Iterable[str], room_ids: Iterable[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonMapping], int]:
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)
handler = self.get_typing_handler() handler = self.get_typing_handler()

View File

@ -131,7 +131,7 @@ class BulkPushRuleEvaluator:
async def _get_rules_for_event( async def _get_rules_for_event(
self, self,
event: EventBase, event: EventBase,
) -> Dict[str, FilteredPushRules]: ) -> Mapping[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about """Get the push rules for all users who may need to be notified about
the event. the event.

View File

@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import DeviceListUpdates, JsonDict from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore(
self, self,
service: ApplicationService, service: ApplicationService,
events: Sequence[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonMapping],
to_device_messages: List[JsonDict], to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates, device_list_summary: DeviceListUpdates,

View File

@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key from synapse.types import (
JsonDict,
JsonMapping,
StrCollection,
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable @cancellable
async def get_user_devices_from_cache( async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]:
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.
Args: Args:
@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
# First fetch all the users which all devices are to be returned. # First fetch all the users which all devices are to be returned.
results: Dict[str, Mapping[str, JsonDict]] = {} results: Dict[str, Mapping[str, JsonMapping]] = {}
for user_id in user_ids: for user_id in user_ids:
if user_id in user_ids_in_cache: if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id) results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already # Then fetch all device-specific requests, but skip users we've already
# fetched all devices for. # fetched all devices for.
device_specific_results: Dict[str, Dict[str, JsonDict]] = {} device_specific_results: Dict[str, Dict[str, JsonMapping]] = {}
for user_id, device_id in user_and_device_ids: for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids: if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id) device = await self._get_cached_user_device(user_id, device_id)
@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return user_ids_in_cache return user_ids_in_cache
@cached(num_args=2, tree=True) @cached(num_args=2, tree=True)
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: async def _get_cached_user_device(
self, user_id: str, device_id: str
) -> JsonMapping:
content = await self.db_pool.simple_select_one_onecol( content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content) return db_to_json(content)
@cached() @cached()
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list( devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) )
async def get_device_list_last_stream_id_for_remotes( async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Dict[str, Optional[str]]: ) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",

View File

@ -52,7 +52,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_device_keys_for_federation_query( async def get_e2e_device_keys_for_federation_query(
self, user_id: str self, user_id: str
) -> Tuple[int, List[JsonDict]]: ) -> Tuple[int, Sequence[JsonMapping]]:
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user
Returns: Returns:
@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(iterable=True) @cached(iterable=True)
async def _get_e2e_device_keys_for_federation_query_inner( async def _get_e2e_device_keys_for_federation_query_inner(
self, user_id: str self, user_id: str
) -> List[JsonDict]: ) -> Sequence[JsonMapping]:
"""Get all devices (with any device keys) for a user""" """Get all devices (with any device keys) for a user"""
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000) @cached(max_entries=10000)
async def count_e2e_one_time_keys( async def count_e2e_one_time_keys(
self, user_id: str, device_id: str self, user_id: str, device_id: str
) -> Dict[str, int]: ) -> Mapping[str, int]:
"""Count the number of one time keys the server has for a device """Count the number of one time keys the server has for a device
Returns: Returns:
A mapping from algorithm to number of keys for that algorithm. A mapping from algorithm to number of keys for that algorithm.
@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_cross_signing_key( async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[JsonDict]: ) -> Optional[JsonMapping]:
"""Returns a user's cross-signing key. """Returns a user's cross-signing key.
Args: Args:
@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type) return user_keys.get(key_type)
@cached(num_args=1) @cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: def _get_bare_e2e_cross_signing_keys(
self, user_id: str
) -> Mapping[str, JsonMapping]:
"""Dummy function. Only used to make a cache for """Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk. _get_bare_e2e_cross_signing_keys_bulk.
""" """
@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
) )
async def _get_bare_e2e_cross_signing_keys_bulk( async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users. The output of this """Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched. the signatures for the calling user need to be fetched.
@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
their user ID will map to None. their user ID will map to None.
""" """
result = await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk", "get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn, self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids, user_ids,
) )
# The `Optional` comes from the `@cachedList` decorator.
return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn( def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable @cancellable
async def get_e2e_cross_signing_keys_bulk( async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users. """Returns the cross-signing keys for a set of users.
Args: Args:
@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if from_user_id: if from_user_id:
result = cast( result = cast(
Dict[str, Optional[Mapping[str, JsonDict]]], Dict[str, Optional[Mapping[str, JsonMapping]]],
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures", "get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn, self._get_e2e_cross_signing_signatures_txn,

View File

@ -24,6 +24,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
MutableMapping, MutableMapping,
Optional, Optional,
Set, Set,
@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore):
self, self,
room_id: str, room_id: str,
event_ids: Collection[str], event_ids: Collection[str],
) -> Dict[str, bool]: ) -> Mapping[str, bool]:
"""Helper for have_seen_events """Helper for have_seen_events
Returns: Returns:
@ -2325,7 +2326,7 @@ class EventsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events( async def get_partial_state_events(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, bool]: ) -> Mapping[str, bool]:
"""Checks which of the given events have partial state """Checks which of the given events have partial state
Args: Args:

View File

@ -16,7 +16,7 @@
import itertools import itertools
import json import json
import logging import logging
from typing import Dict, Iterable, Optional, Tuple from typing import Dict, Iterable, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore):
) )
async def get_server_keys_json( async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]] self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]: ) -> Mapping[Tuple[str, str], FetchKeyResult]:
""" """
Args: Args:
server_name_and_key_ids: server_name_and_key_ids:
@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore):
) )
async def get_server_keys_json_for_remote( async def get_server_keys_json_for_remote(
self, server_name: str, key_ids: Iterable[str] self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, Optional[FetchKeyResultForRemote]]: ) -> Mapping[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs. """Fetch the cached keys for the given server/key IDs.
If we have multiple entries for a given key ID, returns the most recent. If we have multiple entries for a given key ID, returns the most recent.

View File

@ -11,7 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
cast,
)
from synapse.api.presence import PresenceState, UserPresenceState from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
) )
async def get_presence_for_users( async def get_presence_for_users(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]: ) -> Mapping[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",

View File

@ -216,7 +216,7 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules( async def bulk_get_push_rules(
self, user_ids: Collection[str] self, user_ids: Collection[str]
) -> Dict[str, FilteredPushRules]: ) -> Mapping[str, FilteredPushRules]:
if not user_ids: if not user_ids:
return {} return {}

View File

@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import JsonDict from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached() @cached()
async def _get_receipts_for_user_with_orderings( async def _get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str self, user_id: str, receipt_type: str
) -> JsonDict: ) -> JsonMapping:
""" """
Fetch receipts for all rooms that the given user is joined to. Fetch receipts for all rooms that the given user is joined to.
@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_rooms(
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]: ) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.
Args: Args:
@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room( async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> Sequence[JsonDict]: ) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
Args: Args:
@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True) @cached(tree=True)
async def _get_linearized_receipts_for_room( async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> Sequence[JsonDict]: ) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room""" """See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
async def _get_linearized_receipts_for_rooms( async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, Sequence[JsonDict]]: ) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids: if not room_ids:
return {} return {}
@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
async def get_linearized_receipts_for_all_rooms( async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None self, to_key: int, from_key: Optional[int] = None
) -> Mapping[str, JsonDict]: ) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up """Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts. to a limit of the latest 100 read receipts.

View File

@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def get_applicable_edits( async def get_applicable_edits(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]: ) -> Mapping[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
events. events.
@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def get_thread_summaries( async def get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]: ) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given events. """Get the number of threaded replies and the latest reply (if any) for the given events.
Args: Args:
@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids") @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def get_threads_participated( async def get_threads_participated(
self, event_ids: Collection[str], user_id: str self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]: ) -> Mapping[str, bool]:
"""Get whether the requesting user participated in the given threads. """Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across This is separate from get_thread_summaries since that can be cached across

View File

@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
async def get_subset_users_in_room_with_profiles( async def get_subset_users_in_room_with_profiles(
self, room_id: str, user_ids: Collection[str] self, room_id: str, user_ids: Collection[str]
) -> Dict[str, ProfileInfo]: ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for a list of users """Get a mapping from user ID to profile information for a list of users
in a given room. in a given room.
@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
async def _get_rooms_for_users( async def _get_rooms_for_users(
self, user_ids: Collection[str] self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[str]]: ) -> Mapping[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`. """A batched version of `get_rooms_for_user`.
Returns: Returns:
@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
async def _get_user_ids_from_membership_event_ids( async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str] self, event_ids: Iterable[str]
) -> Dict[str, Optional[str]]: ) -> Mapping[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join """For given set of member event_ids check if they point to a join
event. event.
@ -1191,7 +1191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
async def get_membership_from_event_ids( async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str] self, member_event_ids: Iterable[str]
) -> Dict[str, Optional[EventIdMembership]]: ) -> Mapping[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs. """Get user_id and membership of a set of event IDs.
Returns: Returns:

View File

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
Mapping,
Optional,
Set,
Tuple,
)
import attr import attr
@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
async def _get_state_group_for_events( async def _get_state_group_for_events(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, int]: ) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group. """Returns mapping event_id -> state_group.
Raises: Raises:

View File

@ -14,7 +14,7 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) )
async def get_destination_retry_timings_batch( async def get_destination_retry_timings_batch(
self, destinations: StrCollection self, destinations: StrCollection
) -> Dict[str, Optional[DestinationRetryTimings]]: ) -> Mapping[str, Optional[DestinationRetryTimings]]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="destinations", table="destinations",
iterable=destinations, iterable=destinations,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable from typing import Iterable, Mapping
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
return bool(result) return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids") @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]:
""" """
Checks which users in a list have requested erasure Checks which users in a list have requested erasure