diff --git a/changelog.d/8194.misc b/changelog.d/8194.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8194.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1419d72e94..9d5b1828df 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler): old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in old_room_member_state_events.items(): + for old_event in old_room_member_state_events.values(): # Only transfer ban events if ( "membership" in old_event.content diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 458f169617..5c6168e301 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): + async def get_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. Args: - room_id (str) + room_id: The room to get the state IDs of. Returns: - deferred: dict of (type, state_key) -> event_id + The current state of the room. """ def _get_current_state_ids_txn(txn): @@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - def get_filtered_current_state_ids( + async def get_filtered_current_state_ids( self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state @@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): from the database. Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() if not where_clause: # We delegate to the cached version - return self.get_current_state_ids(room_id) + return await self.get_current_state_ids(room_id) def _get_filtered_current_state_ids_txn(txn): results = {} @@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 0d963c98ff..356623fc6e 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore @@ -23,7 +22,9 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore): if it's new state. Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: + A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. @@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return defer.succeed((max_stream_id, [])) + return (max_stream_id, []) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore): txn.execute(sql, (prev_stream_id, clipped_stream_id)) return clipped_stream_id, self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) @@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore): retcol="COALESCE(MAX(stream_id), -1)", ) - def get_max_stream_id_in_current_state_deltas(self): - return self.db_pool.runInteraction( + async def get_max_stream_id_in_current_state_deltas(self): + return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 497f607703..24f44a7e36 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): + async def get_room_event_before_stream_ordering( + self, room_id: str, stream_ordering: int + ) -> Tuple[int, int, str]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering: Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) + A tuple of (stream ordering, topological ordering, event_id) """ def _f(txn): @@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) + return await self.db_pool.runInteraction( + "get_room_event_before_stream_ordering", _f + ) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7f104ad936..e924f1ca3b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. @@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_state_group_delta", _get_state_group_delta_txn ) @@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ def _store_state_group_txn(txn): @@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_group - return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) + return await self.db_pool.runInteraction( + "store_state_group", _store_state_group_txn + ) - def purge_unreferenced_state_groups( + async def purge_unreferenced_state_groups( self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: + ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to delete. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return {row["state_group"]: row["prev_state_group"] for row in rows} - def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state(self, room_id, state_groups_to_delete): """Deletes all record of a room from state tables Args: @@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_groups_to_delete (list[int]): State groups to delete """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 534883361f..96a1b59d64 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -333,7 +333,7 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group: int): + async def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,11 +341,11 @@ class StateGroupStorage(object): state_group: The state group used to retrieve state deltas. Returns: - Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: + Tuple[Optional[int], Optional[StateMap[str]]]: (prev_group, delta_ids) """ - return self.stores.state.get_state_group_delta(state_group) + return await self.stores.state.get_state_group_delta(state_group) async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] @@ -525,7 +525,7 @@ class StateGroupStorage(object): state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] @@ -546,14 +546,14 @@ class StateGroupStorage(object): """ return self.stores.state._get_state_for_groups(groups, state_filter) - def store_state_group( + async def store_state_group( self, event_id: str, room_id: str, prev_group: Optional[int], delta_ids: Optional[dict], current_state_ids: dict, - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -567,8 +567,8 @@ class StateGroupStorage(object): to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ - return self.stores.state.store_state_group( + return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids )