Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)
This commit is contained in:
parent
d6b7d49a61
commit
a4904dcb04
|
@ -0,0 +1 @@
|
|||
Reduce memory allocations.
|
|
@ -1874,9 +1874,9 @@ class DatabasePool:
|
|||
keyvalues: Optional[Dict[str, Any]] = None,
|
||||
desc: str = "simple_select_many_batch",
|
||||
batch_size: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
more rows.
|
||||
|
||||
Filters rows by whether the value of `column` is in `iterable`.
|
||||
|
||||
|
@ -1888,10 +1888,13 @@ class DatabasePool:
|
|||
keyvalues: dict of column names and values to select the rows with
|
||||
desc: description of the transaction, for logging and metrics
|
||||
batch_size: the number of rows for each select query
|
||||
|
||||
Returns:
|
||||
The results as a list of tuples.
|
||||
"""
|
||||
keyvalues = keyvalues or {}
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
results: List[Tuple[Any, ...]] = []
|
||||
|
||||
for chunk in batch_iter(iterable, batch_size):
|
||||
rows = await self.runInteraction(
|
||||
|
@ -1918,9 +1921,9 @@ class DatabasePool:
|
|||
iterable: Collection[Any],
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
more rows.
|
||||
|
||||
Filters rows by whether the value of `column` is in `iterable`.
|
||||
|
||||
|
@ -1931,6 +1934,9 @@ class DatabasePool:
|
|||
iterable: list
|
||||
keyvalues: dict of column names and values to select the rows with
|
||||
retcols: list of strings giving the names of the columns to return
|
||||
|
||||
Returns:
|
||||
The results as a list of tuples.
|
||||
"""
|
||||
if not iterable:
|
||||
return []
|
||||
|
@ -1949,7 +1955,7 @@ class DatabasePool:
|
|||
)
|
||||
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
return txn.fetchall()
|
||||
|
||||
async def simple_update(
|
||||
self,
|
||||
|
|
|
@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
# Note that this is more efficient than just dropping `device_id` from the query,
|
||||
# since device_inbox has an index on `(user_id, device_id, stream_id)`
|
||||
if not device_ids_to_query:
|
||||
user_device_dicts = self.db_pool.simple_select_many_txn(
|
||||
user_device_dicts = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="devices",
|
||||
column="user_id",
|
||||
iterable=user_ids_to_query,
|
||||
keyvalues={"hidden": False},
|
||||
retcols=("device_id",),
|
||||
),
|
||||
)
|
||||
|
||||
device_ids_to_query.update(
|
||||
{row["device_id"] for row in user_device_dicts}
|
||||
)
|
||||
device_ids_to_query.update({row[0] for row in user_device_dicts})
|
||||
|
||||
if not device_ids_to_query:
|
||||
# We've ended up with no devices to query.
|
||||
|
@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
|
||||
# We exclude hidden devices (such as cross-signing keys) here as they are
|
||||
# not expected to receive to-device messages.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "hidden": False},
|
||||
column="device_id",
|
||||
iterable=devices,
|
||||
retcols=("device_id",),
|
||||
),
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
for (device_id,) in rows:
|
||||
# Only insert into the local inbox if the device exists on
|
||||
# this server
|
||||
device_id = row["device_id"]
|
||||
|
||||
with start_active_span("serialise_to_device_message"):
|
||||
msg = messages_by_device[device_id]
|
||||
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
|
||||
|
|
|
@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
async def get_device_list_last_stream_id_for_remotes(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Mapping[str, Optional[str]]:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id", "stream_id"),
|
||||
desc="get_device_list_last_stream_id_for_remotes",
|
||||
),
|
||||
)
|
||||
|
||||
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
||||
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
||||
results.update(rows)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -1077,19 +1080,27 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
The IDs of users whose device lists need resync.
|
||||
"""
|
||||
if user_ids:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
row_tuples = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_resync",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id",),
|
||||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||
),
|
||||
)
|
||||
|
||||
return {row[0] for row in row_tuples}
|
||||
else:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
rows = cast(
|
||||
List[Dict[str, str]],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="device_lists_remote_resync",
|
||||
keyvalues=None,
|
||||
retcols=("user_id",),
|
||||
desc="get_user_ids_requiring_device_list_resync",
|
||||
),
|
||||
)
|
||||
|
||||
return {row["user_id"] for row in rows}
|
||||
|
|
|
@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
A map from (algorithm, key_id) to json string for key
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=key_ids,
|
||||
retcols=("algorithm", "key_id", "key_json"),
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="add_e2e_one_time_keys_check",
|
||||
),
|
||||
)
|
||||
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
|
||||
result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
|
||||
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
||||
return result
|
||||
|
||||
|
|
|
@ -1049,7 +1049,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
Args:
|
||||
event_ids: The event IDs to calculate the max depth of.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
@ -1058,6 +1060,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"depth",
|
||||
),
|
||||
desc="get_max_depth_of",
|
||||
),
|
||||
)
|
||||
|
||||
if not rows:
|
||||
|
@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
else:
|
||||
max_depth_event_id = ""
|
||||
current_max_depth = 0
|
||||
for row in rows:
|
||||
if row["depth"] > current_max_depth:
|
||||
max_depth_event_id = row["event_id"]
|
||||
current_max_depth = row["depth"]
|
||||
for event_id, depth in rows:
|
||||
if depth > current_max_depth:
|
||||
max_depth_event_id = event_id
|
||||
current_max_depth = depth
|
||||
|
||||
return max_depth_event_id, current_max_depth
|
||||
|
||||
|
@ -1078,7 +1081,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
Args:
|
||||
event_ids: The event IDs to calculate the max depth of.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
@ -1087,6 +1092,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"depth",
|
||||
),
|
||||
desc="get_min_depth_of",
|
||||
),
|
||||
)
|
||||
|
||||
if not rows:
|
||||
|
@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
else:
|
||||
min_depth_event_id = ""
|
||||
current_min_depth = MAX_DEPTH
|
||||
for row in rows:
|
||||
if row["depth"] < current_min_depth:
|
||||
min_depth_event_id = row["event_id"]
|
||||
current_min_depth = row["depth"]
|
||||
for event_id, depth in rows:
|
||||
if depth < current_min_depth:
|
||||
min_depth_event_id = event_id
|
||||
current_min_depth = depth
|
||||
|
||||
return min_depth_event_id, current_min_depth
|
||||
|
||||
|
@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
A filtered down list of `event_ids` that have previous failed pull attempts.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="event_failed_pull_attempts",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
keyvalues={},
|
||||
retcols=("event_id",),
|
||||
desc="get_event_ids_with_failed_pull_attempts",
|
||||
),
|
||||
)
|
||||
event_ids_with_failed_pull_attempts: Set[str] = {
|
||||
row["event_id"] for row in rows
|
||||
}
|
||||
|
||||
return event_ids_with_failed_pull_attempts
|
||||
return {row[0] for row in rows}
|
||||
|
||||
@trace
|
||||
async def get_event_ids_to_not_pull_from_backoff(
|
||||
|
@ -1585,7 +1590,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
A dictionary of event_ids that should not be attempted to be pulled and the
|
||||
next timestamp at which we may try pulling them again.
|
||||
"""
|
||||
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
|
||||
event_failed_pull_attempts = cast(
|
||||
List[Tuple[str, int, int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="event_failed_pull_attempts",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
@ -1596,21 +1603,21 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"num_attempts",
|
||||
),
|
||||
desc="get_event_ids_to_not_pull_from_backoff",
|
||||
),
|
||||
)
|
||||
|
||||
current_time = self._clock.time_msec()
|
||||
|
||||
event_ids_with_backoff = {}
|
||||
for event_failed_pull_attempt in event_failed_pull_attempts:
|
||||
event_id = event_failed_pull_attempt["event_id"]
|
||||
for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
|
||||
# Exponential back-off (up to the upper bound) so we don't try to
|
||||
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
||||
backoff_end_time = (
|
||||
event_failed_pull_attempt["last_attempt_ts"]
|
||||
last_attempt_ts
|
||||
+ (
|
||||
2
|
||||
** min(
|
||||
event_failed_pull_attempt["num_attempts"],
|
||||
num_attempts,
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -501,16 +502,19 @@ class PersistEventsStore:
|
|||
|
||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||
# for.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="rooms",
|
||||
column="room_id",
|
||||
iterable={event.room_id for event in events if event.is_state()},
|
||||
keyvalues={},
|
||||
retcols=("room_id", "has_auth_chain_index"),
|
||||
),
|
||||
)
|
||||
rooms_using_chain_index = {
|
||||
row["room_id"] for row in rows if row["has_auth_chain_index"]
|
||||
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
|
||||
}
|
||||
|
||||
state_events = {
|
||||
|
@ -571,19 +575,18 @@ class PersistEventsStore:
|
|||
# We check if there are any events that need to be handled in the rooms
|
||||
# we're looking at. These should just be out of band memberships, where
|
||||
# we didn't have the auth chain when we first persisted.
|
||||
rows = db_pool.simple_select_many_txn(
|
||||
auth_chain_to_calc_rows = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="room_id",
|
||||
iterable=set(event_to_room_id.values()),
|
||||
retcols=("event_id", "type", "state_key"),
|
||||
),
|
||||
)
|
||||
for row in rows:
|
||||
event_id = row["event_id"]
|
||||
event_type = row["type"]
|
||||
state_key = row["state_key"]
|
||||
|
||||
for event_id, event_type, state_key in auth_chain_to_calc_rows:
|
||||
# (We could pull out the auth events for all rows at once using
|
||||
# simple_select_many, but this case happens rarely and almost always
|
||||
# with a single row.)
|
||||
|
@ -753,7 +756,9 @@ class PersistEventsStore:
|
|||
# Step 1, fetch all existing links from all the chains we've seen
|
||||
# referenced.
|
||||
chain_links = _LinkMap()
|
||||
rows = db_pool.simple_select_many_txn(
|
||||
auth_chain_rows = cast(
|
||||
List[Tuple[int, int, int, int]],
|
||||
db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_links",
|
||||
column="origin_chain_id",
|
||||
|
@ -765,11 +770,17 @@ class PersistEventsStore:
|
|||
"target_chain_id",
|
||||
"target_sequence_number",
|
||||
),
|
||||
),
|
||||
)
|
||||
for row in rows:
|
||||
for (
|
||||
origin_chain_id,
|
||||
origin_sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
) in auth_chain_rows:
|
||||
chain_links.add_link(
|
||||
(row["origin_chain_id"], row["origin_sequence_number"]),
|
||||
(row["target_chain_id"], row["target_sequence_number"]),
|
||||
(origin_chain_id, origin_sequence_number),
|
||||
(target_chain_id, target_sequence_number),
|
||||
new=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
|
||||
for chunk in chunks:
|
||||
ev_rows = self.db_pool.simple_select_many_txn(
|
||||
ev_rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
column="event_id",
|
||||
iterable=chunk,
|
||||
retcols=["event_id", "json"],
|
||||
keyvalues={},
|
||||
),
|
||||
)
|
||||
|
||||
for row in ev_rows:
|
||||
event_id = row["event_id"]
|
||||
event_json = db_to_json(row["json"])
|
||||
for event_id, json in ev_rows:
|
||||
event_json = db_to_json(json)
|
||||
try:
|
||||
origin_server_ts = event_json["origin_server_ts"]
|
||||
except (KeyError, AttributeError):
|
||||
|
@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
if deleted:
|
||||
# We now need to invalidate the caches of these rooms
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=to_delete,
|
||||
keyvalues={},
|
||||
retcols=("room_id",),
|
||||
),
|
||||
)
|
||||
room_ids = {row["room_id"] for row in rows}
|
||||
room_ids = {row[0] for row in rows}
|
||||
for room_id in room_ids:
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
|
||||
|
@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
count = len(rows)
|
||||
|
||||
# We also need to fetch the auth events for them.
|
||||
auth_events = self.db_pool.simple_select_many_txn(
|
||||
auth_events = cast(
|
||||
List[Tuple[str, str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
column="event_id",
|
||||
iterable=event_to_room_id,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "auth_id"),
|
||||
),
|
||||
)
|
||||
|
||||
event_to_auth_chain: Dict[str, List[str]] = {}
|
||||
for row in auth_events:
|
||||
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
|
||||
for event_id, auth_id in auth_events:
|
||||
event_to_auth_chain.setdefault(event_id, []).append(auth_id)
|
||||
|
||||
# Calculate and persist the chain cover index for this set of events.
|
||||
#
|
||||
|
|
|
@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
),
|
||||
)
|
||||
|
||||
return {r["event_id"] for r in rows}
|
||||
return {r[0] for r in rows}
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
|
@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
a dict mapping from event id to partial-stateness. We return True for
|
||||
any of the events which are unknown (or are outliers).
|
||||
"""
|
||||
result = await self.db_pool.simple_select_many_batch(
|
||||
result = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="partial_state_events",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
retcols=["event_id"],
|
||||
desc="get_partial_state_events",
|
||||
),
|
||||
)
|
||||
# convert the result to a dict, to make @cachedList work
|
||||
partial = {r["event_id"] for r in result}
|
||||
partial = {r[0] for r in result}
|
||||
return {e_id: e_id in partial for e_id in event_ids}
|
||||
|
||||
@cached()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Iterable, Mapping, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
@ -205,7 +205,9 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
|
||||
If we have multiple entries for a given key ID, returns the most recent.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="server_keys_json",
|
||||
column="key_id",
|
||||
iterable=key_ids,
|
||||
|
@ -218,22 +220,24 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
"key_json",
|
||||
),
|
||||
desc="get_server_keys_json_for_remote",
|
||||
),
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return {}
|
||||
|
||||
# We sort the rows so that the most recently added entry is picked up.
|
||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||
# will stomp over older entries in the dictionary.
|
||||
rows.sort(key=lambda r: r[2])
|
||||
|
||||
return {
|
||||
row["key_id"]: FetchKeyResultForRemote(
|
||||
key_id: FetchKeyResultForRemote(
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
key_json=bytes(row["key_json"]),
|
||||
valid_until_ts=row["ts_valid_until_ms"],
|
||||
added_ts=row["ts_added_ms"],
|
||||
key_json=bytes(key_json),
|
||||
valid_until_ts=ts_valid_until_ms,
|
||||
added_ts=ts_added_ms,
|
||||
)
|
||||
for row in rows
|
||||
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
|
||||
}
|
||||
|
||||
async def get_all_server_keys_json_for_remote(
|
||||
|
@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
|
|||
if not rows:
|
||||
return {}
|
||||
|
||||
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||
# will stomp over older entries in the dictionary.
|
||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||
|
||||
return {
|
||||
|
|
|
@ -261,7 +261,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
async def get_presence_for_users(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Mapping[str, UserPresenceState]:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
# TODO All these columns are nullable, but we don't expect that:
|
||||
# https://github.com/matrix-org/synapse/issues/16467
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
@ -276,12 +280,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
"currently_active",
|
||||
),
|
||||
desc="get_presence_for_users",
|
||||
),
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
row["currently_active"] = bool(row["currently_active"])
|
||||
|
||||
return {row["user_id"]: UserPresenceState(**row) for row in rows}
|
||||
return {
|
||||
user_id: UserPresenceState(
|
||||
user_id=user_id,
|
||||
state=state,
|
||||
last_active_ts=last_active_ts,
|
||||
last_federation_update_ts=last_federation_update_ts,
|
||||
last_user_sync_ts=last_user_sync_ts,
|
||||
status_msg=status_msg,
|
||||
currently_active=bool(currently_active),
|
||||
)
|
||||
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
|
||||
}
|
||||
|
||||
async def should_user_receive_full_presence_with_token(
|
||||
self,
|
||||
|
@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
limit = 100
|
||||
offset = 0
|
||||
while True:
|
||||
# TODO All these columns are nullable, but we don't expect that:
|
||||
# https://github.com/matrix-org/synapse/issues/16467
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _load_rules(
|
||||
rawrules: List[JsonDict],
|
||||
rawrules: List[Tuple[str, int, str, str]],
|
||||
enabled_map: Dict[str, bool],
|
||||
experimental_config: ExperimentalConfig,
|
||||
) -> FilteredPushRules:
|
||||
"""Take the DB rows returned from the DB and convert them into a full
|
||||
`FilteredPushRules` object.
|
||||
|
||||
Args:
|
||||
rawrules: List of tuples of:
|
||||
* rule ID
|
||||
* Priority lass
|
||||
* Conditions (as serialized JSON)
|
||||
* Actions (as serialized JSON)
|
||||
enabled_map: A dictionary of rule ID to a boolean of whether the rule is
|
||||
enabled. This might not include all rule IDs from rawrules.
|
||||
experimental_config: The `experimental_features` section of the Synapse
|
||||
config. (Used to check if various features are enabled.)
|
||||
|
||||
Returns:
|
||||
A new FilteredPushRules object.
|
||||
"""
|
||||
|
||||
ruleslist = [
|
||||
PushRule.from_db(
|
||||
rule_id=rawrule["rule_id"],
|
||||
priority_class=rawrule["priority_class"],
|
||||
conditions=rawrule["conditions"],
|
||||
actions=rawrule["actions"],
|
||||
rule_id=rawrule[0],
|
||||
priority_class=rawrule[1],
|
||||
conditions=rawrule[2],
|
||||
actions=rawrule[3],
|
||||
)
|
||||
for rawrule in rawrules
|
||||
]
|
||||
|
@ -183,7 +197,19 @@ class PushRulesWorkerStore(
|
|||
|
||||
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
||||
return _load_rules(
|
||||
[
|
||||
(
|
||||
row["rule_id"],
|
||||
row["priority_class"],
|
||||
row["conditions"],
|
||||
row["actions"],
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
enabled_map,
|
||||
self.hs.config.experimental,
|
||||
)
|
||||
|
||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||
results = await self.db_pool.simple_select_list(
|
||||
|
@ -221,21 +247,36 @@ class PushRulesWorkerStore(
|
|||
if not user_ids:
|
||||
return {}
|
||||
|
||||
raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
|
||||
raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
|
||||
user_id: [] for user_id in user_ids
|
||||
}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, int, int, str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
retcols=("*",),
|
||||
retcols=(
|
||||
"user_name",
|
||||
"rule_id",
|
||||
"priority_class",
|
||||
"priority",
|
||||
"conditions",
|
||||
"actions",
|
||||
),
|
||||
desc="bulk_get_push_rules",
|
||||
batch_size=1000,
|
||||
),
|
||||
)
|
||||
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
# Sort by highest priority_class, then highest priority.
|
||||
rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
|
||||
|
||||
for row in rows:
|
||||
raw_rules.setdefault(row["user_name"], []).append(row)
|
||||
for user_name, rule_id, priority_class, _, conditions, actions in rows:
|
||||
raw_rules.setdefault(user_name, []).append(
|
||||
(rule_id, priority_class, conditions, actions)
|
||||
)
|
||||
|
||||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
|
@ -256,17 +297,19 @@ class PushRulesWorkerStore(
|
|||
|
||||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, Optional[int]]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
retcols=("user_name", "rule_id", "enabled"),
|
||||
desc="bulk_get_push_rules_enabled",
|
||||
batch_size=1000,
|
||||
),
|
||||
)
|
||||
for row in rows:
|
||||
enabled = bool(row["enabled"])
|
||||
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
|
||||
for user_name, rule_id, enabled in rows:
|
||||
results.setdefault(user_name, {})[rule_id] = bool(enabled)
|
||||
return results
|
||||
|
||||
async def get_all_push_rule_updates(
|
||||
|
|
|
@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
def get_all_relation_ids_for_event_with_types_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[str]:
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
column="relation_type",
|
||||
iterable=relation_types,
|
||||
keyvalues={"relates_to_id": event_id},
|
||||
retcols=["event_id"],
|
||||
),
|
||||
)
|
||||
|
||||
return [row["event_id"] for row in rows]
|
||||
return [row[0] for row in rows]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_all_relation_ids_for_event_with_types",
|
||||
|
|
|
@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
complete.
|
||||
"""
|
||||
|
||||
rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="partial_state_rooms",
|
||||
column="room_id",
|
||||
iterable=room_ids,
|
||||
retcols=("room_id",),
|
||||
desc="is_partial_state_room_batched",
|
||||
),
|
||||
)
|
||||
partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
|
||||
partial_state_rooms = {row[0] for row in rows}
|
||||
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
|
||||
|
||||
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -683,7 +684,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
Map from user_id to set of rooms that is currently in.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="current_state_events",
|
||||
column="state_key",
|
||||
iterable=user_ids,
|
||||
|
@ -696,12 +699,13 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
"membership": Membership.JOIN,
|
||||
},
|
||||
desc="get_rooms_for_users",
|
||||
),
|
||||
)
|
||||
|
||||
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
|
||||
|
||||
for row in rows:
|
||||
user_rooms[row["state_key"]].add(row["room_id"])
|
||||
for state_key, room_id in rows:
|
||||
user_rooms[state_key].add(room_id)
|
||||
|
||||
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
|
||||
|
||||
|
@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
Map from event ID to `user_id`, or None if event is not a join.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
retcols=("user_id", "event_id"),
|
||||
retcols=("event_id", "user_id"),
|
||||
keyvalues={"membership": Membership.JOIN},
|
||||
batch_size=1000,
|
||||
desc="_get_user_ids_from_membership_event_ids",
|
||||
),
|
||||
)
|
||||
|
||||
return {row["event_id"]: row["user_id"] for row in rows}
|
||||
return dict(rows)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
||||
|
@ -1202,7 +1209,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
membership event, otherwise the value is None.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, str, str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=member_event_ids,
|
||||
|
@ -1210,13 +1219,12 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
keyvalues={},
|
||||
batch_size=500,
|
||||
desc="get_membership_from_event_ids",
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
row["event_id"]: EventIdMembership(
|
||||
membership=row["membership"], user_id=row["user_id"]
|
||||
)
|
||||
for row in rows
|
||||
event_id: EventIdMembership(membership=membership, user_id=user_id)
|
||||
for user_id, membership, event_id in rows
|
||||
}
|
||||
|
||||
async def is_local_host_in_room_ignoring_users(
|
||||
|
|
|
@ -20,10 +20,12 @@ from typing import (
|
|||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
Raises:
|
||||
RuntimeError if the state is unknown at any of the given events
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "state_group"),
|
||||
desc="_get_state_group_for_events",
|
||||
),
|
||||
)
|
||||
|
||||
res = {row["event_id"]: row["state_group"] for row in rows}
|
||||
res = dict(rows)
|
||||
for e in event_ids:
|
||||
if e not in res:
|
||||
raise RuntimeError("No state group for unknown or outlier event %s" % e)
|
||||
|
@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
The subset of state groups that are referenced.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="state_group",
|
||||
iterable=state_groups,
|
||||
keyvalues={},
|
||||
retcols=("DISTINCT state_group",),
|
||||
desc="get_referenced_state_groups",
|
||||
),
|
||||
)
|
||||
|
||||
return {row["state_group"] for row in rows}
|
||||
return {row[0] for row in rows}
|
||||
|
||||
async def update_state_for_partial_state_event(
|
||||
self,
|
||||
|
@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||
# potentially stale, since there may have been a period where the
|
||||
# server didn't share a room with the remote user and therefore may
|
||||
# have missed any device updates.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
column="room_id",
|
||||
iterable=to_delete,
|
||||
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
retcols=("state_key",),
|
||||
),
|
||||
)
|
||||
|
||||
potentially_left_users = {row["state_key"] for row in rows}
|
||||
potentially_left_users = {row[0] for row in rows}
|
||||
|
||||
# Now lets actually delete the rooms from the DB.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
|
|
|
@ -506,7 +506,9 @@ class StatsStore(StateDeltasStore):
|
|||
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
|
||||
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
|
||||
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
column="type",
|
||||
|
@ -522,9 +524,10 @@ class StatsStore(StateDeltasStore):
|
|||
],
|
||||
keyvalues={"room_id": room_id, "state_key": ""},
|
||||
retcols=["event_id"],
|
||||
),
|
||||
)
|
||||
|
||||
event_ids = cast(List[str], [row["event_id"] for row in rows])
|
||||
event_ids = [row[0] for row in rows]
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
|
|
|
@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
|||
async def get_destination_retry_timings_batch(
|
||||
self, destinations: StrCollection
|
||||
) -> Mapping[str, Optional[DestinationRetryTimings]]:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="destinations",
|
||||
iterable=destinations,
|
||||
column="destination",
|
||||
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
|
||||
retcols=(
|
||||
"destination",
|
||||
"failure_ts",
|
||||
"retry_last_ts",
|
||||
"retry_interval",
|
||||
),
|
||||
desc="get_destination_retry_timings_batch",
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
row.pop("destination"): DestinationRetryTimings(**row)
|
||||
for row in rows
|
||||
if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
|
||||
destination: DestinationRetryTimings(
|
||||
failure_ts, retry_last_ts, retry_interval
|
||||
)
|
||||
for destination, failure_ts, retry_last_ts, retry_interval in rows
|
||||
if retry_last_ts and failure_ts and retry_interval
|
||||
}
|
||||
|
||||
async def set_destination_retry_timings(
|
||||
|
|
|
@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
|
||||
# If a registration token was used, decrement the pending counter
|
||||
# before deleting the session.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_credentials",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
||||
retcols=["result"],
|
||||
),
|
||||
)
|
||||
|
||||
# Get the tokens used and how much pending needs to be decremented by.
|
||||
|
@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
# registration token stage for that session will be True.
|
||||
# If a token was used to authenticate, but registration was
|
||||
# never completed, the result will be the token used.
|
||||
token = db_to_json(r["result"])
|
||||
token = db_to_json(r[0])
|
||||
if isinstance(token, str):
|
||||
token_counts[token] = token_counts.get(token, 0) + 1
|
||||
|
||||
# Update the `pending` counters.
|
||||
if len(token_counts) > 0:
|
||||
token_rows = self.db_pool.simple_select_many_txn(
|
||||
token_rows = cast(
|
||||
List[Tuple[str, int]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
column="token",
|
||||
iterable=list(token_counts.keys()),
|
||||
keyvalues={},
|
||||
retcols=["token", "pending"],
|
||||
),
|
||||
)
|
||||
for token_row in token_rows:
|
||||
token = token_row["token"]
|
||||
new_pending = token_row["pending"] - token_counts[token]
|
||||
for token, pending in token_rows:
|
||||
new_pending = pending - token_counts[token]
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
|
|
|
@ -410,7 +410,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
)
|
||||
|
||||
# Next fetch their profiles. Note that not all users have profiles.
|
||||
profile_rows = self.db_pool.simple_select_many_txn(
|
||||
profile_rows = cast(
|
||||
List[Tuple[str, Optional[str], Optional[str]]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="profiles",
|
||||
column="full_user_id",
|
||||
|
@ -421,14 +423,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
"avatar_url",
|
||||
),
|
||||
keyvalues={},
|
||||
),
|
||||
)
|
||||
profiles = {
|
||||
row["full_user_id"]: _UserDirProfile(
|
||||
row["full_user_id"],
|
||||
row["displayname"],
|
||||
row["avatar_url"],
|
||||
)
|
||||
for row in profile_rows
|
||||
full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
|
||||
for full_user_id, displayname, avatar_url in profile_rows
|
||||
}
|
||||
|
||||
profiles_to_insert = [
|
||||
|
@ -517,7 +516,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
|
||||
]
|
||||
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[str]]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="users",
|
||||
column="name",
|
||||
|
@ -526,9 +527,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
"deactivated": 0,
|
||||
},
|
||||
retcols=("name", "user_type"),
|
||||
),
|
||||
)
|
||||
|
||||
return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
|
||||
return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
|
||||
|
||||
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
|
||||
"""Check if the room is either world_readable or publically joinable"""
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Iterable, Mapping
|
||||
from typing import Iterable, List, Mapping, Tuple, cast
|
||||
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
|
@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
|
|||
Returns:
|
||||
for each user, whether the user has requested erasure.
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[str]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="erased_users",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id",),
|
||||
desc="are_users_erased",
|
||||
),
|
||||
)
|
||||
erased_users = {row["user_id"] for row in rows}
|
||||
erased_users = {row[0] for row in rows}
|
||||
|
||||
return {u: u in erased_users for u in user_ids}
|
||||
|
||||
|
|
|
@ -13,7 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
"[purge] found %i state groups to delete", len(state_groups_to_delete)
|
||||
)
|
||||
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
rows = cast(
|
||||
List[Tuple[int]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
column="prev_state_group",
|
||||
iterable=state_groups_to_delete,
|
||||
keyvalues={},
|
||||
retcols=("state_group",),
|
||||
),
|
||||
)
|
||||
|
||||
remaining_state_groups = {
|
||||
row["state_group"]
|
||||
for row in rows
|
||||
if row["state_group"] not in state_groups_to_delete
|
||||
state_group
|
||||
for state_group, in rows
|
||||
if state_group not in state_groups_to_delete
|
||||
}
|
||||
|
||||
logger.info(
|
||||
|
@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
A mapping from state group to previous state group.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
rows = cast(
|
||||
List[Tuple[int, int]],
|
||||
await self.db_pool.simple_select_many_batch(
|
||||
table="state_group_edges",
|
||||
column="prev_state_group",
|
||||
iterable=state_groups,
|
||||
keyvalues={},
|
||||
retcols=("prev_state_group", "state_group"),
|
||||
retcols=("state_group", "prev_state_group"),
|
||||
desc="get_previous_state_groups",
|
||||
),
|
||||
)
|
||||
|
||||
return {row["state_group"]: row["prev_state_group"] for row in rows}
|
||||
return dict(rows)
|
||||
|
||||
async def purge_room_state(
|
||||
self, room_id: str, state_groups_to_delete: Collection[int]
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Dict, List, Set, Tuple, cast
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.trial import unittest
|
||||
|
@ -421,7 +421,9 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
|||
self, events: List[EventBase]
|
||||
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
|
||||
# Fetch the map from event ID -> (chain ID, sequence number)
|
||||
rows = self.get_success(
|
||||
rows = cast(
|
||||
List[Tuple[str, int, int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_many_batch(
|
||||
table="event_auth_chains",
|
||||
column="event_id",
|
||||
|
@ -429,14 +431,18 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
|||
retcols=("event_id", "chain_id", "sequence_number"),
|
||||
keyvalues={},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
chain_map = {
|
||||
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
|
||||
event_id: (chain_id, sequence_number)
|
||||
for event_id, chain_id, sequence_number in rows
|
||||
}
|
||||
|
||||
# Fetch all the links and pass them to the _LinkMap.
|
||||
rows = self.get_success(
|
||||
auth_chain_rows = cast(
|
||||
List[Tuple[int, int, int, int]],
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_select_many_batch(
|
||||
table="event_auth_chain_links",
|
||||
column="origin_chain_id",
|
||||
|
@ -449,13 +455,19 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
|||
),
|
||||
keyvalues={},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
link_map = _LinkMap()
|
||||
for row in rows:
|
||||
for (
|
||||
origin_chain_id,
|
||||
origin_sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
) in auth_chain_rows:
|
||||
added = link_map.add_link(
|
||||
(row["origin_chain_id"], row["origin_sequence_number"]),
|
||||
(row["target_chain_id"], row["target_sequence_number"]),
|
||||
(origin_chain_id, origin_sequence_number),
|
||||
(target_chain_id, target_sequence_number),
|
||||
)
|
||||
|
||||
# We shouldn't have persisted any redundant links
|
||||
|
|
Loading…
Reference in New Issue