Convert state and stream stores and related code to async (#8194)
This commit is contained in:
parent
b055dc9322
commit
aec7085179
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
old_room_member_state_events = await self.store.get_events(
|
||||
old_room_member_state_ids.values()
|
||||
)
|
||||
for k, old_event in old_room_member_state_events.items():
|
||||
for old_event in old_room_member_state_events.values():
|
||||
# Only transfer ban events
|
||||
if (
|
||||
"membership" in old_event.content
|
||||
|
|
|
@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import StateMap
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
return create_event
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
def get_current_state_ids(self, room_id):
|
||||
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
|
||||
"""Get the current state event ids for a room based on the
|
||||
current_state_events table.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room to get the state IDs of.
|
||||
|
||||
Returns:
|
||||
deferred: dict of (type, state_key) -> event_id
|
||||
The current state of the room.
|
||||
"""
|
||||
|
||||
def _get_current_state_ids_txn(txn):
|
||||
|
@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_state_ids", _get_current_state_ids_txn
|
||||
)
|
||||
|
||||
# FIXME: how should this be cached?
|
||||
def get_filtered_current_state_ids(
|
||||
async def get_filtered_current_state_ids(
|
||||
self, room_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state event of a given type for a room based on the
|
||||
current_state_events table. This may not be as up-to-date as the result
|
||||
of doing a fresh state resolution as per state_handler.get_current_state
|
||||
|
@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
from the database.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
|
||||
Map from type/state_key to event ID.
|
||||
"""
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
|
||||
if not where_clause:
|
||||
# We delegate to the cached version
|
||||
return self.get_current_state_ids(room_id)
|
||||
return await self.get_current_state_ids(room_id)
|
||||
|
||||
def _get_filtered_current_state_ids_txn(txn):
|
||||
results = {}
|
||||
|
@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
|
@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class StateDeltasStore(SQLBaseStore):
|
||||
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
|
||||
async def get_current_state_deltas(
|
||||
self, prev_stream_id: int, max_stream_id: int
|
||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
"""Fetch a list of room state changes since the given stream id
|
||||
|
||||
Each entry in the result contains the following fields:
|
||||
|
@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
|
|||
if it's new state.
|
||||
|
||||
Args:
|
||||
prev_stream_id (int): point to get changes since (exclusive)
|
||||
max_stream_id (int): the point that we know has been correctly persisted
|
||||
prev_stream_id: point to get changes since (exclusive)
|
||||
max_stream_id: the point that we know has been correctly persisted
|
||||
- ie, an upper limit to return changes from.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[int, list[dict]]: A tuple consisting of:
|
||||
A tuple consisting of:
|
||||
- the stream id which these results go up to
|
||||
- list of current_state_delta_stream rows. If it is empty, we are
|
||||
up to date.
|
||||
|
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
|
|||
# if the CSDs haven't changed between prev_stream_id and now, we
|
||||
# know for certain that they haven't changed between prev_stream_id and
|
||||
# max_stream_id.
|
||||
return defer.succeed((max_stream_id, []))
|
||||
return (max_stream_id, [])
|
||||
|
||||
def get_current_state_deltas_txn(txn):
|
||||
# First we calculate the max stream id that will give us less than
|
||||
|
@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
|
|||
txn.execute(sql, (prev_stream_id, clipped_stream_id))
|
||||
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_state_deltas", get_current_state_deltas_txn
|
||||
)
|
||||
|
||||
|
@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
|
|||
retcol="COALESCE(MAX(stream_id), -1)",
|
||||
)
|
||||
|
||||
def get_max_stream_id_in_current_state_deltas(self):
|
||||
return self.db_pool.runInteraction(
|
||||
async def get_max_stream_id_in_current_state_deltas(self):
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_max_stream_id_in_current_state_deltas",
|
||||
self._get_max_stream_id_in_current_state_deltas_txn,
|
||||
)
|
||||
|
|
|
@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return rows, token
|
||||
|
||||
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
|
||||
async def get_room_event_before_stream_ordering(
|
||||
self, room_id: str, stream_ordering: int
|
||||
) -> Tuple[int, int, str]:
|
||||
"""Gets details of the first event in a room at or before a stream ordering
|
||||
|
||||
Args:
|
||||
|
@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
stream_ordering:
|
||||
|
||||
Returns:
|
||||
Deferred[(int, int, str)]:
|
||||
(stream ordering, topological ordering, event_id)
|
||||
A tuple of (stream ordering, topological ordering, event_id)
|
||||
"""
|
||||
|
||||
def _f(txn):
|
||||
|
@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
txn.execute(sql, (room_id, stream_ordering))
|
||||
return txn.fetchone()
|
||||
|
||||
return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_event_before_stream_ordering", _f
|
||||
)
|
||||
|
||||
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
|
||||
"""Returns the current token for rooms stream.
|
||||
|
|
|
@ -17,8 +17,6 @@ import logging
|
|||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=10000, iterable=True)
|
||||
def get_state_group_delta(self, state_group):
|
||||
async def get_state_group_delta(self, state_group):
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
|
@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_state_group_delta", _get_state_group_delta_txn
|
||||
)
|
||||
|
||||
|
@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
fetched_keys=non_member_types,
|
||||
)
|
||||
|
||||
def store_state_group(
|
||||
async def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
):
|
||||
) -> int:
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
|
@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
to event_id.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
The state group ID
|
||||
"""
|
||||
|
||||
def _store_state_group_txn(txn):
|
||||
|
@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
|
||||
return state_group
|
||||
|
||||
return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
|
||||
return await self.db_pool.runInteraction(
|
||||
"store_state_group", _store_state_group_txn
|
||||
)
|
||||
|
||||
def purge_unreferenced_state_groups(
|
||||
async def purge_unreferenced_state_groups(
|
||||
self, room_id: str, state_groups_to_delete
|
||||
) -> defer.Deferred:
|
||||
) -> None:
|
||||
"""Deletes no longer referenced state groups and de-deltas any state
|
||||
groups that reference them.
|
||||
|
||||
|
@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
to delete.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"purge_unreferenced_state_groups",
|
||||
self._purge_unreferenced_state_groups,
|
||||
room_id,
|
||||
|
@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
|
||||
return {row["state_group"]: row["prev_state_group"] for row in rows}
|
||||
|
||||
def purge_room_state(self, room_id, state_groups_to_delete):
|
||||
async def purge_room_state(self, room_id, state_groups_to_delete):
|
||||
"""Deletes all record of a room from state tables
|
||||
|
||||
Args:
|
||||
|
@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
state_groups_to_delete (list[int]): State groups to delete
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"purge_room_state",
|
||||
self._purge_room_state_txn,
|
||||
room_id,
|
||||
|
|
|
@ -333,7 +333,7 @@ class StateGroupStorage(object):
|
|||
def __init__(self, hs, stores):
|
||||
self.stores = stores
|
||||
|
||||
def get_state_group_delta(self, state_group: int):
|
||||
async def get_state_group_delta(self, state_group: int):
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
|
@ -341,11 +341,11 @@ class StateGroupStorage(object):
|
|||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
|
||||
Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
(prev_group, delta_ids)
|
||||
"""
|
||||
|
||||
return self.stores.state.get_state_group_delta(state_group)
|
||||
return await self.stores.state.get_state_group_delta(state_group)
|
||||
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Iterable[str]
|
||||
|
@ -525,7 +525,7 @@ class StateGroupStorage(object):
|
|||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
@ -546,14 +546,14 @@ class StateGroupStorage(object):
|
|||
"""
|
||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||
|
||||
def store_state_group(
|
||||
async def store_state_group(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
prev_group: Optional[int],
|
||||
delta_ids: Optional[dict],
|
||||
current_state_ids: dict,
|
||||
):
|
||||
) -> int:
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
|
@ -567,8 +567,8 @@ class StateGroupStorage(object):
|
|||
to event_id.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
The state group ID
|
||||
"""
|
||||
return self.stores.state.store_state_group(
|
||||
return await self.stores.state.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue