Remove manys calls to cursor_to_dict (#16431)

This avoids calling cursor_to_dict and then immediately
unpacking the values in the dict for other users. By not
creating the intermediate dictionary we can avoid allocating
the dictionary and strings for the keys, which should generally
be more performant.

Additionally this improves type hints by avoid Dict[str, Any]
dictionaries coming out of the database layer.
This commit is contained in:
Patrick Cloke 2023-10-05 11:07:38 -04:00 committed by GitHub
parent 4e302b30b6
commit fa907025f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 319 additions and 227 deletions

View File

@ -1 +1 @@
Reduce the size of each replication command instance.
Reduce memory allocations.

1
changelog.d/16431.misc Normal file
View File

@ -0,0 +1 @@
Reduce memory allocations.

View File

@ -101,7 +101,7 @@ if TYPE_CHECKING:
class PusherConfig:
"""Parameters necessary to configure a pusher."""
id: Optional[str]
id: Optional[int]
user_name: str
profile_tag: str

View File

@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in txn
}
return await self.db_pool.runInteraction(
@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
for room_id, account_data_type, content in txn:
room_data = by_room.setdefault(room_id, {})
room_data[row["account_data_type"]] = db_to_json(row["content"])
room_data[account_data_type] = db_to_json(content)
return by_room

View File

@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Pattern,
Sequence,
Tuple,
cast,
)
from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
from synapse.appservice import (
ApplicationService,
@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
"SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None
entry = rows[0]
return entry
return cast(Optional[Tuple[int, str]], txn.fetchone())
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
event_ids = db_to_json(entry["event_ids"])
txn_id, event_ids_str = entry
event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused
@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
id=entry["txn_id"],
id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],

View File

@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
devices: Dict[str, List[str]] = {}
for row in rows:
for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
if self.hs.is_mine_id(row["user_id"]):
user_devices = devices.setdefault(row["user_id"], [])
user_devices.append(row["device_id"])
if self.hs.is_mine_id(user_id):
user_devices = devices.setdefault(user_id, [])
user_devices.append(device_id)
return devices

View File

@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}
txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
user_keys[key_type] = key
user_keys[key_type] = db_to_json(key_data)
return result
@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)
txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
key_id: str = row["key_id"]
target_user_id: str = row["target_user_id"]
target_device_id: str = row["target_device_id"]
for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
user_sigs[key_id] = signature
else:
signatures[from_user_id] = {key_id: row["signature"]}
signatures[from_user_id] = {key_id: signature}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}
target_user_key["signatures"] = {from_user_id: {key_id: signature}}
return keys

View File

@ -1654,8 +1654,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []
rows = []
ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@ -1676,10 +1674,9 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
for event_id, redacts, rejects in txn:
event = ev_map[event_id]
if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None:

View File

@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
return [
UserPresenceState(
user_id=user_id,
state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
]
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup

View File

@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# The type of a row in the pushers table.
PusherRow = Tuple[
int, # id
str, # user_name
Optional[int], # access_token
str, # profile_tag
str, # kind
str, # app_id
str, # app_display_name
str, # device_display_name
str, # pushkey
int, # ts
str, # lang
str, # data
int, # last_stream_ordering
int, # last_success
int, # failing_since
bool, # enabled
str, # device_id
]
class PusherWorkerStore(SQLBaseStore):
def __init__(
@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
def _decode_pushers_rows(
self,
rows: Iterable[PusherRow],
) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
for r in rows:
data_json = r["data"]
for (
id,
user_name,
access_token,
profile_tag,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
ts,
lang,
data,
last_stream_ordering,
last_success,
failing_since,
enabled,
device_id,
) in rows:
try:
r["data"] = db_to_json(data_json)
data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
data_json,
id,
data,
e.args[0],
)
continue
yield PusherConfig(
id=id,
user_name=user_name,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
ts=ts,
lang=lang,
data=data_json,
last_stream_ordering=last_stream_ordering,
last_success=last_success,
failing_since=failing_since,
# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
r["enabled"] = bool(r["enabled"])
yield PusherConfig(**r)
enabled=bool(enabled),
device_id=device_id,
access_token=access_token,
)
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
return self.db_pool.cursor_to_dict(txn)
return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@ -164,15 +221,23 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
rows = self.db_pool.cursor_to_dict(txn)
def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
txn.execute(
"""
SELECT id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since,
enabled, device_id
FROM pushers WHERE COALESCE(enabled, TRUE)
"""
)
return cast(List[PusherRow], txn.fetchall())
return self._decode_pushers_rows(rows)
return await self.db_pool.runInteraction(
return self._decode_pushers_rows(
await self.db_pool.runInteraction(
"get_enabled_pushers", get_enabled_pushers_txn
)
)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
self, pusher_id: str
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if len(rows) == 0:
return 0
@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
key_values=[(row["pusher_id"],) for row in rows],
key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
(row["pusher_device_id"] or row["token_device_id"], None)
for row in rows
(pusher_device_id or token_device_id, None)
for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)

View File

@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, to_key))
rows = self.db_pool.cursor_to_dict(txn)
return rows
return cast(List[Tuple[str, str, str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
] = db_to_json(row["data"])
for receipt_type, user_id, event_id, data in rows:
content.setdefault(event_id, {}).setdefault(receipt_type, {})[
user_id
] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not room_ids:
return {}
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
return self.db_pool.cursor_to_dict(txn)
return cast(
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
)
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
for row in txn_results:
for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
room_id,
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
event_entry = room_event["content"].setdefault(event_id, {})
receipt_type_dict = event_entry.setdefault(receipt_type, {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
if row["thread_id"]:
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
receipt_type_dict[user_id] = db_to_json(data)
if thread_id:
receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
for row in txn_results:
for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
room_id,
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
event_entry = room_event["content"].setdefault(event_id, {})
receipt_type_dict = event_entry.setdefault(receipt_type, {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
receipt_type_dict[user_id] = db_to_json(data)
return results

View File

@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@ -213,34 +213,45 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
row = txn.fetchone()
if not row:
return None
return rows[0]
row = await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is None:
return None
(
name,
is_guest,
admin,
consent_version,
consent_ts,
consent_server_notice_sent,
appservice_id,
creation_ts,
user_type,
deactivated,
shadow_banned,
approved,
locked,
) = row
return UserInfo(
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
appservice_id=appservice_id,
consent_server_notice_sent=consent_server_notice_sent,
consent_version=consent_version,
consent_ts=consent_ts,
creation_ts=creation_ts,
is_admin=bool(admin),
is_deactivated=bool(deactivated),
is_guest=bool(is_guest),
is_shadow_banned=bool(shadow_banned),
user_id=UserID.from_string(name),
user_type=user_type,
approved=bool(approved),
locked=bool(locked),
)
return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
async def is_trial_user(self, user_id: str) -> bool:
@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
row = txn.fetchone()
if rows:
row = rows[0]
if row:
(
user_id,
is_guest,
shadow_banned,
token_id,
device_id,
valid_until_ms,
token_owner,
token_used,
) = row
return TokenLookupResult(
user_id=user_id,
is_guest=is_guest,
shadow_banned=shadow_banned,
token_id=token_id,
device_id=device_id,
valid_until_ms=valid_until_ms,
token_owner=token_owner,
# This field is nullable, ensure it comes out as a boolean
if row["token_used"] is None:
row["token_used"] = False
return TokenLookupResult(**row)
token_used=bool(token_used),
)
return None
@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
txn.execute("SELECT COUNT(*) FROM users")
row = txn.fetchone()
assert row is not None
return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
row = txn.fetchone()
assert row is not None
return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
for (name,) in txn.fetchall():
self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
row = txn.fetchone()
assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
return bool(rows[0]["approved"])
return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
for user in rows:
if not user["count_tokens"] and not user["count_threepids"]:
self.set_user_deactivated_status_txn(txn, user["name"], True)
for name, count_tokens, count_threepids in rows:
if not count_tokens and not count_threepids:
self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):

View File

@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Optional[int]]]:
) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)
return self.db_pool.cursor_to_dict(txn)
return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)
min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
min_lifetime, max_lifetime = ret
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args)
rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
rooms_dict[row["room_id"]] = RetentionPolicy(
min_lifetime=row["min_lifetime"],
max_lifetime=row["max_lifetime"],
rooms_dict = {
room_id: RetentionPolicy(
min_lifetime=min_lifetime,
max_lifetime=max_lifetime,
)
for room_id, min_lifetime, max_lifetime in txn
}
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql)
rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
rooms_dict[row["room_id"]] = RetentionPolicy()
for (room_id,) in txn:
if room_id not in rooms_dict:
rooms_dict[room_id] = RetentionPolicy()
return rooms_dict
@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return True
for row in rows:
if not row["json"]:
for room_id, event_id, json in rows:
if not json:
retention_policy = {}
else:
ev = db_to_json(row["json"])
ev = db_to_json(json)
retention_policy = ev["content"]
self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
"room_id": row["room_id"],
"event_id": row["event_id"],
"room_id": room_id,
"event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
txn, "insert_room_retention", {"room_id": rows[-1][0]}
)
if batch_size > len(rows):

View File

@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1]["stream_ordering"]
min_stream_id = rows[-1][0]
to_update = []
for row in rows:
event_id = row["event_id"]
room_id = row["room_id"]
for _, event_id, room_id, json in rows:
try:
event_json = db_to_json(row["json"])
event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue

View File

@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
rows = self.db_pool.cursor_to_dict(txn)
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1]["stream_ordering"]
min_stream_id = rows[-1][0]
event_search_rows = []
for row in rows:
for (
stream_ordering,
event_id,
room_id,
etype,
json,
origin_server_ts,
) in rows:
try:
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
event_json = db_to_json(row["json"])
event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
@staticmethod
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
row["status"] = TaskStatus(row["status"])
if row["params"] is not None:
row["params"] = db_to_json(row["params"])
if row["result"] is not None:
row["result"] = db_to_json(row["result"])
return ScheduledTask(**row)
def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
task_id, action, status, timestamp, resource_id, params, result, error = row
return ScheduledTask(
id=task_id,
action=action,
status=TaskStatus(status),
timestamp=timestamp,
resource_id=resource_id,
params=db_to_json(params) if params is not None else None,
result=db_to_json(result) if result is not None else None,
error=error,
)
async def get_scheduled_tasks(
self,
@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn)
return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.