Add type hints to `synapse/storage/databases/main` (#11984)
This commit is contained in:
parent
99f6d79fe1
commit
7c82da27aa
|
@ -0,0 +1 @@
|
|||
Add missing type hints to storage classes.
|
3
mypy.ini
3
mypy.ini
|
@ -31,14 +31,11 @@ exclude = (?x)
|
|||
|synapse/storage/databases/main/group_server.py
|
||||
|synapse/storage/databases/main/metrics.py
|
||||
|synapse/storage/databases/main/monthly_active_users.py
|
||||
|synapse/storage/databases/main/presence.py
|
||||
|synapse/storage/databases/main/purge_events.py
|
||||
|synapse/storage/databases/main/push_rule.py
|
||||
|synapse/storage/databases/main/receipts.py
|
||||
|synapse/storage/databases/main/roommember.py
|
||||
|synapse/storage/databases/main/search.py
|
||||
|synapse/storage/databases/main/state.py
|
||||
|synapse/storage/databases/main/user_directory.py
|
||||
|synapse/storage/schema/
|
||||
|
||||
|tests/api/test_auth.py
|
||||
|
|
|
@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
|
|||
Returns:
|
||||
dict: `user_id` -> `UserPresenceState`
|
||||
"""
|
||||
states = {
|
||||
user_id: self.user_to_current_state.get(user_id, None)
|
||||
for user_id in user_ids
|
||||
}
|
||||
states = {}
|
||||
missing = []
|
||||
for user_id in user_ids:
|
||||
state = self.user_to_current_state.get(user_id, None)
|
||||
if state:
|
||||
states[user_id] = state
|
||||
else:
|
||||
missing.append(user_id)
|
||||
|
||||
missing = [user_id for user_id, state in states.items() if not state]
|
||||
if missing:
|
||||
# There are things not in our in memory cache. Lets pull them out of
|
||||
# the database.
|
||||
res = await self.store.get_presence_for_users(missing)
|
||||
states.update(res)
|
||||
|
||||
missing = [user_id for user_id, state in states.items() if not state]
|
||||
if missing:
|
||||
new = {
|
||||
user_id: UserPresenceState.default(user_id) for user_id in missing
|
||||
}
|
||||
states.update(new)
|
||||
self.user_to_current_state.update(new)
|
||||
for user_id in missing:
|
||||
# if user has no state in database, create the state
|
||||
if not res.get(user_id, None):
|
||||
new_state = UserPresenceState.default(user_id)
|
||||
states[user_id] = new_state
|
||||
self.user_to_current_state[user_id] = new_state
|
||||
|
||||
return states
|
||||
|
||||
|
|
|
@ -12,15 +12,23 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
|
||||
|
||||
from synapse.api.presence import PresenceState, UserPresenceState
|
||||
from synapse.replication.tcp.streams import PresenceStream
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
|
|||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Used by `PresenceStore._get_active_presence()`
|
||||
|
@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._presence_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
self._can_persist_presence = (
|
||||
hs.get_instance_name() in hs.config.worker.writers.presence
|
||||
self._instance_name in hs.config.worker.writers.presence
|
||||
)
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
|
@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
|
||||
return stream_orderings[-1], self._presence_id_gen.get_current_token()
|
||||
|
||||
def _update_presence_txn(self, txn, stream_orderings, presence_states):
|
||||
def _update_presence_txn(
|
||||
self, txn: LoggingTransaction, stream_orderings, presence_states
|
||||
) -> None:
|
||||
for stream_id, state in zip(stream_orderings, presence_states):
|
||||
txn.call_after(
|
||||
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
|
||||
|
@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_presence_updates_txn(txn):
|
||||
def get_all_presence_updates_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
sql = """
|
||||
SELECT stream_id, user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts,
|
||||
status_msg,
|
||||
currently_active
|
||||
status_msg, currently_active
|
||||
FROM presence_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
updates = cast(
|
||||
List[Tuple[int, list]],
|
||||
[(row[0], row[1:]) for row in txn],
|
||||
)
|
||||
|
||||
upper_bound = current_id
|
||||
limited = False
|
||||
|
@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
@cached()
|
||||
def _get_presence_for_user(self, user_id):
|
||||
def _get_presence_for_user(self, user_id: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
|
@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def get_presence_for_users(self, user_ids):
|
||||
async def get_presence_for_users(
|
||||
self, user_ids: Iterable[str]
|
||||
) -> Dict[str, UserPresenceState]:
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
column="user_id",
|
||||
|
@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
True if the user should have full presence sent to them, False otherwise.
|
||||
"""
|
||||
|
||||
def _should_user_receive_full_presence_with_token_txn(txn):
|
||||
def _should_user_receive_full_presence_with_token_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
sql = """
|
||||
SELECT 1 FROM users_to_send_full_presence_to
|
||||
WHERE user_id = ?
|
||||
|
@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
_should_user_receive_full_presence_with_token_txn,
|
||||
)
|
||||
|
||||
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
|
||||
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
|
||||
"""Adds to the list of users who should receive a full snapshot of presence
|
||||
upon their next sync.
|
||||
|
||||
|
@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
|
||||
return users_to_state
|
||||
|
||||
def get_current_presence_token(self):
|
||||
def get_current_presence_token(self) -> int:
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def _get_active_presence(self, db_conn: Connection):
|
||||
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
|
||||
"""Fetch non-offline presence from the database so that we can register
|
||||
the appropriate time outs.
|
||||
"""
|
||||
|
@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||
|
||||
return [UserPresenceState(**row) for row in rows]
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
def take_presence_startup_info(self) -> List[UserPresenceState]:
|
||||
active_on_startup = self._presence_on_startup
|
||||
self._presence_on_startup = None
|
||||
self._presence_on_startup = []
|
||||
return active_on_startup
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Set, Tuple
|
||||
from typing import Any, List, Set, Tuple, cast
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.types import RoomStreamToken
|
||||
|
@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
def _purge_history_txn(
|
||||
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
token: RoomStreamToken,
|
||||
delete_local_events: bool,
|
||||
) -> Set[int]:
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
|
@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
""",
|
||||
(room_id,),
|
||||
)
|
||||
(min_depth,) = txn.fetchone()
|
||||
(min_depth,) = cast(Tuple[int], txn.fetchone())
|
||||
|
||||
logger.info("[purge] updating room_depth to %d", min_depth)
|
||||
|
||||
|
@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
"purge_room", self._purge_room_txn, room_id
|
||||
)
|
||||
|
||||
def _purge_room_txn(self, txn, room_id: str) -> List[int]:
|
||||
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
|
||||
# First we fetch all the state groups that should be deleted, before
|
||||
# we delete that information.
|
||||
txn.execute(
|
||||
|
|
|
@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
|
@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
processed_event_count = 0
|
||||
|
||||
for room_id, event_count in rooms_to_work_on:
|
||||
is_in_room = await self.is_host_joined(room_id, self.server_name)
|
||||
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
|
||||
|
||||
if is_in_room:
|
||||
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
|
||||
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
|
||||
# Throw away users excluded from the directory.
|
||||
users_with_profile = {
|
||||
user_id: profile
|
||||
|
@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
|
||||
for user_id in users_to_work_on:
|
||||
if await self.should_include_local_user_in_dir(user_id):
|
||||
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
|
||||
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
|
||||
await self.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url
|
||||
)
|
||||
|
@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
# technically it could be DM-able. In the future, this could potentially
|
||||
# be configurable per-appservice whether the appservice sender can be
|
||||
# contacted.
|
||||
if self.get_app_service_by_user_id(user) is not None:
|
||||
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
|
||||
return False
|
||||
|
||||
# We're opting to exclude appservice users (anyone matching the user
|
||||
|
@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
# they could be DM-able. In the future, this could potentially
|
||||
# be configurable per-appservice whether the appservice users can be
|
||||
# contacted.
|
||||
if self.get_if_app_services_interested_in_user(user):
|
||||
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
|
||||
# TODO we might want to make this configurable for each app service
|
||||
return False
|
||||
|
||||
# Support users are for diagnostics and should not appear in the user directory.
|
||||
if await self.is_support_user(user):
|
||||
if await self.is_support_user(user): # type: ignore[attr-defined]
|
||||
return False
|
||||
|
||||
# Deactivated users aren't contactable, so should not appear in the user directory.
|
||||
try:
|
||||
if await self.get_user_deactivated_status(user):
|
||||
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
|
||||
return False
|
||||
except StoreError:
|
||||
# No such user in the users table. No need to do this when calling
|
||||
|
@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
)
|
||||
|
||||
current_state_ids = await self.get_filtered_current_state_ids(
|
||||
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
|
||||
room_id, StateFilter.from_types(types_to_filter)
|
||||
)
|
||||
|
||||
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_id:
|
||||
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
|
||||
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
|
||||
if join_rule_ev:
|
||||
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
|
||||
return True
|
||||
|
||||
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
if hist_vis_id:
|
||||
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
|
||||
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
|
||||
if hist_vis_ev:
|
||||
if (
|
||||
hist_vis_ev.content.get("history_visibility")
|
||||
|
|
|
@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.appservice.api import ApplicationService
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||
|
||||
# Define a state map type from type/state_key to T (usually an event ID or
|
||||
# event)
|
||||
|
@ -485,7 +485,7 @@ class RoomStreamToken:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
|
||||
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
|
||||
try:
|
||||
if string[0] == "s":
|
||||
return cls(topological=None, stream=int(string[1:]))
|
||||
|
@ -502,7 +502,7 @@ class RoomStreamToken:
|
|||
instance_id = int(key)
|
||||
pos = int(value)
|
||||
|
||||
instance_name = await store.get_name_from_instance_id(instance_id)
|
||||
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
|
||||
instance_map[instance_name] = pos
|
||||
|
||||
return cls(
|
||||
|
|
Loading…
Reference in New Issue