Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)

This should use fewer allocations and improves type hints.
This commit is contained in:
Patrick Cloke 2023-10-26 13:01:36 -04:00 committed by GitHub
parent c14a7de6af
commit 9407d5ba78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 607 additions and 507 deletions

1
changelog.d/16505.misc Normal file
View File

@ -0,0 +1 @@
Reduce memory allocations.

View File

@ -103,10 +103,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids:
for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?

View File

@ -1206,10 +1206,7 @@ class SsoHandler:
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]
for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:

View File

@ -606,13 +606,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again.
"""
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
updates = cast(
List[Tuple[str]],
await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
),
)
background_update_names = [x["update_name"] for x in updates]
background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
@ -1804,9 +1807,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
table: the table name
@ -1817,8 +1820,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
@ -1836,9 +1838,9 @@ class DatabasePool:
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
@ -1849,8 +1851,7 @@ class DatabasePool:
retcols: the names of the columns to return
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@ -1863,7 +1864,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
return cls.cursor_to_dict(txn)
return txn.fetchall()
async def simple_select_many_batch(
self,

View File

@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
) -> Dict[str, JsonMapping]:
rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_list_txn(
txn,
table="room_account_data",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=["account_data_type", "content"],
),
)
return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in rows
}
return await self.db_pool.runInteraction(

View File

@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state.value}, ["as_id"]
results = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="application_services_state",
keyvalues={"state": state.value},
retcols=("as_id",),
),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
for res in results:
for (as_id,) in results:
for service in as_list:
if service.id == res["as_id"]:
if service.id == as_id:
services.append(service)
return services

View File

@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None:
keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
res = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
)
return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=ip,
user_agent=user_agent,
last_seen=last_seen,
)
for d in res
for user_id, ip, user_agent, device_id, last_seen in res
}
async def _get_user_ip_and_agents_from_database(

View File

@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
async def get_devices_by_user(
self, user_id: str
) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device.
and "display_name" for each device. Display name may be null.
"""
devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
devices = cast(
List[Tuple[str, str, Optional[str]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
),
)
return {d["device_id"]: d for d in devices}
return {
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
for d in devices
}
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
),
)
@trace
@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
devices = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
),
)
return {
device["device_id"]: db_to_json(device["content"]) for device in devices
}
return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
row_tuples = cast(
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
return {row[0] for row in row_tuples}
else:
rows = cast(
List[Dict[str, str]],
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
return {row["user_id"] for row in rows}
return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"user_id",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
rows = cast(
List[Tuple[str, str, int, int, int, str]],
await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_keys",
),
desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
for (
room_id,
session_id,
first_message_index,
forwarded_count,
is_verified,
session_data,
) in rows:
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
room_entry["sessions"][session_id] = {
"first_message_index": first_message_index,
"forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": db_to_json(row["session_data"]),
"is_verified": bool(is_verified),
"session_data": db_to_json(session_data),
}
return sessions

View File

@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
rows = await self.db_pool.simple_select_list(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
retcols=("event_id", "event_json"),
desc="prune_staged_events_in_room_fetch",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
retcols=("event_id", "event_json"),
desc="prune_staged_events_in_room_fetch",
),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
for row in rows:
event_id = row["event_id"]
for event_id, event_json in rows:
seen_events.add(event_id)
event_d = db_to_json(row["event_json"])
event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, FrozenSet
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
enabled = await self.db_pool.simple_select_list(
"per_user_experimental_features",
{"user_id": user_id, "enabled": True},
["feature"],
enabled = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="per_user_experimental_features",
keyvalues={"user_id": user_id, "enabled": True},
retcols=("feature",),
),
)
return frozenset(feature["feature"] for feature in enabled)
return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,

View File

@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
rows = cast(
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"])
rows.sort(key=lambda r: r[2])
return {
row["key_id"]: FetchKeyResultForRemote(
key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
key_json=bytes(key_json),
valid_until_ts=ts_valid_until_ms,
added_ts=ts_added_ms,
)
for row in rows
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}

View File

@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
),
desc="get_local_media_thumbnails",
),
desc="get_local_media_thumbnails",
)
return [
ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
),
desc="get_remote_media_thumbnails",
),
desc="get_remote_media_thumbnails",
)
return [
ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]

View File

@ -179,46 +179,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
rows = cast(
List[Tuple[str, int, int, str, str]],
await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="get_push_rules_for_user",
),
desc="get_push_rules_for_user",
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
# Sort by highest priority_class, then highest priority.
rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(
[
(
row["rule_id"],
row["priority_class"],
row["conditions"],
row["actions"],
)
for row in rows
],
[(row[0], row[1], row[3], row[4]) for row in rows],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
results = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
),
)
return {r["rule_id"]: bool(r["enabled"]) for r in results}
return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int

View File

@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
async def get_throttle_params_by_room(
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room",
res = cast(
List[Tuple[str, Optional[int], Optional[int]]],
await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room",
),
)
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = ThrottleParams(
row["last_sent_ts"],
row["throttle_ms"],
for room_id, last_sent_ts, throttle_ms in res:
params_by_room[room_id] = ThrottleParams(
last_sent_ts or 0, throttle_ms or 0
)
return params_by_room

View File

@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
res = await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
),
)
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
results = await self.db_pool.simple_select_list(
"user_threepids",
keyvalues={"user_id": user_id},
retcols=["medium", "address", "validated_at", "added_at"],
desc="user_get_threepids",
results = cast(
List[Tuple[str, str, int, int]],
await self.db_pool.simple_select_list(
"user_threepids",
keyvalues={"user_id": user_id},
retcols=["medium", "address", "validated_at", "added_at"],
desc="user_get_threepids",
),
)
return [ThreepidResult(**r) for r in results]
return [
ThreepidResult(
medium=r[0],
address=r[1],
validated_at=r[2],
added_at=r[3],
)
for r in results
]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
List of tuples of two strings:
medium: The medium of the threepid (e.g "email")
address: The address of the threepid (e.g "bob@example.com")
"""
return await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
desc="user_get_bound_threepids",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
desc="user_get_bound_threepids",
),
)
async def remove_user_bound_threepid(

View File

@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
rows = self.db_pool.simple_select_list_txn(
txn=txn,
table="event_relations",
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_list_txn(
txn=txn,
table="event_relations",
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
),
)
return [row["event_id"] for row in rows]
return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",

View File

@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
rows = await self.db_pool.simple_select_list(
table="partial_state_rooms",
keyvalues={},
retcols=("room_id", "joined_via"),
desc="get_server_which_served_partial_join",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="partial_state_rooms",
keyvalues={},
retcols=("room_id", "joined_via"),
desc="get_server_which_served_partial_join",
),
)
for row in rows:
room_id = row["room_id"]
joined_via = row["joined_via"]
for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
rows = await self.db_pool.simple_select_list(
"partial_state_rooms_servers",
keyvalues=None,
retcols=("room_id", "server_name"),
desc="get_partial_state_rooms",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
"partial_state_rooms_servers",
keyvalues=None,
retcols=("room_id", "server_name"),
desc="get_partial_state_rooms",
),
)
for row in rows:
room_id = row["room_id"]
server_name = row["server_name"]
for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in

View File

@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
rows = await self.db_pool.simple_select_list(
"current_state_events",
keyvalues={"room_id": room_id},
retcols=("event_id", "membership"),
desc="has_completed_background_updates",
rows = cast(
List[Tuple[str, Optional[str]]],
await self.db_pool.simple_select_list(
"current_state_events",
keyvalues={"room_id": room_id},
retcols=("event_id", "membership"),
desc="has_completed_background_updates",
),
)
return {row["event_id"]: row["membership"] for row in rows}
return dict(rows)
# TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]

View File

@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
for room_id, tag, content in rows:
room_tags = tags_by_room.setdefault(room_id, {})
room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
),
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}
return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict

View File

@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
for row in await self.db_pool.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
results[row["stage_type"]] = db_to_json(row["result"])
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
),
)
for stage_type, result in rows:
results[stage_type] = db_to_json(result)
return results
@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
rows = await self.db_pool.simple_select_list(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id},
retcols=("user_agent", "ip"),
desc="get_user_agents_ips_to_ui_auth_session",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id},
retcols=("user_agent", "ip"),
desc="get_user_agents_ips_to_ui_auth_session",
),
)
return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""

View File

@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
delta_ids = self.db_pool.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
retcols=("type", "state_key", "event_id"),
delta_ids = cast(
List[Tuple[str, str, str]],
self.db_pool.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
retcols=("type", "state_key", "event_id"),
),
)
return _GetStateGroupDelta(
prev_group,
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
{
(event_type, state_key): event_id
for event_type, state_key, event_id in delta_ids
},
)
return await self.db_pool.runInteraction(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
async def get_all_room_state(self) -> List[Dict[str, Any]]:
return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
async def get_all_room_state(self) -> List[Optional[str]]:
rows = cast(
List[Tuple[Optional[str]]],
await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("topic",)
),
)
return [r[0] for r in rows]
def _get_current_stats(
self, stats_type: str, stat_id: str
@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
self.assertEqual(r[0], "foo")
def test_create_user(self) -> None:
"""

View File

@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
columns += expected_row.keys()
rows = self.get_success(
row_tuples = self.get_success(
self.store.db_pool.simple_select_list(
table=table,
keyvalues={
@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
self.assertEqual(
len(rows),
len(row_tuples),
1,
f"Background update did not leave behind latest receipt in {table}",
)
self.assertEqual(
rows[0],
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
**expected_row,
},
row_tuples[0],
(
room_id,
receipt_type,
user_id,
*expected_row.values(),
),
)
else:
self.assertEqual(
len(rows),
len(row_tuples),
0,
f"Background update did not remove all duplicate receipts from {table}",
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
import secrets
from typing import Generator, Tuple
from typing import Generator, List, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
)
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
res = self.get_success(
self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
yield from cast(
List[Tuple[int, str, str]],
self.get_success(
self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
),
)
for i in res:
yield (i["id"], i["username"], i["value"])
def test_upsert_many(self) -> None:
"""
Upsert_many will perform the upsert operation across a batch of data.

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple, cast
from unittest.mock import AsyncMock, Mock
import yaml
@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
rows = self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
rows = cast(
List[Tuple[int, int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
),
)
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(
@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
rows = self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
rows = cast(
List[Tuple[int, int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
),
)
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(

View File

@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield defer.ensureDeferred(
@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.assertEqual([(1,), (2,), (3,)], ret)
self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple, cast
from unittest.mock import AsyncMock
from parameterized import parameterized
@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
self.pump(0)
result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents",
)
),
)
self.assertEqual(
result,
[
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": None,
"last_seen": 12345678000,
}
],
result, [("access_token", "ip", "user_agent", None, 12345678000)]
)
# Add another & trigger the storage loop
@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
self.pump(0)
result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents",
)
),
)
# Only one result, has been upserted.
self.assertEqual(
result,
[
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": None,
"last_seen": 12345878000,
}
],
result, [("access_token", "ip", "user_agent", None, 12345878000)]
)
@parameterized.expand([(False,), (True,)])
@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
else:
# Check that the new IP and user agent has not been stored yet
db_result = self.get_success(
self.store.db_pool.simple_select_list(
table="devices",
keyvalues={},
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
db_result = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
self.get_success(
self.store.db_pool.simple_select_list(
table="devices",
keyvalues={},
retcols=(
"user_id",
"ip",
"user_agent",
"device_id",
"last_seen",
),
),
),
)
self.assertEqual(
db_result,
[
{
"user_id": user_id,
"device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
},
],
)
self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
db_result = self.get_success(
self.store.db_pool.simple_select_list(
table="devices",
keyvalues={},
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
db_result = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
self.get_success(
self.store.db_pool.simple_select_list(
table="devices",
keyvalues={},
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
),
)
self.assertCountEqual(
db_result,
[
{
"user_id": user_id,
"device_id": device_id_1,
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
{
"user_id": user_id,
"device_id": device_id_2,
"ip": "ip_2",
"user_agent": "user_agent_2",
"last_seen": 12345678000,
},
(user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
(user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
],
)
@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
db_result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={},
retcols=("access_token", "ip", "user_agent", "last_seen"),
db_result = cast(
List[Tuple[str, str, str, int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={},
retcols=("access_token", "ip", "user_agent", "last_seen"),
),
),
)
self.assertEqual(
db_result,
[
{
"access_token": "access_token",
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
{
"access_token": "access_token",
"ip": "ip_2",
"user_agent": "user_agent_2",
"last_seen": 12345678000,
},
("access_token", "ip_1", "user_agent_1", 12345678000),
("access_token", "ip_2", "user_agent_2", 12345678000),
],
)
@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents",
)
),
)
self.assertEqual(
result,
[
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": device_id,
"last_seen": 0,
}
],
[("access_token", "ip", "user_agent", device_id, 0)],
)
# Now advance by a couple of months
self.reactor.advance(60 * 24 * 60 * 60)
# We should get no results.
result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents",
)
),
)
self.assertEqual(result, [])
@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
result = self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={},
retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents",
)
),
)
# ensure user1 is filtered out
self.assertEqual(
result,
[
{
"access_token": access_token2,
"ip": "ip",
"user_agent": "user_agent",
"device_id": device_id2,
"last_seen": 0,
}
],
)
self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
class ClientIpAuthTestCase(unittest.HomeserverTestCase):

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import Membership
@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
res = self.get_success(
self.store.db_pool.simple_select_list(
"room_memberships",
{"user_id": "@alice:test"},
["display_name", "event_id"],
)
res = cast(
List[Tuple[Optional[str], str]],
self.get_success(
self.store.db_pool.simple_select_list(
"room_memberships",
{"user_id": "@alice:test"},
["display_name", "event_id"],
)
),
)
# Check that we only got one result back
self.assertEqual(len(res), 1)
# Check that alice's display name is "alice"
self.assertEqual(res[0]["display_name"], "alice")
self.assertEqual(res[0][0], "alice")
# Grab the event_id to use later
event_id = res[0]["event_id"]
event_id = res[0][1]
# Create a profile with the offending null byte in the display name
new_profile = {"displayname": "ali\u0000ce"}
@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
tok=self.t_alice,
)
res2 = self.get_success(
self.store.db_pool.simple_select_list(
"room_memberships",
{"user_id": "@alice:test"},
["display_name", "event_id"],
)
res2 = cast(
List[Tuple[Optional[str], str]],
self.get_success(
self.store.db_pool.simple_select_list(
"room_memberships",
{"user_id": "@alice:test"},
["display_name", "event_id"],
)
),
)
# Check that we only have two results
self.assertEqual(len(res2), 2)
# Filter out the previous event using the event_id we grabbed above
row = [row for row in res2 if row["event_id"] != event_id]
row = [row for row in res2 if row[1] != event_id]
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
self.assertIsNone(row[0][0])
def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten."""

View File

@ -13,6 +13,7 @@
# limitations under the License.
import logging
from typing import List, Tuple, cast
from immutabledict import immutabledict
@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
)
# check that only state events are in state_groups, and all state events are in state_groups
res = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups",
keyvalues=None,
retcols=("event_id",),
)
res = cast(
List[Tuple[str]],
self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups",
keyvalues=None,
retcols=("event_id",),
)
),
)
events = []
for result in res:
self.assertNotIn(event3.event_id, result)
events.append(result.get("event_id"))
self.assertNotIn(event3.event_id, result) # XXX
events.append(result[0])
for event, _ in processed_events_and_context:
if event.is_state():
@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
state = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups_state",
keyvalues={"state_group": context.state_group_after_event},
retcols=("type", "state_key"),
)
state = cast(
List[Tuple[str, str]],
self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups_state",
keyvalues={"state_group": context.state_group_after_event},
retcols=("type", "state_key"),
)
),
)
self.assertEqual(event.type, state[0].get("type"))
self.assertEqual(event.state_key, state[0].get("state_key"))
self.assertEqual(event.type, state[0][0])
self.assertEqual(event.state_key, state[0][1])
groups = self.get_success(
self.store.db_pool.simple_select_list(
table="state_group_edges",
keyvalues={"state_group": str(context.state_group_after_event)},
retcols=("*",),
)
)
self.assertEqual(
context.state_group_before_event, groups[0].get("prev_state_group")
groups = cast(
List[Tuple[str]],
self.get_success(
self.store.db_pool.simple_select_list(
table="state_group_edges",
keyvalues={
"state_group": str(context.state_group_after_event)
},
retcols=("prev_state_group",),
)
),
)
self.assertEqual(context.state_group_before_event, groups[0][0])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, cast
from unittest import mock
from unittest.mock import Mock, patch
@ -62,14 +62,13 @@ class GetUserDirectoryTables:
Returns a list of tuples (user_id, room_id) where room_id is public and
contains the user with the given id.
"""
r = await self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
r = cast(
List[Tuple[str, str]],
await self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
),
)
retval = set()
for i in r:
retval.add((i["user_id"], i["room_id"]))
return retval
return set(r)
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table.
@ -78,27 +77,30 @@ class GetUserDirectoryTables:
to the rows of `users_who_share_private_rooms`.
"""
rows = await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
rows = cast(
List[Tuple[str, str, str]],
await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
),
)
rv = set()
for row in rows:
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
return rv
return set(rows)
async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table.
This is useful when checking we've correctly excluded users from the directory.
"""
result = await self.store.db_pool.simple_select_list(
"user_directory",
None,
["user_id"],
result = cast(
List[Tuple[str]],
await self.store.db_pool.simple_select_list(
"user_directory",
None,
["user_id"],
),
)
return {row["user_id"] for row in result}
return {row[0] for row in result}
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
"""Fetch users and their profiles from the `user_directory` table.
@ -107,16 +109,17 @@ class GetUserDirectoryTables:
It's almost the entire contents of the `user_directory` table: the only
thing missing is an unused room_id column.
"""
rows = await self.store.db_pool.simple_select_list(
"user_directory",
None,
("user_id", "display_name", "avatar_url"),
rows = cast(
List[Tuple[str, Optional[str], Optional[str]]],
await self.store.db_pool.simple_select_list(
"user_directory",
None,
("user_id", "display_name", "avatar_url"),
),
)
return {
row["user_id"]: ProfileInfo(
display_name=row["display_name"], avatar_url=row["avatar_url"]
)
for row in rows
user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
for user_id, display_name, avatar_url in rows
}
async def get_tables(