Remove more usages of cursor_to_dict. (#16551)
Mostly to improve type safety.
This commit is contained in:
parent
85e5f2dc25
commit
679c691f6f
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -19,6 +19,8 @@ import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -357,9 +359,9 @@ class IdentityHandler:
|
||||||
|
|
||||||
# Check to see if a session already exists and that it is not yet
|
# Check to see if a session already exists and that it is not yet
|
||||||
# marked as validated
|
# marked as validated
|
||||||
if session and session.get("validated_at") is None:
|
if session and session.validated_at is None:
|
||||||
session_id = session["session_id"]
|
session_id = session.session_id
|
||||||
last_send_attempt = session["last_send_attempt"]
|
last_send_attempt = session.last_send_attempt
|
||||||
|
|
||||||
# Check that the send_attempt is higher than previous attempts
|
# Check that the send_attempt is higher than previous attempts
|
||||||
if send_attempt <= last_send_attempt:
|
if send_attempt <= last_send_attempt:
|
||||||
|
@ -480,7 +482,6 @@ class IdentityHandler:
|
||||||
|
|
||||||
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
|
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
|
||||||
# and if validation fails we try msisdn
|
# and if validation fails we try msisdn
|
||||||
validation_session = None
|
|
||||||
|
|
||||||
# Try to validate as email
|
# Try to validate as email
|
||||||
if self.hs.config.email.can_verify_email:
|
if self.hs.config.email.can_verify_email:
|
||||||
|
@ -488,19 +489,18 @@ class IdentityHandler:
|
||||||
validation_session = await self.store.get_threepid_validation_session(
|
validation_session = await self.store.get_threepid_validation_session(
|
||||||
"email", client_secret, sid=sid, validated=True
|
"email", client_secret, sid=sid, validated=True
|
||||||
)
|
)
|
||||||
|
if validation_session:
|
||||||
if validation_session:
|
return attr.asdict(validation_session)
|
||||||
return validation_session
|
|
||||||
|
|
||||||
# Try to validate as msisdn
|
# Try to validate as msisdn
|
||||||
if self.hs.config.registration.account_threepid_delegate_msisdn:
|
if self.hs.config.registration.account_threepid_delegate_msisdn:
|
||||||
# Ask our delegated msisdn identity server
|
# Ask our delegated msisdn identity server
|
||||||
validation_session = await self.threepid_from_creds(
|
return await self.threepid_from_creds(
|
||||||
self.hs.config.registration.account_threepid_delegate_msisdn,
|
self.hs.config.registration.account_threepid_delegate_msisdn,
|
||||||
threepid_creds,
|
threepid_creds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return validation_session
|
return None
|
||||||
|
|
||||||
async def proxy_msisdn_submit_token(
|
async def proxy_msisdn_submit_token(
|
||||||
self, id_server: str, client_secret: str, sid: str, token: str
|
self, id_server: str, client_secret: str, sid: str, token: str
|
||||||
|
|
|
@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
threepid = {
|
threepid = {
|
||||||
"medium": row["medium"],
|
"medium": row.medium,
|
||||||
"address": row["address"],
|
"address": row.address,
|
||||||
"validated_at": row["validated_at"],
|
"validated_at": row.validated_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Valid threepid returned, delete from the db
|
# Valid threepid returned, delete from the db
|
||||||
|
|
|
@ -949,10 +949,7 @@ class MediaRepository:
|
||||||
|
|
||||||
deleted = 0
|
deleted = 0
|
||||||
|
|
||||||
for media in old_media:
|
for origin, media_id, file_id in old_media:
|
||||||
origin = media["media_origin"]
|
|
||||||
media_id = media["media_id"]
|
|
||||||
file_id = media["filesystem_id"]
|
|
||||||
key = (origin, media_id)
|
key = (origin, media_id)
|
||||||
|
|
||||||
logger.info("Deleting: %r", key)
|
logger.info("Deleting: %r", key)
|
||||||
|
|
|
@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
|
||||||
destinations, total = await self._store.get_destinations_paginate(
|
destinations, total = await self._store.get_destinations_paginate(
|
||||||
start, limit, destination, order_by, direction
|
start, limit, destination, order_by, direction
|
||||||
)
|
)
|
||||||
response = {"destinations": destinations, "total": total}
|
response = {
|
||||||
|
"destinations": [
|
||||||
|
{
|
||||||
|
"destination": r[0],
|
||||||
|
"retry_last_ts": r[1],
|
||||||
|
"retry_interval": r[2],
|
||||||
|
"failure_ts": r[3],
|
||||||
|
"last_successful_stream_ordering": r[4],
|
||||||
|
}
|
||||||
|
for r in destinations
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
if (start + limit) < total:
|
if (start + limit) < total:
|
||||||
response["next_token"] = str(start + len(destinations))
|
response["next_token"] = str(start + len(destinations))
|
||||||
|
|
||||||
|
|
|
@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||||
|
|
||||||
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
||||||
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
|
result = [
|
||||||
|
{
|
||||||
|
"event_id": ex[0],
|
||||||
|
"state_group": ex[1],
|
||||||
|
"depth": ex[2],
|
||||||
|
"received_ts": ex[3],
|
||||||
|
}
|
||||||
|
for ex in extremities
|
||||||
|
]
|
||||||
|
|
||||||
|
return HTTPStatus.OK, {"count": len(extremities), "results": result}
|
||||||
|
|
||||||
|
|
||||||
class RoomEventContextServlet(RestServlet):
|
class RoomEventContextServlet(RestServlet):
|
||||||
|
|
|
@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
||||||
users_media, total = await self.store.get_users_media_usage_paginate(
|
users_media, total = await self.store.get_users_media_usage_paginate(
|
||||||
start, limit, from_ts, until_ts, order_by, direction, search_term
|
start, limit, from_ts, until_ts, order_by, direction, search_term
|
||||||
)
|
)
|
||||||
ret = {"users": users_media, "total": total}
|
ret = {
|
||||||
|
"users": [
|
||||||
|
{
|
||||||
|
"user_id": r[0],
|
||||||
|
"displayname": r[1],
|
||||||
|
"media_count": r[2],
|
||||||
|
"media_length": r[3],
|
||||||
|
}
|
||||||
|
for r in users_media
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
if (start + limit) < total:
|
if (start + limit) < total:
|
||||||
ret["next_token"] = start + len(users_media)
|
ret["next_token"] = start + len(users_media)
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
@ -1047,43 +1046,20 @@ class DatabasePool:
|
||||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@overload
|
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
|
||||||
async def execute(
|
|
||||||
self, desc: str, decoder: Literal[None], query: str, *args: Any
|
|
||||||
) -> List[Tuple[Any, ...]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def execute(
|
|
||||||
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
|
|
||||||
) -> R:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
desc: str,
|
|
||||||
decoder: Optional[Callable[[Cursor], R]],
|
|
||||||
query: str,
|
|
||||||
*args: Any,
|
|
||||||
) -> Union[List[Tuple[Any, ...]], R]:
|
|
||||||
"""Runs a single query for a result set.
|
"""Runs a single query for a result set.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
desc: description of the transaction, for logging and metrics
|
desc: description of the transaction, for logging and metrics
|
||||||
decoder - The function which can resolve the cursor results to
|
|
||||||
something meaningful.
|
|
||||||
query - The query string to execute
|
query - The query string to execute
|
||||||
*args - Query args.
|
*args - Query args.
|
||||||
Returns:
|
Returns:
|
||||||
The result of decoder(results)
|
The result of decoder(results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
|
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
|
||||||
txn.execute(query, args)
|
txn.execute(query, args)
|
||||||
if decoder:
|
return txn.fetchall()
|
||||||
return decoder(txn)
|
|
||||||
else:
|
|
||||||
return txn.fetchall()
|
|
||||||
|
|
||||||
return await self.runInteraction(desc, interaction)
|
return await self.runInteraction(desc, interaction)
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"_censor_redactions_fetch", None, sql, before_ts, 100
|
"_censor_redactions_fetch", sql, before_ts, 100
|
||||||
)
|
)
|
||||||
|
|
||||||
updates = []
|
updates = []
|
||||||
|
|
|
@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"get_all_devices_changed",
|
"get_all_devices_changed",
|
||||||
None,
|
|
||||||
sql,
|
sql,
|
||||||
from_key,
|
from_key,
|
||||||
to_key,
|
to_key,
|
||||||
|
@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
WHERE from_user_id = ? AND stream_id > ?
|
WHERE from_user_id = ? AND stream_id > ?
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
"get_users_whose_signatures_changed", sql, user_id, from_key
|
||||||
)
|
)
|
||||||
return {user for row in rows for user in db_to_json(row[0])}
|
return {user for row in rows for user in db_to_json(row[0])}
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"get_e2e_device_keys_for_federation_query_check",
|
"get_e2e_device_keys_for_federation_query_check",
|
||||||
None,
|
|
||||||
sql,
|
sql,
|
||||||
now_stream_id,
|
now_stream_id,
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
|
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
|
||||||
# indexes on it.
|
# indexes on it.
|
||||||
# We need to pass execute a dummy function to handle the txn's result otherwise
|
await self.db_pool.runInteraction(
|
||||||
# it tries to call fetchall() on it and fails because there's no result to fetch.
|
|
||||||
await self.db_pool.execute(
|
|
||||||
"background_analyze_new_stream_ordering_column",
|
"background_analyze_new_stream_ordering_column",
|
||||||
lambda txn: None,
|
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
|
||||||
"ANALYZE events(stream_ordering2)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import List, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
|
||||||
|
|
||||||
async def get_forward_extremities_for_room(
|
async def get_forward_extremities_for_room(
|
||||||
self, room_id: str
|
self, room_id: str
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[str, int, int, Optional[int]]]:
|
||||||
"""Get list of forward extremities for a room."""
|
"""
|
||||||
|
Get list of forward extremities for a room.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of event_id, state_group, depth, and received_ts.
|
||||||
|
"""
|
||||||
|
|
||||||
def get_forward_extremities_for_room_txn(
|
def get_forward_extremities_for_room_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[str, int, int, Optional[int]]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT event_id, state_group, depth, received_ts
|
SELECT event_id, state_group, depth, received_ts
|
||||||
FROM event_forward_extremities
|
FROM event_forward_extremities
|
||||||
|
@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_forward_extremities_for_room",
|
"get_forward_extremities_for_room",
|
||||||
|
|
|
@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
|
|
||||||
async def get_remote_media_ids(
|
async def get_remote_media_ids(
|
||||||
self, before_ts: int, include_quarantined_media: bool
|
self, before_ts: int, include_quarantined_media: bool
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
Retrieve a list of server name, media ID tuples from the remote media cache.
|
Retrieve a list of server name, media ID tuples from the remote media cache.
|
||||||
|
|
||||||
|
@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
A list of tuples containing:
|
A list of tuples containing:
|
||||||
* The server name of homeserver where the media originates from,
|
* The server name of homeserver where the media originates from,
|
||||||
* The ID of the media.
|
* The ID of the media.
|
||||||
|
* The filesystem ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT media_origin, media_id, filesystem_id
|
||||||
|
FROM remote_media_cache
|
||||||
|
WHERE last_access_ts < ?
|
||||||
"""
|
"""
|
||||||
sql = (
|
|
||||||
"SELECT media_origin, media_id, filesystem_id"
|
|
||||||
" FROM remote_media_cache"
|
|
||||||
" WHERE last_access_ts < ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
if include_quarantined_media is False:
|
if include_quarantined_media is False:
|
||||||
# Only include media that has not been quarantined
|
# Only include media that has not been quarantined
|
||||||
|
@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
AND quarantined_by IS NULL
|
AND quarantined_by IS NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self.db_pool.execute(
|
return cast(
|
||||||
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
|
List[Tuple[str, str, str]],
|
||||||
|
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||||
|
|
|
@ -151,6 +151,22 @@ class ThreepidResult:
|
||||||
added_at: int
|
added_at: int
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
|
class ThreepidValidationSession:
|
||||||
|
address: str
|
||||||
|
"""address of the 3pid"""
|
||||||
|
medium: str
|
||||||
|
"""medium of the 3pid"""
|
||||||
|
client_secret: str
|
||||||
|
"""a secret provided by the client for this validation session"""
|
||||||
|
session_id: str
|
||||||
|
"""ID of the validation session"""
|
||||||
|
last_send_attempt: int
|
||||||
|
"""a number serving to dedupe send attempts for this session"""
|
||||||
|
validated_at: Optional[int]
|
||||||
|
"""timestamp of when this session was validated if so"""
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
address: Optional[str] = None,
|
address: Optional[str] = None,
|
||||||
sid: Optional[str] = None,
|
sid: Optional[str] = None,
|
||||||
validated: Optional[bool] = True,
|
validated: Optional[bool] = True,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[ThreepidValidationSession]:
|
||||||
"""Gets a session_id and last_send_attempt (if available) for a
|
"""Gets a session_id and last_send_attempt (if available) for a
|
||||||
combination of validation metadata
|
combination of validation metadata
|
||||||
|
|
||||||
|
@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
perform no filtering
|
perform no filtering
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict containing the following:
|
A ThreepidValidationSession or None if a validation session is not found
|
||||||
* address - address of the 3pid
|
|
||||||
* medium - medium of the 3pid
|
|
||||||
* client_secret - a secret provided by the client for this validation session
|
|
||||||
* session_id - ID of the validation session
|
|
||||||
* send_attempt - a number serving to dedupe send attempts for this session
|
|
||||||
* validated_at - timestamp of when this session was validated if so
|
|
||||||
|
|
||||||
Otherwise None if a validation session is not found
|
|
||||||
"""
|
"""
|
||||||
if not client_secret:
|
if not client_secret:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
def get_threepid_validation_session_txn(
|
def get_threepid_validation_session_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[ThreepidValidationSession]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT address, session_id, medium, client_secret,
|
SELECT address, session_id, medium, client_secret,
|
||||||
last_send_attempt, validated_at
|
last_send_attempt, validated_at
|
||||||
|
@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
sql += " LIMIT 1"
|
sql += " LIMIT 1"
|
||||||
|
|
||||||
txn.execute(sql, list(keyvalues.values()))
|
txn.execute(sql, list(keyvalues.values()))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
if not rows:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return rows[0]
|
return ThreepidValidationSession(
|
||||||
|
address=row[0],
|
||||||
|
session_id=row[1],
|
||||||
|
medium=row[2],
|
||||||
|
client_secret=row[3],
|
||||||
|
last_send_attempt=row[4],
|
||||||
|
validated_at=row[5],
|
||||||
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_threepid_validation_session", get_threepid_validation_session_txn
|
"get_threepid_validation_session", get_threepid_validation_session_txn
|
||||||
|
|
|
@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
like_clause = "%:" + host
|
like_clause = "%:" + host
|
||||||
|
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"is_host_joined", None, sql, membership, room_id, like_clause
|
"is_host_joined", sql, membership, room_id, like_clause
|
||||||
)
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
|
@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
AND forgotten = 0;
|
AND forgotten = 0;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
|
rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
|
||||||
|
|
||||||
# `count(*)` returns always an integer
|
# `count(*)` returns always an integer
|
||||||
# If any rows still exist it means someone has not forgotten this room yet
|
# If any rows still exist it means someone has not forgotten this room yet
|
||||||
|
|
|
@ -26,6 +26,7 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
# entire table from the database.
|
# entire table from the database.
|
||||||
sql += " ORDER BY rank DESC LIMIT 500"
|
sql += " ORDER BY rank DESC LIMIT 500"
|
||||||
|
|
||||||
results = await self.db_pool.execute(
|
# List of tuples of (rank, room_id, event_id).
|
||||||
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
|
results = cast(
|
||||||
|
List[Tuple[Union[int, float], str, str]],
|
||||||
|
await self.db_pool.execute("search_msgs", sql, *args),
|
||||||
)
|
)
|
||||||
|
|
||||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
results = list(filter(lambda row: row[1] in room_ids, results))
|
||||||
|
|
||||||
# We set redact_behaviour to block here to prevent redacted events being returned in
|
# We set redact_behaviour to block here to prevent redacted events being returned in
|
||||||
# search results (which is a data leak)
|
# search results (which is a data leak)
|
||||||
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||||
[r["event_id"] for r in results],
|
[r[2] for r in results],
|
||||||
redact_behaviour=EventRedactBehaviour.block,
|
redact_behaviour=EventRedactBehaviour.block,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
count_sql += " GROUP BY room_id"
|
count_sql += " GROUP BY room_id"
|
||||||
|
|
||||||
count_results = await self.db_pool.execute(
|
# List of tuples of (room_id, count).
|
||||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
count_results = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
count = sum(row[1] for row in count_results if row[0] in room_ids)
|
||||||
return {
|
return {
|
||||||
"results": [
|
"results": [
|
||||||
{"event": event_map[r["event_id"]], "rank": r["rank"]}
|
{"event": event_map[r[2]], "rank": r[0]}
|
||||||
for r in results
|
for r in results
|
||||||
if r["event_id"] in event_map
|
if r[2] in event_map
|
||||||
],
|
],
|
||||||
"highlights": highlights,
|
"highlights": highlights,
|
||||||
"count": count,
|
"count": count,
|
||||||
|
@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
search_query = search_term
|
search_query = search_term
|
||||||
sql = """
|
sql = """
|
||||||
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
|
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
|
||||||
origin_server_ts, stream_ordering, room_id, event_id
|
room_id, event_id, origin_server_ts, stream_ordering
|
||||||
FROM event_search
|
FROM event_search
|
||||||
WHERE vector @@ websearch_to_tsquery('english', ?) AND
|
WHERE vector @@ websearch_to_tsquery('english', ?) AND
|
||||||
"""
|
"""
|
||||||
|
@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
# mypy expects to append only a `str`, not an `int`
|
# mypy expects to append only a `str`, not an `int`
|
||||||
args.append(limit)
|
args.append(limit)
|
||||||
|
|
||||||
results = await self.db_pool.execute(
|
# List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
|
||||||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
|
results = cast(
|
||||||
|
List[Tuple[Union[int, float], str, str, int, int]],
|
||||||
|
await self.db_pool.execute("search_rooms", sql, *args),
|
||||||
)
|
)
|
||||||
|
|
||||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
results = list(filter(lambda row: row[1] in room_ids, results))
|
||||||
|
|
||||||
# We set redact_behaviour to block here to prevent redacted events being returned in
|
# We set redact_behaviour to block here to prevent redacted events being returned in
|
||||||
# search results (which is a data leak)
|
# search results (which is a data leak)
|
||||||
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||||
[r["event_id"] for r in results],
|
[r[2] for r in results],
|
||||||
redact_behaviour=EventRedactBehaviour.block,
|
redact_behaviour=EventRedactBehaviour.block,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
count_sql += " GROUP BY room_id"
|
count_sql += " GROUP BY room_id"
|
||||||
|
|
||||||
count_results = await self.db_pool.execute(
|
# List of tuples of (room_id, count).
|
||||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
count_results = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
count = sum(row[1] for row in count_results if row[0] in room_ids)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"results": [
|
"results": [
|
||||||
{
|
{
|
||||||
"event": event_map[r["event_id"]],
|
"event": event_map[r[2]],
|
||||||
"rank": r["rank"],
|
"rank": r[0],
|
||||||
"pagination_token": "%s,%s"
|
"pagination_token": "%s,%s" % (r[3], r[4]),
|
||||||
% (r["origin_server_ts"], r["stream_ordering"]),
|
|
||||||
}
|
}
|
||||||
for r in results
|
for r in results
|
||||||
if r["event_id"] in event_map
|
if r[2] in event_map
|
||||||
],
|
],
|
||||||
"highlights": highlights,
|
"highlights": highlights,
|
||||||
"count": count,
|
"count": count,
|
||||||
|
|
|
@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
|
||||||
order_by: Optional[str] = UserSortOrder.USER_ID.value,
|
order_by: Optional[str] = UserSortOrder.USER_ID.value,
|
||||||
direction: Direction = Direction.FORWARDS,
|
direction: Direction = Direction.FORWARDS,
|
||||||
search_term: Optional[str] = None,
|
search_term: Optional[str] = None,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
|
||||||
"""Function to retrieve a paginated list of users and their uploaded local media
|
"""Function to retrieve a paginated list of users and their uploaded local media
|
||||||
(size and number). This will return a json list of users and the
|
(size and number). This will return a json list of users and the
|
||||||
total number of users matching the filter criteria.
|
total number of users matching the filter criteria.
|
||||||
|
@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
|
||||||
order_by: the sort order of the returned list
|
order_by: the sort order of the returned list
|
||||||
direction: sort ascending or descending
|
direction: sort ascending or descending
|
||||||
search_term: a string to filter user names by
|
search_term: a string to filter user names by
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of user dicts and an integer representing the total number of
|
A tuple of:
|
||||||
users that exist given this query
|
A list of tuples of user information (the user ID, displayname,
|
||||||
|
total number of media, total length of media) and
|
||||||
|
|
||||||
|
An integer representing the total number of users that exist
|
||||||
|
given this query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_users_media_usage_paginate_txn(
|
def get_users_media_usage_paginate_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
|
||||||
filters = []
|
filters = []
|
||||||
args: list = []
|
args: list = []
|
||||||
|
|
||||||
|
@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
|
||||||
|
|
||||||
args += [limit, start]
|
args += [limit, start]
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
users = self.db_pool.cursor_to_dict(txn)
|
users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
|
||||||
|
|
||||||
return users, count
|
return users, count
|
||||||
|
|
||||||
|
|
|
@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
row = await self.db_pool.execute(
|
row = await self.db_pool.execute(
|
||||||
"get_current_topological_token", None, sql, room_id, room_id, stream_key
|
"get_current_topological_token", sql, room_id, room_id, stream_key
|
||||||
)
|
)
|
||||||
return row[0][0] if row else 0
|
return row[0][0] if row else 0
|
||||||
|
|
||||||
|
@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"get_timeline_gaps",
|
"get_timeline_gaps",
|
||||||
None,
|
|
||||||
sql,
|
sql,
|
||||||
room_id,
|
room_id,
|
||||||
from_token.stream if from_token else 0,
|
from_token.stream if from_token else 0,
|
||||||
|
|
|
@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
destination: Optional[str] = None,
|
destination: Optional[str] = None,
|
||||||
order_by: str = DestinationSortOrder.DESTINATION.value,
|
order_by: str = DestinationSortOrder.DESTINATION.value,
|
||||||
direction: Direction = Direction.FORWARDS,
|
direction: Direction = Direction.FORWARDS,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[
|
||||||
|
List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
|
||||||
|
int,
|
||||||
|
]:
|
||||||
"""Function to retrieve a paginated list of destinations.
|
"""Function to retrieve a paginated list of destinations.
|
||||||
This will return a json list of destinations and the
|
This will return a json list of destinations and the
|
||||||
total number of destinations matching the filter criteria.
|
total number of destinations matching the filter criteria.
|
||||||
|
@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
order_by: the sort order of the returned list
|
order_by: the sort order of the returned list
|
||||||
direction: sort ascending or descending
|
direction: sort ascending or descending
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of a list of mappings from destination to information
|
A tuple of a list of tuples of destination information:
|
||||||
|
* destination
|
||||||
|
* retry_last_ts
|
||||||
|
* retry_interval
|
||||||
|
* failure_ts
|
||||||
|
* last_successful_stream_ordering
|
||||||
and a count of total destinations.
|
and a count of total destinations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_destinations_paginate_txn(
|
def get_destinations_paginate_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[
|
||||||
|
List[
|
||||||
|
Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
|
||||||
|
],
|
||||||
|
int,
|
||||||
|
]:
|
||||||
order_by_column = DestinationSortOrder(order_by).value
|
order_by_column = DestinationSortOrder(order_by).value
|
||||||
|
|
||||||
if direction == Direction.BACKWARDS:
|
if direction == Direction.BACKWARDS:
|
||||||
|
@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, args + [limit, start])
|
txn.execute(sql, args + [limit, start])
|
||||||
destinations = self.db_pool.cursor_to_dict(txn)
|
destinations = cast(
|
||||||
|
List[
|
||||||
|
Tuple[
|
||||||
|
str, Optional[int], Optional[int], Optional[int], Optional[int]
|
||||||
|
]
|
||||||
|
],
|
||||||
|
txn.fetchall(),
|
||||||
|
)
|
||||||
return destinations, count
|
return destinations, count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
results = cast(
|
results = cast(
|
||||||
List[UserProfile],
|
List[Tuple[str, Optional[str], Optional[str]]],
|
||||||
await self.db_pool.execute(
|
await self.db_pool.execute("search_user_dir", sql, *args),
|
||||||
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
limited = len(results) > limit
|
limited = len(results) > limit
|
||||||
|
|
||||||
return {"limited": limited, "results": results[0:limit]}
|
return {
|
||||||
|
"limited": limited,
|
||||||
|
"results": [
|
||||||
|
{"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
|
||||||
|
for r in results[0:limit]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _filter_text_for_index(text: str) -> str:
|
def _filter_text_for_index(text: str) -> str:
|
||||||
|
|
|
@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
||||||
if max_group is None:
|
if max_group is None:
|
||||||
rows = await self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"_background_deduplicate_state",
|
"_background_deduplicate_state",
|
||||||
None,
|
|
||||||
"SELECT coalesce(max(id), 0) FROM state_groups",
|
"SELECT coalesce(max(id), 0) FROM state_groups",
|
||||||
)
|
)
|
||||||
max_group = rows[0][0]
|
max_group = rows[0][0]
|
||||||
|
|
|
@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
event_id, stream_ordering = self.get_success(
|
event_id, stream_ordering = self.get_success(
|
||||||
self.hs.get_datastores().main.db_pool.execute(
|
self.hs.get_datastores().main.db_pool.execute(
|
||||||
"test:get_destination_rooms",
|
"test:get_destination_rooms",
|
||||||
None,
|
|
||||||
"""
|
"""
|
||||||
SELECT event_id, stream_ordering
|
SELECT event_id, stream_ordering
|
||||||
FROM destination_rooms dr
|
FROM destination_rooms dr
|
||||||
|
|
|
@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.runInteraction(
|
||||||
"test_not_null_constraint", lambda _: None, table_sql
|
"test_not_null_constraint", lambda txn: txn.execute(table_sql)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||||
# using SQLite.
|
# using SQLite.
|
||||||
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
|
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.runInteraction(
|
||||||
"test_not_null_constraint", lambda _: None, index_sql
|
"test_not_null_constraint", lambda txn: txn.execute(index_sql)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.runInteraction(
|
||||||
"test_foreign_key_constraint", lambda _: None, base_sql
|
"test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.runInteraction(
|
||||||
"test_foreign_key_constraint", lambda _: None, table_sql
|
"test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.execute(
|
||||||
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
|
"", "SELECT full_user_id from profiles ORDER BY full_user_id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(res), len(expected_values))
|
self.assertEqual(len(res), len(expected_values))
|
||||||
|
|
|
@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.store.db_pool.execute(
|
self.store.db_pool.execute(
|
||||||
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
|
"", "SELECT full_user_id from user_filters ORDER BY full_user_id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(res), len(expected_values))
|
self.assertEqual(len(res), len(expected_values))
|
||||||
|
|
Loading…
Reference in New Issue