diff --git a/changelog.d/11580.misc b/changelog.d/11580.misc new file mode 100644 index 0000000000..2c48e22de0 --- /dev/null +++ b/changelog.d/11580.misc @@ -0,0 +1 @@ +Add some safety checks that storage functions are used correctly. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a219999f15..2cacc7dd6c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -55,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -986,7 +987,7 @@ class DatabasePool: self, table: str, keys: Collection[str], - values: Iterable[Iterable[Any]], + values: Collection[Collection[Any]], desc: str, ) -> None: """Executes an INSERT query on the named table. @@ -1427,7 +1428,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", ) -> Dict[str, Any]: @@ -1438,7 +1439,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1448,7 +1449,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1618,7 +1619,7 @@ class DatabasePool: self, table: str, keyvalues: Optional[Dict[str, Any]], - retcols: Iterable[str], + retcols: Collection[str], desc: str = "simple_select_list", ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or @@ -1681,7 +1682,7 @@ class DatabasePool: table: str, column: str, iterable: Iterable[Any], - retcols: Iterable[str], + retcols: Collection[str], keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, @@ -1704,16 +1705,7 @@ class DatabasePool: results: List[Dict[str, Any]] = [] - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: + for chunk in batch_iter(iterable, batch_size): rows = await self.runInteraction( desc, self.simple_select_many_txn, @@ -1853,7 +1845,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: select_sql = "SELECT %s FROM %s WHERE %s" % ( @@ -2146,7 +2138,7 @@ class DatabasePool: table: str, term: Optional[str], col: str, - retcols: Iterable[str], + retcols: Collection[str], desc="simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index b73ce53c91..7ab681ed6f 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore): # This only exists for the cachedList decorator raise NotImplementedError() - @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - ) - async def get_if_users_have_pushers( - self, user_ids: Iterable[str] - ) -> Dict[str, bool]: - rows = await self.db_pool.simple_select_many_batch( - table="pushers", - column="user_name", - iterable=user_ids, - retcols=["user_name"], - desc="get_if_users_have_pushers", - ) - - result = {user_id: False for user_id in user_ids} - result.update({r["user_name"]: True for r in rows}) - - return result - async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ) -> None: