Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)
This should use fewer allocations and improves type hints.
This commit is contained in:
parent
c14a7de6af
commit
9407d5ba78
|
@ -0,0 +1 @@
|
||||||
|
Reduce memory allocations.
|
|
@ -103,10 +103,10 @@ class DeactivateAccountHandler:
|
||||||
# Attempt to unbind any known bound threepids to this account from identity
|
# Attempt to unbind any known bound threepids to this account from identity
|
||||||
# server(s).
|
# server(s).
|
||||||
bound_threepids = await self.store.user_get_bound_threepids(user_id)
|
bound_threepids = await self.store.user_get_bound_threepids(user_id)
|
||||||
for threepid in bound_threepids:
|
for medium, address in bound_threepids:
|
||||||
try:
|
try:
|
||||||
result = await self._identity_handler.try_unbind_threepid(
|
result = await self._identity_handler.try_unbind_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], id_server
|
user_id, medium, address, id_server
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Do we want this to be a fatal error or should we carry on?
|
# Do we want this to be a fatal error or should we carry on?
|
||||||
|
|
|
@ -1206,10 +1206,7 @@ class SsoHandler:
|
||||||
# We have no guarantee that all the devices of that session are for the same
|
# We have no guarantee that all the devices of that session are for the same
|
||||||
# `user_id`. Hence, we have to iterate over the list of devices and log them out
|
# `user_id`. Hence, we have to iterate over the list of devices and log them out
|
||||||
# one by one.
|
# one by one.
|
||||||
for device in devices:
|
for user_id, device_id in devices:
|
||||||
user_id = device["user_id"]
|
|
||||||
device_id = device["device_id"]
|
|
||||||
|
|
||||||
# If the user_id associated with that device/session is not the one we got
|
# If the user_id associated with that device/session is not the one we got
|
||||||
# out of the `sub` claim, skip that device and show log an error.
|
# out of the `sub` claim, skip that device and show log an error.
|
||||||
if expected_user_id is not None and user_id != expected_user_id:
|
if expected_user_id is not None and user_id != expected_user_id:
|
||||||
|
|
|
@ -606,13 +606,16 @@ class DatabasePool:
|
||||||
|
|
||||||
If the background updates have not completed, wait 15 sec and check again.
|
If the background updates have not completed, wait 15 sec and check again.
|
||||||
"""
|
"""
|
||||||
updates = await self.simple_select_list(
|
updates = cast(
|
||||||
"background_updates",
|
List[Tuple[str]],
|
||||||
keyvalues=None,
|
await self.simple_select_list(
|
||||||
retcols=["update_name"],
|
"background_updates",
|
||||||
desc="check_background_updates",
|
keyvalues=None,
|
||||||
|
retcols=["update_name"],
|
||||||
|
desc="check_background_updates",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
background_update_names = [x["update_name"] for x in updates]
|
background_update_names = [x[0] for x in updates]
|
||||||
|
|
||||||
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
|
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
|
||||||
if update_name not in background_update_names:
|
if update_name not in background_update_names:
|
||||||
|
@ -1804,9 +1807,9 @@ class DatabasePool:
|
||||||
keyvalues: Optional[Dict[str, Any]],
|
keyvalues: Optional[Dict[str, Any]],
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
desc: str = "simple_select_list",
|
desc: str = "simple_select_list",
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of tuples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table: the table name
|
table: the table name
|
||||||
|
@ -1817,8 +1820,7 @@ class DatabasePool:
|
||||||
desc: description of the transaction, for logging and metrics
|
desc: description of the transaction, for logging and metrics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries, one per result row, each a mapping between the
|
A list of tuples, one per result row, each the retcolumn's value for the row.
|
||||||
column names from `retcols` and that column's value for the row.
|
|
||||||
"""
|
"""
|
||||||
return await self.runInteraction(
|
return await self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
|
@ -1836,9 +1838,9 @@ class DatabasePool:
|
||||||
table: str,
|
table: str,
|
||||||
keyvalues: Optional[Dict[str, Any]],
|
keyvalues: Optional[Dict[str, Any]],
|
||||||
retcols: Iterable[str],
|
retcols: Iterable[str],
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of tuples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn: Transaction object
|
txn: Transaction object
|
||||||
|
@ -1849,8 +1851,7 @@ class DatabasePool:
|
||||||
retcols: the names of the columns to return
|
retcols: the names of the columns to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries, one per result row, each a mapping between the
|
A list of tuples, one per result row, each the retcolumn's value for the row.
|
||||||
column names from `retcols` and that column's value for the row.
|
|
||||||
"""
|
"""
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
|
@ -1863,7 +1864,7 @@ class DatabasePool:
|
||||||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
return cls.cursor_to_dict(txn)
|
return txn.fetchall()
|
||||||
|
|
||||||
async def simple_select_many_batch(
|
async def simple_select_many_batch(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
|
|
||||||
def get_account_data_for_room_txn(
|
def get_account_data_for_room_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Dict[str, JsonDict]:
|
) -> Dict[str, JsonMapping]:
|
||||||
rows = self.db_pool.simple_select_list_txn(
|
rows = cast(
|
||||||
txn,
|
List[Tuple[str, str]],
|
||||||
"room_account_data",
|
self.db_pool.simple_select_list_txn(
|
||||||
{"user_id": user_id, "room_id": room_id},
|
txn,
|
||||||
["account_data_type", "content"],
|
table="room_account_data",
|
||||||
|
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||||
|
retcols=["account_data_type", "content"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
account_data_type: db_to_json(content)
|
||||||
|
for account_data_type, content in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
Returns:
|
Returns:
|
||||||
A list of ApplicationServices, which may be empty.
|
A list of ApplicationServices, which may be empty.
|
||||||
"""
|
"""
|
||||||
results = await self.db_pool.simple_select_list(
|
results = cast(
|
||||||
"application_services_state", {"state": state.value}, ["as_id"]
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_list(
|
||||||
|
table="application_services_state",
|
||||||
|
keyvalues={"state": state.value},
|
||||||
|
retcols=("as_id",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||||
as_list = self.get_app_services()
|
as_list = self.get_app_services()
|
||||||
services = []
|
services = []
|
||||||
|
|
||||||
for res in results:
|
for (as_id,) in results:
|
||||||
for service in as_list:
|
for service in as_list:
|
||||||
if service.id == res["as_id"]:
|
if service.id == as_id:
|
||||||
services.append(service)
|
services.append(service)
|
||||||
return services
|
return services
|
||||||
|
|
||||||
|
|
|
@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
keyvalues["device_id"] = device_id
|
keyvalues["device_id"] = device_id
|
||||||
|
|
||||||
res = await self.db_pool.simple_select_list(
|
res = cast(
|
||||||
table="devices",
|
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||||
keyvalues=keyvalues,
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
table="devices",
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
|
(user_id, device_id): DeviceLastConnectionInfo(
|
||||||
user_id=d["user_id"],
|
user_id=user_id,
|
||||||
device_id=d["device_id"],
|
device_id=device_id,
|
||||||
ip=d["ip"],
|
ip=ip,
|
||||||
user_agent=d["user_agent"],
|
user_agent=user_agent,
|
||||||
last_seen=d["last_seen"],
|
last_seen=last_seen,
|
||||||
)
|
)
|
||||||
for d in res
|
for user_id, ip, user_agent, device_id, last_seen in res
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_user_ip_and_agents_from_database(
|
async def _get_user_ip_and_agents_from_database(
|
||||||
|
|
|
@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
async def get_devices_by_user(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Dict[str, Dict[str, Optional[str]]]:
|
||||||
"""Retrieve all of a user's registered devices. Only returns devices
|
"""Retrieve all of a user's registered devices. Only returns devices
|
||||||
that are not marked as hidden.
|
that are not marked as hidden.
|
||||||
|
|
||||||
|
@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
user_id:
|
user_id:
|
||||||
Returns:
|
Returns:
|
||||||
A mapping from device_id to a dict containing "device_id", "user_id"
|
A mapping from device_id to a dict containing "device_id", "user_id"
|
||||||
and "display_name" for each device.
|
and "display_name" for each device. Display name may be null.
|
||||||
"""
|
"""
|
||||||
devices = await self.db_pool.simple_select_list(
|
devices = cast(
|
||||||
table="devices",
|
List[Tuple[str, str, Optional[str]]],
|
||||||
keyvalues={"user_id": user_id, "hidden": False},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("user_id", "device_id", "display_name"),
|
table="devices",
|
||||||
desc="get_devices_by_user",
|
keyvalues={"user_id": user_id, "hidden": False},
|
||||||
|
retcols=("user_id", "device_id", "display_name"),
|
||||||
|
desc="get_devices_by_user",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {d["device_id"]: d for d in devices}
|
return {
|
||||||
|
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
|
||||||
|
for d in devices
|
||||||
|
}
|
||||||
|
|
||||||
async def get_devices_by_auth_provider_session_id(
|
async def get_devices_by_auth_provider_session_id(
|
||||||
self, auth_provider_id: str, auth_provider_session_id: str
|
self, auth_provider_id: str, auth_provider_session_id: str
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[str, str]]:
|
||||||
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts containing the device_id and the user_id of each device
|
A list of dicts containing the device_id and the user_id of each device
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_list(
|
return cast(
|
||||||
table="device_auth_providers",
|
List[Tuple[str, str]],
|
||||||
keyvalues={
|
await self.db_pool.simple_select_list(
|
||||||
"auth_provider_id": auth_provider_id,
|
table="device_auth_providers",
|
||||||
"auth_provider_session_id": auth_provider_session_id,
|
keyvalues={
|
||||||
},
|
"auth_provider_id": auth_provider_id,
|
||||||
retcols=("user_id", "device_id"),
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
desc="get_devices_by_auth_provider_session_id",
|
},
|
||||||
|
retcols=("user_id", "device_id"),
|
||||||
|
desc="get_devices_by_auth_provider_session_id",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
async def get_cached_devices_for_user(
|
async def get_cached_devices_for_user(
|
||||||
self, user_id: str
|
self, user_id: str
|
||||||
) -> Mapping[str, JsonMapping]:
|
) -> Mapping[str, JsonMapping]:
|
||||||
devices = await self.db_pool.simple_select_list(
|
devices = cast(
|
||||||
table="device_lists_remote_cache",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"user_id": user_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("device_id", "content"),
|
table="device_lists_remote_cache",
|
||||||
desc="get_cached_devices_for_user",
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("device_id", "content"),
|
||||||
|
desc="get_cached_devices_for_user",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {
|
return {device[0]: db_to_json(device[1]) for device in devices}
|
||||||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_cached_device_list_changes(
|
def get_cached_device_list_changes(
|
||||||
self,
|
self,
|
||||||
|
@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
The IDs of users whose device lists need resync.
|
The IDs of users whose device lists need resync.
|
||||||
"""
|
"""
|
||||||
if user_ids:
|
if user_ids:
|
||||||
row_tuples = cast(
|
rows = cast(
|
||||||
List[Tuple[str]],
|
List[Tuple[str]],
|
||||||
await self.db_pool.simple_select_many_batch(
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
|
@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row[0] for row in row_tuples}
|
|
||||||
else:
|
else:
|
||||||
rows = cast(
|
rows = cast(
|
||||||
List[Dict[str, str]],
|
List[Tuple[str]],
|
||||||
await self.db_pool.simple_select_list(
|
await self.db_pool.simple_select_list(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
keyvalues=None,
|
keyvalues=None,
|
||||||
|
@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["user_id"] for row in rows}
|
return {row[0] for row in rows}
|
||||||
|
|
||||||
async def mark_remote_users_device_caches_as_stale(
|
async def mark_remote_users_device_caches_as_stale(
|
||||||
self, user_ids: StrCollection
|
self, user_ids: StrCollection
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
|
||||||
|
|
||||||
from typing_extensions import Literal, TypedDict
|
from typing_extensions import Literal, TypedDict
|
||||||
|
|
||||||
|
@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
||||||
if session_id:
|
if session_id:
|
||||||
keyvalues["session_id"] = session_id
|
keyvalues["session_id"] = session_id
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="e2e_room_keys",
|
List[Tuple[str, str, int, int, int, str]],
|
||||||
keyvalues=keyvalues,
|
await self.db_pool.simple_select_list(
|
||||||
retcols=(
|
table="e2e_room_keys",
|
||||||
"user_id",
|
keyvalues=keyvalues,
|
||||||
"room_id",
|
retcols=(
|
||||||
"session_id",
|
"room_id",
|
||||||
"first_message_index",
|
"session_id",
|
||||||
"forwarded_count",
|
"first_message_index",
|
||||||
"is_verified",
|
"forwarded_count",
|
||||||
"session_data",
|
"is_verified",
|
||||||
|
"session_data",
|
||||||
|
),
|
||||||
|
desc="get_e2e_room_keys",
|
||||||
),
|
),
|
||||||
desc="get_e2e_room_keys",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sessions: Dict[
|
sessions: Dict[
|
||||||
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
||||||
] = {"rooms": {}}
|
] = {"rooms": {}}
|
||||||
for row in rows:
|
for (
|
||||||
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
|
room_id,
|
||||||
room_entry["sessions"][row["session_id"]] = {
|
session_id,
|
||||||
"first_message_index": row["first_message_index"],
|
first_message_index,
|
||||||
"forwarded_count": row["forwarded_count"],
|
forwarded_count,
|
||||||
|
is_verified,
|
||||||
|
session_data,
|
||||||
|
) in rows:
|
||||||
|
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
|
||||||
|
room_entry["sessions"][session_id] = {
|
||||||
|
"first_message_index": first_message_index,
|
||||||
|
"forwarded_count": forwarded_count,
|
||||||
# is_verified must be returned to the client as a boolean
|
# is_verified must be returned to the client as a boolean
|
||||||
"is_verified": bool(row["is_verified"]),
|
"is_verified": bool(is_verified),
|
||||||
"session_data": db_to_json(row["session_data"]),
|
"session_data": db_to_json(session_data),
|
||||||
}
|
}
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
|
@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# keeping only the forward extremities (i.e. the events not referenced
|
# keeping only the forward extremities (i.e. the events not referenced
|
||||||
# by other events in the queue). We do this so that we can always
|
# by other events in the queue). We do this so that we can always
|
||||||
# backpaginate in all the events we have dropped.
|
# backpaginate in all the events we have dropped.
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="federation_inbound_events_staging",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"room_id": room_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("event_id", "event_json"),
|
table="federation_inbound_events_staging",
|
||||||
desc="prune_staged_events_in_room_fetch",
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=("event_id", "event_json"),
|
||||||
|
desc="prune_staged_events_in_room_fetch",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find the set of events referenced by those in the queue, as well as
|
# Find the set of events referenced by those in the queue, as well as
|
||||||
# collecting all the event IDs in the queue.
|
# collecting all the event IDs in the queue.
|
||||||
referenced_events: Set[str] = set()
|
referenced_events: Set[str] = set()
|
||||||
seen_events: Set[str] = set()
|
seen_events: Set[str] = set()
|
||||||
for row in rows:
|
for event_id, event_json in rows:
|
||||||
event_id = row["event_id"]
|
|
||||||
seen_events.add(event_id)
|
seen_events.add(event_id)
|
||||||
event_d = db_to_json(row["event_json"])
|
event_d = db_to_json(event_json)
|
||||||
|
|
||||||
# We don't bother parsing the dicts into full blown event objects,
|
# We don't bother parsing the dicts into full blown event objects,
|
||||||
# as that is needlessly expensive.
|
# as that is needlessly expensive.
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, FrozenSet
|
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
|
||||||
|
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||||
|
@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
the features currently enabled for the user
|
the features currently enabled for the user
|
||||||
"""
|
"""
|
||||||
enabled = await self.db_pool.simple_select_list(
|
enabled = cast(
|
||||||
"per_user_experimental_features",
|
List[Tuple[str]],
|
||||||
{"user_id": user_id, "enabled": True},
|
await self.db_pool.simple_select_list(
|
||||||
["feature"],
|
table="per_user_experimental_features",
|
||||||
|
keyvalues={"user_id": user_id, "enabled": True},
|
||||||
|
retcols=("feature",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return frozenset(feature["feature"] for feature in enabled)
|
return frozenset(feature[0] for feature in enabled)
|
||||||
|
|
||||||
async def set_features_for_user(
|
async def set_features_for_user(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
If we have multiple entries for a given key ID, returns the most recent.
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="server_keys_json",
|
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
|
||||||
keyvalues={"server_name": server_name},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=(
|
table="server_keys_json",
|
||||||
"key_id",
|
keyvalues={"server_name": server_name},
|
||||||
"from_server",
|
retcols=(
|
||||||
"ts_added_ms",
|
"key_id",
|
||||||
"ts_valid_until_ms",
|
"from_server",
|
||||||
"key_json",
|
"ts_added_ms",
|
||||||
|
"ts_valid_until_ms",
|
||||||
|
"key_json",
|
||||||
|
),
|
||||||
|
desc="get_server_keys_json_for_remote",
|
||||||
),
|
),
|
||||||
desc="get_server_keys_json_for_remote",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
|
@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
# We sort the rows by ts_added_ms so that the most recently added entry
|
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||||
# will stomp over older entries in the dictionary.
|
# will stomp over older entries in the dictionary.
|
||||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
rows.sort(key=lambda r: r[2])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row["key_id"]: FetchKeyResultForRemote(
|
key_id: FetchKeyResultForRemote(
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
key_json=bytes(row["key_json"]),
|
key_json=bytes(key_json),
|
||||||
valid_until_ts=row["ts_valid_until_ms"],
|
valid_until_ts=ts_valid_until_ms,
|
||||||
added_ts=row["ts_added_ms"],
|
added_ts=ts_added_ms,
|
||||||
)
|
)
|
||||||
for row in rows
|
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
|
||||||
}
|
}
|
||||||
|
|
|
@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
|
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"local_media_repository_thumbnails",
|
List[Tuple[int, int, str, str, int]],
|
||||||
{"media_id": media_id},
|
await self.db_pool.simple_select_list(
|
||||||
(
|
"local_media_repository_thumbnails",
|
||||||
"thumbnail_width",
|
{"media_id": media_id},
|
||||||
"thumbnail_height",
|
(
|
||||||
"thumbnail_method",
|
"thumbnail_width",
|
||||||
"thumbnail_type",
|
"thumbnail_height",
|
||||||
"thumbnail_length",
|
"thumbnail_method",
|
||||||
|
"thumbnail_type",
|
||||||
|
"thumbnail_length",
|
||||||
|
),
|
||||||
|
desc="get_local_media_thumbnails",
|
||||||
),
|
),
|
||||||
desc="get_local_media_thumbnails",
|
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
ThumbnailInfo(
|
ThumbnailInfo(
|
||||||
width=row["thumbnail_width"],
|
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||||
height=row["thumbnail_height"],
|
|
||||||
method=row["thumbnail_method"],
|
|
||||||
type=row["thumbnail_type"],
|
|
||||||
length=row["thumbnail_length"],
|
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
async def get_remote_media_thumbnails(
|
async def get_remote_media_thumbnails(
|
||||||
self, origin: str, media_id: str
|
self, origin: str, media_id: str
|
||||||
) -> List[ThumbnailInfo]:
|
) -> List[ThumbnailInfo]:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"remote_media_cache_thumbnails",
|
List[Tuple[int, int, str, str, int]],
|
||||||
{"media_origin": origin, "media_id": media_id},
|
await self.db_pool.simple_select_list(
|
||||||
(
|
"remote_media_cache_thumbnails",
|
||||||
"thumbnail_width",
|
{"media_origin": origin, "media_id": media_id},
|
||||||
"thumbnail_height",
|
(
|
||||||
"thumbnail_method",
|
"thumbnail_width",
|
||||||
"thumbnail_type",
|
"thumbnail_height",
|
||||||
"thumbnail_length",
|
"thumbnail_method",
|
||||||
|
"thumbnail_type",
|
||||||
|
"thumbnail_length",
|
||||||
|
),
|
||||||
|
desc="get_remote_media_thumbnails",
|
||||||
),
|
),
|
||||||
desc="get_remote_media_thumbnails",
|
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
ThumbnailInfo(
|
ThumbnailInfo(
|
||||||
width=row["thumbnail_width"],
|
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
|
||||||
height=row["thumbnail_height"],
|
|
||||||
method=row["thumbnail_method"],
|
|
||||||
type=row["thumbnail_type"],
|
|
||||||
length=row["thumbnail_length"],
|
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
|
|
@ -179,46 +179,44 @@ class PushRulesWorkerStore(
|
||||||
|
|
||||||
@cached(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
|
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="push_rules",
|
List[Tuple[str, int, int, str, str]],
|
||||||
keyvalues={"user_name": user_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=(
|
table="push_rules",
|
||||||
"user_name",
|
keyvalues={"user_name": user_id},
|
||||||
"rule_id",
|
retcols=(
|
||||||
"priority_class",
|
"rule_id",
|
||||||
"priority",
|
"priority_class",
|
||||||
"conditions",
|
"priority",
|
||||||
"actions",
|
"conditions",
|
||||||
|
"actions",
|
||||||
|
),
|
||||||
|
desc="get_push_rules_for_user",
|
||||||
),
|
),
|
||||||
desc="get_push_rules_for_user",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
# Sort by highest priority_class, then highest priority.
|
||||||
|
rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
|
||||||
|
|
||||||
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
||||||
|
|
||||||
return _load_rules(
|
return _load_rules(
|
||||||
[
|
[(row[0], row[1], row[3], row[4]) for row in rows],
|
||||||
(
|
|
||||||
row["rule_id"],
|
|
||||||
row["priority_class"],
|
|
||||||
row["conditions"],
|
|
||||||
row["actions"],
|
|
||||||
)
|
|
||||||
for row in rows
|
|
||||||
],
|
|
||||||
enabled_map,
|
enabled_map,
|
||||||
self.hs.config.experimental,
|
self.hs.config.experimental,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||||
results = await self.db_pool.simple_select_list(
|
results = cast(
|
||||||
table="push_rules_enable",
|
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||||
keyvalues={"user_name": user_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("rule_id", "enabled"),
|
table="push_rules_enable",
|
||||||
desc="get_push_rules_enabled_for_user",
|
keyvalues={"user_name": user_id},
|
||||||
|
retcols=("rule_id", "enabled"),
|
||||||
|
desc="get_push_rules_enabled_for_user",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {r["rule_id"]: bool(r["enabled"]) for r in results}
|
return {r[0]: bool(r[1]) for r in results}
|
||||||
|
|
||||||
async def have_push_rules_changed_for_user(
|
async def have_push_rules_changed_for_user(
|
||||||
self, user_id: str, last_id: int
|
self, user_id: str, last_id: int
|
||||||
|
|
|
@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
async def get_throttle_params_by_room(
|
async def get_throttle_params_by_room(
|
||||||
self, pusher_id: int
|
self, pusher_id: int
|
||||||
) -> Dict[str, ThrottleParams]:
|
) -> Dict[str, ThrottleParams]:
|
||||||
res = await self.db_pool.simple_select_list(
|
res = cast(
|
||||||
"pusher_throttle",
|
List[Tuple[str, Optional[int], Optional[int]]],
|
||||||
{"pusher": pusher_id},
|
await self.db_pool.simple_select_list(
|
||||||
["room_id", "last_sent_ts", "throttle_ms"],
|
"pusher_throttle",
|
||||||
desc="get_throttle_params_by_room",
|
{"pusher": pusher_id},
|
||||||
|
["room_id", "last_sent_ts", "throttle_ms"],
|
||||||
|
desc="get_throttle_params_by_room",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
params_by_room = {}
|
params_by_room = {}
|
||||||
for row in res:
|
for room_id, last_sent_ts, throttle_ms in res:
|
||||||
params_by_room[row["room_id"]] = ThrottleParams(
|
params_by_room[room_id] = ThrottleParams(
|
||||||
row["last_sent_ts"],
|
last_sent_ts or 0, throttle_ms or 0
|
||||||
row["throttle_ms"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return params_by_room
|
return params_by_room
|
||||||
|
|
|
@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
Tuples of (auth_provider, external_id)
|
Tuples of (auth_provider, external_id)
|
||||||
"""
|
"""
|
||||||
res = await self.db_pool.simple_select_list(
|
return cast(
|
||||||
table="user_external_ids",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"user_id": mxid},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("auth_provider", "external_id"),
|
table="user_external_ids",
|
||||||
desc="get_external_ids_by_user",
|
keyvalues={"user_id": mxid},
|
||||||
|
retcols=("auth_provider", "external_id"),
|
||||||
|
desc="get_external_ids_by_user",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
|
||||||
|
|
||||||
async def count_all_users(self) -> int:
|
async def count_all_users(self) -> int:
|
||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
|
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
|
||||||
results = await self.db_pool.simple_select_list(
|
results = cast(
|
||||||
"user_threepids",
|
List[Tuple[str, str, int, int]],
|
||||||
keyvalues={"user_id": user_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=["medium", "address", "validated_at", "added_at"],
|
"user_threepids",
|
||||||
desc="user_get_threepids",
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=["medium", "address", "validated_at", "added_at"],
|
||||||
|
desc="user_get_threepids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return [ThreepidResult(**r) for r in results]
|
return [
|
||||||
|
ThreepidResult(
|
||||||
|
medium=r[0],
|
||||||
|
address=r[1],
|
||||||
|
validated_at=r[2],
|
||||||
|
added_at=r[3],
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
async def user_delete_threepid(
|
async def user_delete_threepid(
|
||||||
self, user_id: str, medium: str, address: str
|
self, user_id: str, medium: str, address: str
|
||||||
|
@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
desc="add_user_bound_threepid",
|
desc="add_user_bound_threepid",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
|
||||||
"""Get the threepids that a user has bound to an identity server through the homeserver
|
"""Get the threepids that a user has bound to an identity server through the homeserver
|
||||||
The homeserver remembers where binds to an identity server occurred. Using this
|
The homeserver remembers where binds to an identity server occurred. Using this
|
||||||
method can retrieve those threepids.
|
method can retrieve those threepids.
|
||||||
|
@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
user_id: The ID of the user to retrieve threepids for
|
user_id: The ID of the user to retrieve threepids for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionaries containing the following keys:
|
List of tuples of two strings:
|
||||||
medium (str): The medium of the threepid (e.g "email")
|
medium: The medium of the threepid (e.g "email")
|
||||||
address (str): The address of the threepid (e.g "bob@example.com")
|
address: The address of the threepid (e.g "bob@example.com")
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_list(
|
return cast(
|
||||||
table="user_threepid_id_server",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"user_id": user_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=["medium", "address"],
|
table="user_threepid_id_server",
|
||||||
desc="user_get_bound_threepids",
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=["medium", "address"],
|
||||||
|
desc="user_get_bound_threepids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_user_bound_threepid(
|
async def remove_user_bound_threepid(
|
||||||
|
|
|
@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
def get_all_relation_ids_for_event_txn(
|
def get_all_relation_ids_for_event_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
rows = self.db_pool.simple_select_list_txn(
|
rows = cast(
|
||||||
txn=txn,
|
List[Tuple[str]],
|
||||||
table="event_relations",
|
self.db_pool.simple_select_list_txn(
|
||||||
keyvalues={"relates_to_id": event_id},
|
txn=txn,
|
||||||
retcols=["event_id"],
|
table="event_relations",
|
||||||
|
keyvalues={"relates_to_id": event_id},
|
||||||
|
retcols=["event_id"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [row["event_id"] for row in rows]
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
desc="get_all_relation_ids_for_event",
|
desc="get_all_relation_ids_for_event",
|
||||||
|
|
|
@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
room_servers: Dict[str, PartialStateResyncInfo] = {}
|
room_servers: Dict[str, PartialStateResyncInfo] = {}
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="partial_state_rooms",
|
List[Tuple[str, str]],
|
||||||
keyvalues={},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("room_id", "joined_via"),
|
table="partial_state_rooms",
|
||||||
desc="get_server_which_served_partial_join",
|
keyvalues={},
|
||||||
|
retcols=("room_id", "joined_via"),
|
||||||
|
desc="get_server_which_served_partial_join",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in rows:
|
for room_id, joined_via in rows:
|
||||||
room_id = row["room_id"]
|
|
||||||
joined_via = row["joined_via"]
|
|
||||||
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
|
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"partial_state_rooms_servers",
|
List[Tuple[str, str]],
|
||||||
keyvalues=None,
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("room_id", "server_name"),
|
"partial_state_rooms_servers",
|
||||||
desc="get_partial_state_rooms",
|
keyvalues=None,
|
||||||
|
retcols=("room_id", "server_name"),
|
||||||
|
desc="get_partial_state_rooms",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in rows:
|
for room_id, server_name in rows:
|
||||||
room_id = row["room_id"]
|
|
||||||
server_name = row["server_name"]
|
|
||||||
entry = room_servers.get(room_id)
|
entry = room_servers.get(room_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
# There is a foreign key constraint which enforces that every room_id in
|
# There is a foreign key constraint which enforces that every room_id in
|
||||||
|
|
|
@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
for fully-joined rooms.
|
for fully-joined rooms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"current_state_events",
|
List[Tuple[str, Optional[str]]],
|
||||||
keyvalues={"room_id": room_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("event_id", "membership"),
|
"current_state_events",
|
||||||
desc="has_completed_background_updates",
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=("event_id", "membership"),
|
||||||
|
desc="has_completed_background_updates",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {row["event_id"]: row["membership"] for row in rows}
|
return dict(rows)
|
||||||
|
|
||||||
# TODO This returns a mutable object, which is generally confusing when using a cache.
|
# TODO This returns a mutable object, which is generally confusing when using a cache.
|
||||||
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
|
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
|
||||||
|
|
|
@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
tag content.
|
tag content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
List[Tuple[str, str, str]],
|
||||||
|
await self.db_pool.simple_select_list(
|
||||||
|
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
|
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
for row in rows:
|
for room_id, tag, content in rows:
|
||||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
room_tags = tags_by_room.setdefault(room_id, {})
|
||||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
room_tags[tag] = db_to_json(content)
|
||||||
return tags_by_room
|
return tags_by_room
|
||||||
|
|
||||||
async def get_all_updated_tags(
|
async def get_all_updated_tags(
|
||||||
|
@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
A mapping of tags to tag content.
|
A mapping of tags to tag content.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="room_tags",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("tag", "content"),
|
table="room_tags",
|
||||||
desc="get_tags_for_room",
|
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||||
|
retcols=("tag", "content"),
|
||||||
|
desc="get_tags_for_room",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {row["tag"]: db_to_json(row["content"]) for row in rows}
|
return {tag: db_to_json(content) for tag, content in rows}
|
||||||
|
|
||||||
async def add_tag_to_room(
|
async def add_tag_to_room(
|
||||||
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
||||||
|
|
|
@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
that auth-type.
|
that auth-type.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
for row in await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
table="ui_auth_sessions_credentials",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"session_id": session_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("stage_type", "result"),
|
table="ui_auth_sessions_credentials",
|
||||||
desc="get_completed_ui_auth_stages",
|
keyvalues={"session_id": session_id},
|
||||||
):
|
retcols=("stage_type", "result"),
|
||||||
results[row["stage_type"]] = db_to_json(row["result"])
|
desc="get_completed_ui_auth_stages",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for stage_type, result in rows:
|
||||||
|
results[stage_type] = db_to_json(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
List of user_agent/ip pairs
|
List of user_agent/ip pairs
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_list(
|
return cast(
|
||||||
table="ui_auth_sessions_ips",
|
List[Tuple[str, str]],
|
||||||
keyvalues={"session_id": session_id},
|
await self.db_pool.simple_select_list(
|
||||||
retcols=("user_agent", "ip"),
|
table="ui_auth_sessions_ips",
|
||||||
desc="get_user_agents_ips_to_ui_auth_session",
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("user_agent", "ip"),
|
||||||
|
desc="get_user_agents_ips_to_ui_auth_session",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return [(row["user_agent"], row["ip"]) for row in rows]
|
|
||||||
|
|
||||||
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
if not prev_group:
|
if not prev_group:
|
||||||
return _GetStateGroupDelta(None, None)
|
return _GetStateGroupDelta(None, None)
|
||||||
|
|
||||||
delta_ids = self.db_pool.simple_select_list_txn(
|
delta_ids = cast(
|
||||||
txn,
|
List[Tuple[str, str, str]],
|
||||||
table="state_groups_state",
|
self.db_pool.simple_select_list_txn(
|
||||||
keyvalues={"state_group": state_group},
|
txn,
|
||||||
retcols=("type", "state_key", "event_id"),
|
table="state_groups_state",
|
||||||
|
keyvalues={"state_group": state_group},
|
||||||
|
retcols=("type", "state_key", "event_id"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return _GetStateGroupDelta(
|
return _GetStateGroupDelta(
|
||||||
prev_group,
|
prev_group,
|
||||||
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
|
{
|
||||||
|
(event_type, state_key): event_id
|
||||||
|
for event_type, state_key, event_id in delta_ids
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_all_room_state(self) -> List[Dict[str, Any]]:
|
async def get_all_room_state(self) -> List[Optional[str]]:
|
||||||
return await self.store.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
|
List[Tuple[Optional[str]]],
|
||||||
|
await self.store.db_pool.simple_select_list(
|
||||||
|
"room_stats_state", None, retcols=("topic",)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
return [r[0] for r in rows]
|
||||||
|
|
||||||
def _get_current_stats(
|
def _get_current_stats(
|
||||||
self, stats_type: str, stat_id: str
|
self, stats_type: str, stat_id: str
|
||||||
|
@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
|
||||||
r = self.get_success(self.get_all_room_state())
|
r = self.get_success(self.get_all_room_state())
|
||||||
|
|
||||||
self.assertEqual(len(r), 1)
|
self.assertEqual(len(r), 1)
|
||||||
self.assertEqual(r[0]["topic"], "foo")
|
self.assertEqual(r[0], "foo")
|
||||||
|
|
||||||
def test_create_user(self) -> None:
|
def test_create_user(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
if expected_row is not None:
|
if expected_row is not None:
|
||||||
columns += expected_row.keys()
|
columns += expected_row.keys()
|
||||||
|
|
||||||
rows = self.get_success(
|
row_tuples = self.get_success(
|
||||||
self.store.db_pool.simple_select_list(
|
self.store.db_pool.simple_select_list(
|
||||||
table=table,
|
table=table,
|
||||||
keyvalues={
|
keyvalues={
|
||||||
|
@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
if expected_row is not None:
|
if expected_row is not None:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(rows),
|
len(row_tuples),
|
||||||
1,
|
1,
|
||||||
f"Background update did not leave behind latest receipt in {table}",
|
f"Background update did not leave behind latest receipt in {table}",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rows[0],
|
row_tuples[0],
|
||||||
{
|
(
|
||||||
"room_id": room_id,
|
room_id,
|
||||||
"receipt_type": receipt_type,
|
receipt_type,
|
||||||
"user_id": user_id,
|
user_id,
|
||||||
**expected_row,
|
*expected_row.values(),
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(rows),
|
len(row_tuples),
|
||||||
0,
|
0,
|
||||||
f"Background update did not remove all duplicate receipts from {table}",
|
f"Background update did not remove all duplicate receipts from {table}",
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Generator, Tuple
|
from typing import Generator, List, Tuple, cast
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
|
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
|
||||||
res = self.get_success(
|
yield from cast(
|
||||||
self.storage.db_pool.simple_select_list(
|
List[Tuple[int, str, str]],
|
||||||
self.table_name, None, ["id, username, value"]
|
self.get_success(
|
||||||
)
|
self.storage.db_pool.simple_select_list(
|
||||||
|
self.table_name, None, ["id, username, value"]
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in res:
|
|
||||||
yield (i["id"], i["username"], i["value"])
|
|
||||||
|
|
||||||
def test_upsert_many(self) -> None:
|
def test_upsert_many(self) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert_many will perform the upsert operation across a batch of data.
|
Upsert_many will perform the upsert operation across a batch of data.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Tuple, cast
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
# Check the correct values are in the new table.
|
# Check the correct values are in the new table.
|
||||||
rows = self.get_success(
|
rows = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[int, int]],
|
||||||
table="test_constraint",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("a", "b"),
|
table="test_constraint",
|
||||||
)
|
keyvalues={},
|
||||||
|
retcols=("a", "b"),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
|
self.assertCountEqual(rows, [(1, 1), (3, 3)])
|
||||||
|
|
||||||
# And check that invalid rows get correctly rejected.
|
# And check that invalid rows get correctly rejected.
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
# Check the correct values are in the new table.
|
# Check the correct values are in the new table.
|
||||||
rows = self.get_success(
|
rows = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[int, int]],
|
||||||
table="test_constraint",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("a", "b"),
|
table="test_constraint",
|
||||||
)
|
keyvalues={},
|
||||||
|
retcols=("a", "b"),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
|
self.assertCountEqual(rows, [(1, 1), (3, 3)])
|
||||||
|
|
||||||
# And check that invalid rows get correctly rejected.
|
# And check that invalid rows get correctly rejected.
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
|
|
@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 3
|
self.mock_txn.rowcount = 3
|
||||||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
|
||||||
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
||||||
|
|
||||||
ret = yield defer.ensureDeferred(
|
ret = yield defer.ensureDeferred(
|
||||||
|
@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
|
self.assertEqual([(1,), (2,), (3,)], ret)
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
|
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(200)
|
self.reactor.advance(200)
|
||||||
self.pump(0)
|
self.pump(0)
|
||||||
|
|
||||||
result = self.get_success(
|
result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, Optional[str], int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={"user_id": user_id},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
table="user_ips",
|
||||||
desc="get_user_ip_and_agents",
|
keyvalues={"user_id": user_id},
|
||||||
)
|
retcols=[
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result, [("access_token", "ip", "user_agent", None, 12345678000)]
|
||||||
[
|
|
||||||
{
|
|
||||||
"access_token": "access_token",
|
|
||||||
"ip": "ip",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
"device_id": None,
|
|
||||||
"last_seen": 12345678000,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add another & trigger the storage loop
|
# Add another & trigger the storage loop
|
||||||
|
@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(10)
|
self.reactor.advance(10)
|
||||||
self.pump(0)
|
self.pump(0)
|
||||||
|
|
||||||
result = self.get_success(
|
result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, Optional[str], int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={"user_id": user_id},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
table="user_ips",
|
||||||
desc="get_user_ip_and_agents",
|
keyvalues={"user_id": user_id},
|
||||||
)
|
retcols=[
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# Only one result, has been upserted.
|
# Only one result, has been upserted.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result, [("access_token", "ip", "user_agent", None, 12345878000)]
|
||||||
[
|
|
||||||
{
|
|
||||||
"access_token": "access_token",
|
|
||||||
"ip": "ip",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
"device_id": None,
|
|
||||||
"last_seen": 12345878000,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(10)
|
self.reactor.advance(10)
|
||||||
else:
|
else:
|
||||||
# Check that the new IP and user agent has not been stored yet
|
# Check that the new IP and user agent has not been stored yet
|
||||||
db_result = self.get_success(
|
db_result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||||
table="devices",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
table="devices",
|
||||||
|
keyvalues={},
|
||||||
|
retcols=(
|
||||||
|
"user_id",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
|
||||||
db_result,
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"device_id": device_id,
|
|
||||||
"ip": None,
|
|
||||||
"user_agent": None,
|
|
||||||
"last_seen": None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.get_success(
|
result = self.get_success(
|
||||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||||
|
@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the new IP and user agent has not been stored yet
|
# Check that the new IP and user agent has not been stored yet
|
||||||
db_result = self.get_success(
|
db_result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
|
||||||
table="devices",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
table="devices",
|
||||||
|
keyvalues={},
|
||||||
|
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
db_result,
|
db_result,
|
||||||
[
|
[
|
||||||
{
|
(user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
|
||||||
"user_id": user_id,
|
(user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
|
||||||
"device_id": device_id_1,
|
|
||||||
"ip": "ip_1",
|
|
||||||
"user_agent": "user_agent_1",
|
|
||||||
"last_seen": 12345678000,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"device_id": device_id_2,
|
|
||||||
"ip": "ip_2",
|
|
||||||
"user_agent": "user_agent_2",
|
|
||||||
"last_seen": 12345678000,
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the new IP and user agent has not been stored yet
|
# Check that the new IP and user agent has not been stored yet
|
||||||
db_result = self.get_success(
|
db_result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("access_token", "ip", "user_agent", "last_seen"),
|
table="user_ips",
|
||||||
|
keyvalues={},
|
||||||
|
retcols=("access_token", "ip", "user_agent", "last_seen"),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
db_result,
|
db_result,
|
||||||
[
|
[
|
||||||
{
|
("access_token", "ip_1", "user_agent_1", 12345678000),
|
||||||
"access_token": "access_token",
|
("access_token", "ip_2", "user_agent_2", 12345678000),
|
||||||
"ip": "ip_1",
|
|
||||||
"user_agent": "user_agent_1",
|
|
||||||
"last_seen": 12345678000,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"access_token": "access_token",
|
|
||||||
"ip": "ip_2",
|
|
||||||
"user_agent": "user_agent_2",
|
|
||||||
"last_seen": 12345678000,
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(200)
|
self.reactor.advance(200)
|
||||||
|
|
||||||
# We should see that in the DB
|
# We should see that in the DB
|
||||||
result = self.get_success(
|
result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, Optional[str], int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={"user_id": user_id},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
table="user_ips",
|
||||||
desc="get_user_ip_and_agents",
|
keyvalues={"user_id": user_id},
|
||||||
)
|
retcols=[
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result,
|
||||||
[
|
[("access_token", "ip", "user_agent", device_id, 0)],
|
||||||
{
|
|
||||||
"access_token": "access_token",
|
|
||||||
"ip": "ip",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
"device_id": device_id,
|
|
||||||
"last_seen": 0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now advance by a couple of months
|
# Now advance by a couple of months
|
||||||
self.reactor.advance(60 * 24 * 60 * 60)
|
self.reactor.advance(60 * 24 * 60 * 60)
|
||||||
|
|
||||||
# We should get no results.
|
# We should get no results.
|
||||||
result = self.get_success(
|
result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, Optional[str], int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={"user_id": user_id},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
table="user_ips",
|
||||||
desc="get_user_ip_and_agents",
|
keyvalues={"user_id": user_id},
|
||||||
)
|
retcols=[
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, [])
|
self.assertEqual(result, [])
|
||||||
|
@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(200)
|
self.reactor.advance(200)
|
||||||
|
|
||||||
# We should see that in the DB
|
# We should see that in the DB
|
||||||
result = self.get_success(
|
result = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str, str, Optional[str], int]],
|
||||||
table="user_ips",
|
self.get_success(
|
||||||
keyvalues={},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
table="user_ips",
|
||||||
desc="get_user_ip_and_agents",
|
keyvalues={},
|
||||||
)
|
retcols=[
|
||||||
|
"access_token",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"device_id",
|
||||||
|
"last_seen",
|
||||||
|
],
|
||||||
|
desc="get_user_ip_and_agents",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ensure user1 is filtered out
|
# ensure user1 is filtered out
|
||||||
self.assertEqual(
|
self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
|
||||||
result,
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"access_token": access_token2,
|
|
||||||
"ip": "ip",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
"device_id": device_id2,
|
|
||||||
"last_seen": 0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import List, Optional, Tuple, cast
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
|
@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
def test__null_byte_in_display_name_properly_handled(self) -> None:
|
def test__null_byte_in_display_name_properly_handled(self) -> None:
|
||||||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
||||||
|
|
||||||
res = self.get_success(
|
res = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[Optional[str], str]],
|
||||||
"room_memberships",
|
self.get_success(
|
||||||
{"user_id": "@alice:test"},
|
self.store.db_pool.simple_select_list(
|
||||||
["display_name", "event_id"],
|
"room_memberships",
|
||||||
)
|
{"user_id": "@alice:test"},
|
||||||
|
["display_name", "event_id"],
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# Check that we only got one result back
|
# Check that we only got one result back
|
||||||
self.assertEqual(len(res), 1)
|
self.assertEqual(len(res), 1)
|
||||||
|
|
||||||
# Check that alice's display name is "alice"
|
# Check that alice's display name is "alice"
|
||||||
self.assertEqual(res[0]["display_name"], "alice")
|
self.assertEqual(res[0][0], "alice")
|
||||||
|
|
||||||
# Grab the event_id to use later
|
# Grab the event_id to use later
|
||||||
event_id = res[0]["event_id"]
|
event_id = res[0][1]
|
||||||
|
|
||||||
# Create a profile with the offending null byte in the display name
|
# Create a profile with the offending null byte in the display name
|
||||||
new_profile = {"displayname": "ali\u0000ce"}
|
new_profile = {"displayname": "ali\u0000ce"}
|
||||||
|
@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.t_alice,
|
tok=self.t_alice,
|
||||||
)
|
)
|
||||||
|
|
||||||
res2 = self.get_success(
|
res2 = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[Optional[str], str]],
|
||||||
"room_memberships",
|
self.get_success(
|
||||||
{"user_id": "@alice:test"},
|
self.store.db_pool.simple_select_list(
|
||||||
["display_name", "event_id"],
|
"room_memberships",
|
||||||
)
|
{"user_id": "@alice:test"},
|
||||||
|
["display_name", "event_id"],
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# Check that we only have two results
|
# Check that we only have two results
|
||||||
self.assertEqual(len(res2), 2)
|
self.assertEqual(len(res2), 2)
|
||||||
|
|
||||||
# Filter out the previous event using the event_id we grabbed above
|
# Filter out the previous event using the event_id we grabbed above
|
||||||
row = [row for row in res2 if row["event_id"] != event_id]
|
row = [row for row in res2 if row[1] != event_id]
|
||||||
|
|
||||||
# Check that alice's display name is now None
|
# Check that alice's display name is now None
|
||||||
self.assertEqual(row[0]["display_name"], None)
|
self.assertIsNone(row[0][0])
|
||||||
|
|
||||||
def test_room_is_locally_forgotten(self) -> None:
|
def test_room_is_locally_forgotten(self) -> None:
|
||||||
"""Test that when the last local user has forgotten a room it is known as forgotten."""
|
"""Test that when the last local user has forgotten a room it is known as forgotten."""
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Tuple, cast
|
||||||
|
|
||||||
from immutabledict import immutabledict
|
from immutabledict import immutabledict
|
||||||
|
|
||||||
|
@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that only state events are in state_groups, and all state events are in state_groups
|
# check that only state events are in state_groups, and all state events are in state_groups
|
||||||
res = self.get_success(
|
res = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str]],
|
||||||
table="state_groups",
|
self.get_success(
|
||||||
keyvalues=None,
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("event_id",),
|
table="state_groups",
|
||||||
)
|
keyvalues=None,
|
||||||
|
retcols=("event_id",),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
for result in res:
|
for result in res:
|
||||||
self.assertNotIn(event3.event_id, result)
|
self.assertNotIn(event3.event_id, result) # XXX
|
||||||
events.append(result.get("event_id"))
|
events.append(result[0])
|
||||||
|
|
||||||
for event, _ in processed_events_and_context:
|
for event, _ in processed_events_and_context:
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
|
@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
# has an entry and prev event in state_group_edges
|
# has an entry and prev event in state_group_edges
|
||||||
for event, context in processed_events_and_context:
|
for event, context in processed_events_and_context:
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state = self.get_success(
|
state = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str, str]],
|
||||||
table="state_groups_state",
|
self.get_success(
|
||||||
keyvalues={"state_group": context.state_group_after_event},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("type", "state_key"),
|
table="state_groups_state",
|
||||||
)
|
keyvalues={"state_group": context.state_group_after_event},
|
||||||
|
retcols=("type", "state_key"),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(event.type, state[0].get("type"))
|
self.assertEqual(event.type, state[0][0])
|
||||||
self.assertEqual(event.state_key, state[0].get("state_key"))
|
self.assertEqual(event.state_key, state[0][1])
|
||||||
|
|
||||||
groups = self.get_success(
|
groups = cast(
|
||||||
self.store.db_pool.simple_select_list(
|
List[Tuple[str]],
|
||||||
table="state_group_edges",
|
self.get_success(
|
||||||
keyvalues={"state_group": str(context.state_group_after_event)},
|
self.store.db_pool.simple_select_list(
|
||||||
retcols=("*",),
|
table="state_group_edges",
|
||||||
)
|
keyvalues={
|
||||||
)
|
"state_group": str(context.state_group_after_event)
|
||||||
self.assertEqual(
|
},
|
||||||
context.state_group_before_event, groups[0].get("prev_state_group")
|
retcols=("prev_state_group",),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
self.assertEqual(context.state_group_before_event, groups[0][0])
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, cast
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
@ -62,14 +62,13 @@ class GetUserDirectoryTables:
|
||||||
Returns a list of tuples (user_id, room_id) where room_id is public and
|
Returns a list of tuples (user_id, room_id) where room_id is public and
|
||||||
contains the user with the given id.
|
contains the user with the given id.
|
||||||
"""
|
"""
|
||||||
r = await self.store.db_pool.simple_select_list(
|
r = cast(
|
||||||
"users_in_public_rooms", None, ("user_id", "room_id")
|
List[Tuple[str, str]],
|
||||||
|
await self.store.db_pool.simple_select_list(
|
||||||
|
"users_in_public_rooms", None, ("user_id", "room_id")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
return set(r)
|
||||||
retval = set()
|
|
||||||
for i in r:
|
|
||||||
retval.add((i["user_id"], i["room_id"]))
|
|
||||||
return retval
|
|
||||||
|
|
||||||
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
|
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
|
||||||
"""Fetch the entire `users_who_share_private_rooms` table.
|
"""Fetch the entire `users_who_share_private_rooms` table.
|
||||||
|
@ -78,27 +77,30 @@ class GetUserDirectoryTables:
|
||||||
to the rows of `users_who_share_private_rooms`.
|
to the rows of `users_who_share_private_rooms`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.store.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"users_who_share_private_rooms",
|
List[Tuple[str, str, str]],
|
||||||
None,
|
await self.store.db_pool.simple_select_list(
|
||||||
["user_id", "other_user_id", "room_id"],
|
"users_who_share_private_rooms",
|
||||||
|
None,
|
||||||
|
["user_id", "other_user_id", "room_id"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
rv = set()
|
return set(rows)
|
||||||
for row in rows:
|
|
||||||
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
|
|
||||||
return rv
|
|
||||||
|
|
||||||
async def get_users_in_user_directory(self) -> Set[str]:
|
async def get_users_in_user_directory(self) -> Set[str]:
|
||||||
"""Fetch the set of users in the `user_directory` table.
|
"""Fetch the set of users in the `user_directory` table.
|
||||||
|
|
||||||
This is useful when checking we've correctly excluded users from the directory.
|
This is useful when checking we've correctly excluded users from the directory.
|
||||||
"""
|
"""
|
||||||
result = await self.store.db_pool.simple_select_list(
|
result = cast(
|
||||||
"user_directory",
|
List[Tuple[str]],
|
||||||
None,
|
await self.store.db_pool.simple_select_list(
|
||||||
["user_id"],
|
"user_directory",
|
||||||
|
None,
|
||||||
|
["user_id"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {row["user_id"] for row in result}
|
return {row[0] for row in result}
|
||||||
|
|
||||||
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
|
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
|
||||||
"""Fetch users and their profiles from the `user_directory` table.
|
"""Fetch users and their profiles from the `user_directory` table.
|
||||||
|
@ -107,16 +109,17 @@ class GetUserDirectoryTables:
|
||||||
It's almost the entire contents of the `user_directory` table: the only
|
It's almost the entire contents of the `user_directory` table: the only
|
||||||
thing missing is an unused room_id column.
|
thing missing is an unused room_id column.
|
||||||
"""
|
"""
|
||||||
rows = await self.store.db_pool.simple_select_list(
|
rows = cast(
|
||||||
"user_directory",
|
List[Tuple[str, Optional[str], Optional[str]]],
|
||||||
None,
|
await self.store.db_pool.simple_select_list(
|
||||||
("user_id", "display_name", "avatar_url"),
|
"user_directory",
|
||||||
|
None,
|
||||||
|
("user_id", "display_name", "avatar_url"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
row["user_id"]: ProfileInfo(
|
user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
|
||||||
display_name=row["display_name"], avatar_url=row["avatar_url"]
|
for user_id, display_name, avatar_url in rows
|
||||||
)
|
|
||||||
for row in rows
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_tables(
|
async def get_tables(
|
||||||
|
|
Loading…
Reference in New Issue