Remove manys calls to cursor_to_dict (#16431)
This avoids calling cursor_to_dict and then immediately unpacking the values in the dict for other users. By not creating the intermediate dictionary we can avoid allocating the dictionary and strings for the keys, which should generally be more performant. Additionally this improves type hints by avoid Dict[str, Any] dictionaries coming out of the database layer.
This commit is contained in:
parent
4e302b30b6
commit
fa907025f4
|
@ -1 +1 @@
|
|||
Reduce the size of each replication command instance.
|
||||
Reduce memory allocations.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Reduce memory allocations.
|
|
@ -101,7 +101,7 @@ if TYPE_CHECKING:
|
|||
class PusherConfig:
|
||||
"""Parameters necessary to configure a pusher."""
|
||||
|
||||
id: Optional[str]
|
||||
id: Optional[int]
|
||||
user_name: str
|
||||
|
||||
profile_tag: str
|
||||
|
|
|
@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
sql += " AND content != '{}'"
|
||||
|
||||
txn.execute(sql, (user_id,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return {
|
||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
account_data_type: db_to_json(content)
|
||||
for account_data_type, content in txn
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
sql += " AND content != '{}'"
|
||||
|
||||
txn.execute(sql, (user_id,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
by_room: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for row in rows:
|
||||
room_data = by_room.setdefault(row["room_id"], {})
|
||||
for room_id, account_data_type, content in txn:
|
||||
room_data = by_room.setdefault(room_id, {})
|
||||
|
||||
room_data[row["account_data_type"]] = db_to_json(row["content"])
|
||||
room_data[account_data_type] = db_to_json(content)
|
||||
|
||||
return by_room
|
||||
|
||||
|
|
|
@ -14,17 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
|
||||
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
|
@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
|
||||
def _get_oldest_unsent_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[Tuple[int, str]]:
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||
"SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
|
||||
" ORDER BY txn_id ASC LIMIT 1",
|
||||
(service.id,),
|
||||
)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
entry = rows[0]
|
||||
|
||||
return entry
|
||||
return cast(Optional[Tuple[int, str]], txn.fetchone())
|
||||
|
||||
entry = await self.db_pool.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||
|
@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
if not entry:
|
||||
return None
|
||||
|
||||
event_ids = db_to_json(entry["event_ids"])
|
||||
txn_id, event_ids_str = entry
|
||||
|
||||
event_ids = db_to_json(event_ids_str)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
|
||||
# TODO: to-device messages, one-time key counts, device list summaries and unused
|
||||
|
@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
# We likely want to populate those for reliability.
|
||||
return AppServiceTransaction(
|
||||
service=service,
|
||||
id=entry["txn_id"],
|
||||
id=txn_id,
|
||||
events=events,
|
||||
ephemeral=[],
|
||||
to_device_messages=[],
|
||||
|
|
|
@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
|
||||
def get_devices_not_accessed_since_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> List[Tuple[str, str]]:
|
||||
sql = """
|
||||
SELECT user_id, device_id
|
||||
FROM devices WHERE last_seen < ? AND hidden = FALSE
|
||||
"""
|
||||
txn.execute(sql, (since_ms,))
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[Tuple[str, str]], txn.fetchall())
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
"get_devices_not_accessed_since",
|
||||
|
@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
)
|
||||
|
||||
devices: Dict[str, List[str]] = {}
|
||||
for row in rows:
|
||||
for user_id, device_id in rows:
|
||||
# Remote devices are never stale from our point of view.
|
||||
if self.hs.is_mine_id(row["user_id"]):
|
||||
user_devices = devices.setdefault(row["user_id"], [])
|
||||
user_devices.append(row["device_id"])
|
||||
if self.hs.is_mine_id(user_id):
|
||||
user_devices = devices.setdefault(user_id, [])
|
||||
user_devices.append(device_id)
|
||||
|
||||
return devices
|
||||
|
||||
|
|
|
@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
}
|
||||
|
||||
txn.execute(sql, params)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
for row in rows:
|
||||
user_id = row["user_id"]
|
||||
key_type = row["keytype"]
|
||||
key = db_to_json(row["keydata"])
|
||||
for user_id, key_type, key_data, _ in txn:
|
||||
user_keys = result.setdefault(user_id, {})
|
||||
user_keys[key_type] = key
|
||||
user_keys[key_type] = db_to_json(key_data)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
query_params.extend(item)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
# and add the signatures to the appropriate keys
|
||||
for row in rows:
|
||||
key_id: str = row["key_id"]
|
||||
target_user_id: str = row["target_user_id"]
|
||||
target_device_id: str = row["target_device_id"]
|
||||
for target_user_id, target_device_id, key_id, signature in txn:
|
||||
key_type = devices[(target_user_id, target_device_id)]
|
||||
# We need to copy everything, because the result may have come
|
||||
# from the cache. dict.copy only does a shallow copy, so we
|
||||
|
@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
].copy()
|
||||
if from_user_id in signatures:
|
||||
user_sigs = signatures[from_user_id] = signatures[from_user_id]
|
||||
user_sigs[key_id] = row["signature"]
|
||||
user_sigs[key_id] = signature
|
||||
else:
|
||||
signatures[from_user_id] = {key_id: row["signature"]}
|
||||
signatures[from_user_id] = {key_id: signature}
|
||||
else:
|
||||
target_user_key["signatures"] = {
|
||||
from_user_id: {key_id: row["signature"]}
|
||||
}
|
||||
target_user_key["signatures"] = {from_user_id: {key_id: signature}}
|
||||
|
||||
return keys
|
||||
|
||||
|
|
|
@ -1654,8 +1654,6 @@ class PersistEventsStore:
|
|||
) -> None:
|
||||
to_prefill = []
|
||||
|
||||
rows = []
|
||||
|
||||
ev_map = {e.event_id: e for e, _ in events_and_contexts}
|
||||
if not ev_map:
|
||||
return
|
||||
|
@ -1676,10 +1674,9 @@ class PersistEventsStore:
|
|||
)
|
||||
|
||||
txn.execute(sql + clause, args)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
for row in rows:
|
||||
event = ev_map[row["event_id"]]
|
||||
if not row["rejects"] and not row["redacts"]:
|
||||
for event_id, redacts, rejects in txn:
|
||||
event = ev_map[event_id]
|
||||
if not rejects and not redacts:
|
||||
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||
|
||||
async def external_prefill() -> None:
|
||||
|
|
|
@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
|
||||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (PresenceState.OFFLINE,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
txn.close()
|
||||
|
||||
for row in rows:
|
||||
row["currently_active"] = bool(row["currently_active"])
|
||||
|
||||
return [UserPresenceState(**row) for row in rows]
|
||||
return [
|
||||
UserPresenceState(
|
||||
user_id=user_id,
|
||||
state=state,
|
||||
last_active_ts=last_active_ts,
|
||||
last_federation_update_ts=last_federation_update_ts,
|
||||
last_user_sync_ts=last_user_sync_ts,
|
||||
status_msg=status_msg,
|
||||
currently_active=bool(currently_active),
|
||||
)
|
||||
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
|
||||
]
|
||||
|
||||
def take_presence_startup_info(self) -> List[UserPresenceState]:
|
||||
active_on_startup = self._presence_on_startup
|
||||
|
|
|
@ -47,6 +47,27 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The type of a row in the pushers table.
|
||||
PusherRow = Tuple[
|
||||
int, # id
|
||||
str, # user_name
|
||||
Optional[int], # access_token
|
||||
str, # profile_tag
|
||||
str, # kind
|
||||
str, # app_id
|
||||
str, # app_display_name
|
||||
str, # device_display_name
|
||||
str, # pushkey
|
||||
int, # ts
|
||||
str, # lang
|
||||
str, # data
|
||||
int, # last_stream_ordering
|
||||
int, # last_success
|
||||
int, # failing_since
|
||||
bool, # enabled
|
||||
str, # device_id
|
||||
]
|
||||
|
||||
|
||||
class PusherWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
|
@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
self._remove_deleted_email_pushers,
|
||||
)
|
||||
|
||||
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
|
||||
def _decode_pushers_rows(
|
||||
self,
|
||||
rows: Iterable[PusherRow],
|
||||
) -> Iterator[PusherConfig]:
|
||||
"""JSON-decode the data in the rows returned from the `pushers` table
|
||||
|
||||
Drops any rows whose data cannot be decoded
|
||||
"""
|
||||
for r in rows:
|
||||
data_json = r["data"]
|
||||
for (
|
||||
id,
|
||||
user_name,
|
||||
access_token,
|
||||
profile_tag,
|
||||
kind,
|
||||
app_id,
|
||||
app_display_name,
|
||||
device_display_name,
|
||||
pushkey,
|
||||
ts,
|
||||
lang,
|
||||
data,
|
||||
last_stream_ordering,
|
||||
last_success,
|
||||
failing_since,
|
||||
enabled,
|
||||
device_id,
|
||||
) in rows:
|
||||
try:
|
||||
r["data"] = db_to_json(data_json)
|
||||
data_json = db_to_json(data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Invalid JSON in data for pusher %d: %s, %s",
|
||||
r["id"],
|
||||
data_json,
|
||||
id,
|
||||
data,
|
||||
e.args[0],
|
||||
)
|
||||
continue
|
||||
|
||||
yield PusherConfig(
|
||||
id=id,
|
||||
user_name=user_name,
|
||||
profile_tag=profile_tag,
|
||||
kind=kind,
|
||||
app_id=app_id,
|
||||
app_display_name=app_display_name,
|
||||
device_display_name=device_display_name,
|
||||
pushkey=pushkey,
|
||||
ts=ts,
|
||||
lang=lang,
|
||||
data=data_json,
|
||||
last_stream_ordering=last_stream_ordering,
|
||||
last_success=last_success,
|
||||
failing_since=failing_since,
|
||||
# If we're using SQLite, then boolean values are integers. This is
|
||||
# troublesome since some code using the return value of this method might
|
||||
# expect it to be a boolean, or will expose it to clients (in responses).
|
||||
r["enabled"] = bool(r["enabled"])
|
||||
|
||||
yield PusherConfig(**r)
|
||||
enabled=bool(enabled),
|
||||
device_id=device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
def get_pushers_stream_token(self) -> int:
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
The pushers for which the given columns have the given values.
|
||||
"""
|
||||
|
||||
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
|
||||
# We could technically use simple_select_list here, but we need to call
|
||||
# COALESCE on the 'enabled' column. While it is technically possible to give
|
||||
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
||||
|
@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[PusherRow], txn.fetchall())
|
||||
|
||||
ret = await self.db_pool.runInteraction(
|
||||
desc="get_pushers_by",
|
||||
|
@ -164,15 +221,23 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
return self._decode_pushers_rows(ret)
|
||||
|
||||
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
||||
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT id, user_name, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name, pushkey, ts, lang, data,
|
||||
last_stream_ordering, last_success, failing_since,
|
||||
enabled, device_id
|
||||
FROM pushers WHERE COALESCE(enabled, TRUE)
|
||||
"""
|
||||
)
|
||||
return cast(List[PusherRow], txn.fetchall())
|
||||
|
||||
return self._decode_pushers_rows(rows)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
return self._decode_pushers_rows(
|
||||
await self.db_pool.runInteraction(
|
||||
"get_enabled_pushers", get_enabled_pushers_txn
|
||||
)
|
||||
)
|
||||
|
||||
async def get_all_updated_pushers_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
|
@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_throttle_params_by_room(
|
||||
self, pusher_id: str
|
||||
self, pusher_id: int
|
||||
) -> Dict[str, ThrottleParams]:
|
||||
res = await self.db_pool.simple_select_list(
|
||||
"pusher_throttle",
|
||||
|
@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
return params_by_room
|
||||
|
||||
async def set_throttle_params(
|
||||
self, pusher_id: str, room_id: str, params: ThrottleParams
|
||||
self, pusher_id: int, room_id: str, params: ThrottleParams
|
||||
) -> None:
|
||||
await self.db_pool.simple_upsert(
|
||||
"pusher_throttle",
|
||||
|
@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||
(last_pusher_id, batch_size),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
if len(rows) == 0:
|
||||
return 0
|
||||
|
||||
|
@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||
txn=txn,
|
||||
table="pushers",
|
||||
key_names=("id",),
|
||||
key_values=[(row["pusher_id"],) for row in rows],
|
||||
key_values=[row[0] for row in rows],
|
||||
value_names=("device_id", "access_token"),
|
||||
# If there was already a device_id on the pusher, we only want to clear
|
||||
# the access_token column, so we keep the existing device_id. Otherwise,
|
||||
# we set the device_id we got from joining the access_tokens table.
|
||||
value_values=[
|
||||
(row["pusher_device_id"] or row["token_device_id"], None)
|
||||
for row in rows
|
||||
(pusher_device_id or token_device_id, None)
|
||||
for _, pusher_device_id, token_device_id in rows
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
|
||||
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
|
|
@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
) -> Sequence[JsonMapping]:
|
||||
"""See get_linearized_receipts_for_room"""
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
|
||||
if from_key:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
"SELECT receipt_type, user_id, event_id, data"
|
||||
" FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, from_key, to_key))
|
||||
else:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
"SELECT receipt_type, user_id, event_id, data"
|
||||
" FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, to_key))
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return rows
|
||||
return cast(List[Tuple[str, str, str, str]], txn.fetchall())
|
||||
|
||||
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
|
||||
|
||||
|
@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
return []
|
||||
|
||||
content: JsonDict = {}
|
||||
for row in rows:
|
||||
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
|
||||
row["user_id"]
|
||||
] = db_to_json(row["data"])
|
||||
for receipt_type, user_id, event_id, data in rows:
|
||||
content.setdefault(event_id, {}).setdefault(receipt_type, {})[
|
||||
user_id
|
||||
] = db_to_json(data)
|
||||
|
||||
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
|
||||
|
||||
|
@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
if not room_ids:
|
||||
return {}
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
def f(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
|
||||
if from_key:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
|
||||
FROM receipts_linearized WHERE
|
||||
stream_id > ? AND stream_id <= ? AND
|
||||
"""
|
||||
clause, args = make_in_list_sql_clause(
|
||||
|
@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql + clause, [from_key, to_key] + list(args))
|
||||
else:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
|
||||
FROM receipts_linearized WHERE
|
||||
stream_id <= ? AND
|
||||
"""
|
||||
|
||||
|
@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql + clause, [to_key] + list(args))
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(
|
||||
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
|
||||
)
|
||||
|
||||
txn_results = await self.db_pool.runInteraction(
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
|
||||
results: JsonDict = {}
|
||||
for row in txn_results:
|
||||
for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
|
||||
# We want a single event per room, since we want to batch the
|
||||
# receipts by room, event and type.
|
||||
room_event = results.setdefault(
|
||||
row["room_id"],
|
||||
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
|
||||
room_id,
|
||||
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
|
||||
)
|
||||
|
||||
# The content is of the form:
|
||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
||||
event_entry = room_event["content"].setdefault(event_id, {})
|
||||
receipt_type_dict = event_entry.setdefault(receipt_type, {})
|
||||
|
||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
||||
if row["thread_id"]:
|
||||
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
|
||||
receipt_type_dict[user_id] = db_to_json(data)
|
||||
if thread_id:
|
||||
receipt_type_dict[user_id]["thread_id"] = thread_id
|
||||
|
||||
results = {
|
||||
room_id: [results[room_id]] if room_id in results else []
|
||||
|
@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
A dictionary of roomids to a list of receipts.
|
||||
"""
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
|
||||
if from_key:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT room_id, receipt_type, user_id, event_id, data
|
||||
FROM receipts_linearized WHERE
|
||||
stream_id > ? AND stream_id <= ?
|
||||
ORDER BY stream_id DESC
|
||||
LIMIT 100
|
||||
|
@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, [from_key, to_key])
|
||||
else:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT room_id, receipt_type, user_id, event_id, data
|
||||
FROM receipts_linearized WHERE
|
||||
stream_id <= ?
|
||||
ORDER BY stream_id DESC
|
||||
LIMIT 100
|
||||
|
@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, [to_key])
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
|
||||
|
||||
txn_results = await self.db_pool.runInteraction(
|
||||
"get_linearized_receipts_for_all_rooms", f
|
||||
)
|
||||
|
||||
results: JsonDict = {}
|
||||
for row in txn_results:
|
||||
for room_id, receipt_type, user_id, event_id, data in txn_results:
|
||||
# We want a single event per room, since we want to batch the
|
||||
# receipts by room, event and type.
|
||||
room_event = results.setdefault(
|
||||
row["room_id"],
|
||||
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
|
||||
room_id,
|
||||
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
|
||||
)
|
||||
|
||||
# The content is of the form:
|
||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
||||
event_entry = room_event["content"].setdefault(event_id, {})
|
||||
receipt_type_dict = event_entry.setdefault(receipt_type, {})
|
||||
|
||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
||||
receipt_type_dict[user_id] = db_to_json(data)
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Returns info about the user account, if it exists."""
|
||||
|
||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
|
||||
# We could technically use simple_select_one here, but it would not perform
|
||||
# the COALESCEs (unless hacked into the column names), which could yield
|
||||
# confusing results.
|
||||
|
@ -213,34 +213,45 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
(user_id,),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if len(rows) == 0:
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return rows[0]
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
desc="get_user_by_id",
|
||||
func=get_user_by_id_txn,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
(
|
||||
name,
|
||||
is_guest,
|
||||
admin,
|
||||
consent_version,
|
||||
consent_ts,
|
||||
consent_server_notice_sent,
|
||||
appservice_id,
|
||||
creation_ts,
|
||||
user_type,
|
||||
deactivated,
|
||||
shadow_banned,
|
||||
approved,
|
||||
locked,
|
||||
) = row
|
||||
|
||||
return UserInfo(
|
||||
appservice_id=row["appservice_id"],
|
||||
consent_server_notice_sent=row["consent_server_notice_sent"],
|
||||
consent_version=row["consent_version"],
|
||||
consent_ts=row["consent_ts"],
|
||||
creation_ts=row["creation_ts"],
|
||||
is_admin=bool(row["admin"]),
|
||||
is_deactivated=bool(row["deactivated"]),
|
||||
is_guest=bool(row["is_guest"]),
|
||||
is_shadow_banned=bool(row["shadow_banned"]),
|
||||
user_id=UserID.from_string(row["name"]),
|
||||
user_type=row["user_type"],
|
||||
approved=bool(row["approved"]),
|
||||
locked=bool(row["locked"]),
|
||||
appservice_id=appservice_id,
|
||||
consent_server_notice_sent=consent_server_notice_sent,
|
||||
consent_version=consent_version,
|
||||
consent_ts=consent_ts,
|
||||
creation_ts=creation_ts,
|
||||
is_admin=bool(admin),
|
||||
is_deactivated=bool(deactivated),
|
||||
is_guest=bool(is_guest),
|
||||
is_shadow_banned=bool(shadow_banned),
|
||||
user_id=UserID.from_string(name),
|
||||
user_type=user_type,
|
||||
approved=bool(approved),
|
||||
locked=bool(locked),
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_user_by_id",
|
||||
func=get_user_by_id_txn,
|
||||
)
|
||||
|
||||
async def is_trial_user(self, user_id: str) -> bool:
|
||||
|
@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"""
|
||||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
row = txn.fetchone()
|
||||
|
||||
if rows:
|
||||
row = rows[0]
|
||||
if row:
|
||||
(
|
||||
user_id,
|
||||
is_guest,
|
||||
shadow_banned,
|
||||
token_id,
|
||||
device_id,
|
||||
valid_until_ms,
|
||||
token_owner,
|
||||
token_used,
|
||||
) = row
|
||||
|
||||
return TokenLookupResult(
|
||||
user_id=user_id,
|
||||
is_guest=is_guest,
|
||||
shadow_banned=shadow_banned,
|
||||
token_id=token_id,
|
||||
device_id=device_id,
|
||||
valid_until_ms=valid_until_ms,
|
||||
token_owner=token_owner,
|
||||
# This field is nullable, ensure it comes out as a boolean
|
||||
if row["token_used"] is None:
|
||||
row["token_used"] = False
|
||||
|
||||
return TokenLookupResult(**row)
|
||||
token_used=bool(token_used),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
def _count_users(txn: LoggingTransaction) -> int:
|
||||
txn.execute("SELECT COUNT(*) AS users FROM users")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]["users"]
|
||||
return 0
|
||||
txn.execute("SELECT COUNT(*) FROM users")
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||
|
||||
|
@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"""Counts all users without a special user_type registered on the homeserver."""
|
||||
|
||||
def _count_users(txn: LoggingTransaction) -> int:
|
||||
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]["users"]
|
||||
return 0
|
||||
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction("count_real_users", _count_users)
|
||||
|
||||
|
@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
txn.execute(sql, [])
|
||||
|
||||
res = self.db_pool.cursor_to_dict(txn)
|
||||
if res:
|
||||
for user in res:
|
||||
self.set_expiration_date_for_user_txn(
|
||||
txn, user["name"], use_delta=True
|
||||
)
|
||||
for (name,) in txn.fetchall():
|
||||
self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"get_users_with_no_expiration_date",
|
||||
|
@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
(user_id,),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
|
||||
# We cast to bool because the value returned by the database engine might
|
||||
# be an integer if we're using SQLite.
|
||||
return bool(rows[0]["approved"])
|
||||
return bool(row[0])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="is_user_pending_approval",
|
||||
|
@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
(last_user, batch_size),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
|
||||
if not rows:
|
||||
return True, 0
|
||||
|
||||
rows_processed_nb = 0
|
||||
|
||||
for user in rows:
|
||||
if not user["count_tokens"] and not user["count_threepids"]:
|
||||
self.set_user_deactivated_status_txn(txn, user["name"], True)
|
||||
for name, count_tokens, count_threepids in rows:
|
||||
if not count_tokens and not count_threepids:
|
||||
self.set_user_deactivated_status_txn(txn, name, True)
|
||||
rows_processed_nb += 1
|
||||
|
||||
logger.info("Marked %d rows as deactivated", rows_processed_nb)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
|
||||
txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
|
||||
)
|
||||
|
||||
if batch_size > len(rows):
|
||||
|
|
|
@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
def get_retention_policy_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, Optional[int]]]:
|
||||
) -> Optional[Tuple[Optional[int], Optional[int]]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT min_lifetime, max_lifetime FROM room_retention
|
||||
|
@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
(room_id,),
|
||||
)
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
|
||||
|
||||
ret = await self.db_pool.runInteraction(
|
||||
"get_retention_policy_for_room",
|
||||
|
@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
max_lifetime=self.config.retention.retention_default_max_lifetime,
|
||||
)
|
||||
|
||||
min_lifetime = ret[0]["min_lifetime"]
|
||||
max_lifetime = ret[0]["max_lifetime"]
|
||||
min_lifetime, max_lifetime = ret
|
||||
|
||||
# If one of the room's policy's attributes isn't defined, use the matching
|
||||
# attribute from the default policy.
|
||||
|
@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rooms_dict = {}
|
||||
|
||||
for row in rows:
|
||||
rooms_dict[row["room_id"]] = RetentionPolicy(
|
||||
min_lifetime=row["min_lifetime"],
|
||||
max_lifetime=row["max_lifetime"],
|
||||
rooms_dict = {
|
||||
room_id: RetentionPolicy(
|
||||
min_lifetime=min_lifetime,
|
||||
max_lifetime=max_lifetime,
|
||||
)
|
||||
for room_id, min_lifetime, max_lifetime in txn
|
||||
}
|
||||
|
||||
if include_null:
|
||||
# If required, do a second query that retrieves all of the rooms we know
|
||||
|
@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
txn.execute(sql)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
# If a room isn't already in the dict (i.e. it doesn't have a retention
|
||||
# policy in its state), add it with a null policy.
|
||||
for row in rows:
|
||||
if row["room_id"] not in rooms_dict:
|
||||
rooms_dict[row["room_id"]] = RetentionPolicy()
|
||||
for (room_id,) in txn:
|
||||
if room_id not in rooms_dict:
|
||||
rooms_dict[room_id] = RetentionPolicy()
|
||||
|
||||
return rooms_dict
|
||||
|
||||
|
@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
(last_room, batch_size),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
|
||||
if not rows:
|
||||
return True
|
||||
|
||||
for row in rows:
|
||||
if not row["json"]:
|
||||
for room_id, event_id, json in rows:
|
||||
if not json:
|
||||
retention_policy = {}
|
||||
else:
|
||||
ev = db_to_json(row["json"])
|
||||
ev = db_to_json(json)
|
||||
retention_policy = ev["content"]
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn=txn,
|
||||
table="room_retention",
|
||||
values={
|
||||
"room_id": row["room_id"],
|
||||
"event_id": row["event_id"],
|
||||
"room_id": room_id,
|
||||
"event_id": event_id,
|
||||
"min_lifetime": retention_policy.get("min_lifetime"),
|
||||
"max_lifetime": retention_policy.get("max_lifetime"),
|
||||
},
|
||||
|
@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
logger.info("Inserted %d rows into room_retention", len(rows))
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
|
||||
txn, "insert_room_retention", {"room_id": rows[-1][0]}
|
||||
)
|
||||
|
||||
if batch_size > len(rows):
|
||||
|
|
|
@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
min_stream_id = rows[-1]["stream_ordering"]
|
||||
min_stream_id = rows[-1][0]
|
||||
|
||||
to_update = []
|
||||
for row in rows:
|
||||
event_id = row["event_id"]
|
||||
room_id = row["room_id"]
|
||||
for _, event_id, room_id, json in rows:
|
||||
try:
|
||||
event_json = db_to_json(row["json"])
|
||||
event_json = db_to_json(json)
|
||||
content = event_json["content"]
|
||||
except Exception:
|
||||
continue
|
||||
|
|
|
@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
|||
# store_search_entries_txn with a generator function, but that
|
||||
# would mean having two cursors open on the database at once.
|
||||
# Instead we just build a list of results.
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
min_stream_id = rows[-1]["stream_ordering"]
|
||||
min_stream_id = rows[-1][0]
|
||||
|
||||
event_search_rows = []
|
||||
for row in rows:
|
||||
for (
|
||||
stream_ordering,
|
||||
event_id,
|
||||
room_id,
|
||||
etype,
|
||||
json,
|
||||
origin_server_ts,
|
||||
) in rows:
|
||||
try:
|
||||
event_id = row["event_id"]
|
||||
room_id = row["room_id"]
|
||||
etype = row["type"]
|
||||
stream_ordering = row["stream_ordering"]
|
||||
origin_server_ts = row["origin_server_ts"]
|
||||
try:
|
||||
event_json = db_to_json(row["json"])
|
||||
event_json = db_to_json(json)
|
||||
content = event_json["content"]
|
||||
except Exception:
|
||||
continue
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
|
||||
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
|
@ -27,6 +27,8 @@ from synapse.util import json_encoder
|
|||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
|
||||
|
||||
|
||||
class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
|
@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
super().__init__(database, db_conn, hs)
|
||||
|
||||
@staticmethod
|
||||
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
|
||||
row["status"] = TaskStatus(row["status"])
|
||||
if row["params"] is not None:
|
||||
row["params"] = db_to_json(row["params"])
|
||||
if row["result"] is not None:
|
||||
row["result"] = db_to_json(row["result"])
|
||||
return ScheduledTask(**row)
|
||||
def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
|
||||
task_id, action, status, timestamp, resource_id, params, result, error = row
|
||||
return ScheduledTask(
|
||||
id=task_id,
|
||||
action=action,
|
||||
status=TaskStatus(status),
|
||||
timestamp=timestamp,
|
||||
resource_id=resource_id,
|
||||
params=db_to_json(params) if params is not None else None,
|
||||
result=db_to_json(result) if result is not None else None,
|
||||
error=error,
|
||||
)
|
||||
|
||||
async def get_scheduled_tasks(
|
||||
self,
|
||||
|
@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
|
||||
"""
|
||||
|
||||
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
|
||||
clauses: List[str] = []
|
||||
args: List[Any] = []
|
||||
if resource_id:
|
||||
|
@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
args.append(limit)
|
||||
|
||||
txn.execute(sql, args)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[ScheduledTaskRow], txn.fetchall())
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
"get_scheduled_tasks", get_scheduled_tasks_txn
|
||||
|
@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
desc="get_scheduled_task",
|
||||
)
|
||||
|
||||
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
|
||||
return (
|
||||
TaskSchedulerWorkerStore._convert_row_to_task(
|
||||
(
|
||||
row["id"],
|
||||
row["action"],
|
||||
row["status"],
|
||||
row["timestamp"],
|
||||
row["resource_id"],
|
||||
row["params"],
|
||||
row["result"],
|
||||
row["error"],
|
||||
)
|
||||
)
|
||||
if row
|
||||
else None
|
||||
)
|
||||
|
||||
async def delete_scheduled_task(self, id: str) -> None:
|
||||
"""Delete a specific task from its id.
|
||||
|
|
Loading…
Reference in New Issue