Add type hints to `synapse/storage/databases/main` (#11984)

This commit is contained in:
Dirk Klimpel 2022-02-21 17:03:06 +01:00 committed by GitHub
parent 99f6d79fe1
commit 7c82da27aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 53 deletions

1
changelog.d/11984.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/
|tests/api/test_auth.py

View File

@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
Returns:
dict: `user_id` -> `UserPresenceState`
"""
states = {
user_id: self.user_to_current_state.get(user_id, None)
for user_id in user_ids
}
states = {}
missing = []
for user_id in user_ids:
state = self.user_to_current_state.get(user_id, None)
if state:
states[user_id] = state
else:
missing.append(user_id)
missing = [user_id for user_id, state in states.items() if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in states.items() if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
for user_id in missing:
# if user has no state in database, create the state
if not res.get(user_id, None):
new_state = UserPresenceState.default(user_id)
states[user_id] = new_state
self.user_to_current_state[user_id] = new_state
return states

View File

@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()`
@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator
self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence
self._instance_name in hs.config.worker.writers.presence
)
if isinstance(database.engine, PostgresEngine):
@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(self, txn, stream_orderings, presence_states):
def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id:
return [], current_id, False
def get_all_presence_updates_txn(txn):
def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
status_msg, currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)
upper_bound = current_id
limited = False
@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
)
@cached()
def _get_presence_for_user(self, user_id):
def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError()
@cachedList(
@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids",
num_args=1,
)
async def get_presence_for_users(self, user_ids):
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise.
"""
def _should_user_receive_full_presence_with_token_txn(txn):
def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn,
)
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state
def get_current_presence_token(self):
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
def _get_active_presence(self, db_conn: Connection):
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows]
def take_presence_startup_info(self):
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
self._presence_on_startup = []
return active_on_startup
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows:

View File

@ -13,9 +13,10 @@
# limitations under the License.
import logging
from typing import Any, List, Set, Tuple
from typing import Any, List, Set, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
)
def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
self,
txn: LoggingTransaction,
room_id: str,
token: RoomStreamToken,
delete_local_events: bool,
) -> Set[int]:
# Tables that should be pruned:
# event_auth
@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""",
(room_id,),
)
(min_depth,) = txn.fetchone()
(min_depth,) = cast(Tuple[int], txn.fetchone())
logger.info("[purge] updating room_depth to %d", min_depth)
@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"purge_room", self._purge_room_txn, room_id
)
def _purge_room_txn(self, txn, room_id: str) -> List[int]:
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(

View File

@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
) -> None:
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
is_in_room = await self.is_host_joined(room_id, self.server_name)
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
if is_in_room:
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory.
users_with_profile = {
user_id: profile
@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
if self.get_app_service_by_user_id(user) is not None:
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False
# We're opting to exclude appservice users (anyone matching the user
@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
if self.get_if_app_services_interested_in_user(user):
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service
return False
# Support users are for diagnostics and should not appear in the user directory.
if await self.is_support_user(user):
if await self.is_support_user(user): # type: ignore[attr-defined]
return False
# Deactivated users aren't contactable, so should not appear in the user directory.
try:
if await self.get_user_deactivated_status(user):
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False
except StoreError:
# No such user in the users table. No need to do this when calling
@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
current_state_ids = await self.get_filtered_current_state_ids(
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev:
if (
hist_vis_ev.content.get("history_visibility")

View File

@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.main import DataStore, PurgeEventsStore
# Define a state map type from type/state_key to T (usually an event ID or
# event)
@ -485,7 +485,7 @@ class RoomStreamToken:
)
@classmethod
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
@ -502,7 +502,7 @@ class RoomStreamToken:
instance_id = int(key)
pos = int(value)
instance_name = await store.get_name_from_instance_id(instance_id)
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
instance_map[instance_name] = pos
return cls(