Type hints for RegistrationStore (#8615)

This commit is contained in:
Erik Johnston 2020-10-22 11:56:58 +01:00 committed by GitHub
parent 2ac908f377
commit a9f90fa73a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 74 deletions

1
changelog.d/8615.misc Normal file
View File

@ -0,0 +1 @@
Type hints for `RegistrationStore`.

View File

@ -57,6 +57,7 @@ files =
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,

View File

@ -146,7 +146,6 @@ class DataStore(
db_conn, "e2e_cross_signing_keys", "stream_id" db_conn, "e2e_cross_signing_keys", "stream_id"
) )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")

View File

@ -16,29 +16,33 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to # Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is
@ -55,7 +59,7 @@ class RegistrationWorkerStore(SQLBaseStore):
# Create a background job for culling expired 3PID validity tokens # Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks: if hs.config.run_background_tasks:
self.clock.looping_call( self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
) )
@ -92,7 +96,7 @@ class RegistrationWorkerStore(SQLBaseStore):
if not info: if not info:
return False return False
now = self.clock.time_msec() now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial return is_trial
@ -257,7 +261,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_expiring_soon", "get_users_expiring_soon",
select_users_txn, select_users_txn,
self.clock.time_msec(), self._clock.time_msec(),
self.config.account_validity.renew_at, self.config.account_validity.renew_at,
) )
@ -328,13 +332,17 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = """
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," SELECT users.name,
" access_tokens.device_id, access_tokens.valid_until_ms" users.is_guest,
" FROM users" users.shadow_banned,
" INNER JOIN access_tokens on users.name = access_tokens.user_id" access_tokens.id as token_id,
" WHERE token = ?" access_tokens.device_id,
) access_tokens.valid_until_ms
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
WHERE token = ?
"""
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
@ -803,7 +811,7 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens", "cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn, cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(), self._clock.time_msec(),
) )
@wrap_as_background_process("account_validity_set_expiration_dates") @wrap_as_background_process("account_validity_set_expiration_dates")
@ -890,10 +898,10 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.clock = hs.get_clock() self._clock = hs.get_clock()
self.config = hs.config self.config = hs.config
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -1016,13 +1024,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1 return 1
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
class RegistrationStore(RegistrationBackgroundUpdateStore): Args:
def __init__(self, database: DatabasePool, db_conn, hs): user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
async def add_access_token_to_user( async def add_access_token_to_user(
self, self,
user_id: str, user_id: str,
@ -1138,19 +1189,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def _register_user( def _register_user(
self, self,
txn, txn,
user_id, user_id: str,
password_hash, password_hash: Optional[str],
was_guest, was_guest: bool,
make_guest, make_guest: bool,
appservice_id, appservice_id: Optional[str],
create_profile_with_displayname, create_profile_with_displayname: Optional[str],
admin, admin: bool,
user_type, user_type: Optional[str],
shadow_banned, shadow_banned: bool,
): ):
user_id_obj = UserID.from_string(user_id) user_id_obj = UserID.from_string(user_id)
now = int(self.clock.time()) now = int(self._clock.time())
try: try:
if was_guest: if was_guest:
@ -1374,18 +1425,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f) await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
async def add_user_pending_deactivation(self, user_id: str) -> None: async def add_user_pending_deactivation(self, user_id: str) -> None:
""" """
Adds a user to the table of users who need to be parted from all the rooms they're Adds a user to the table of users who need to be parted from all the rooms they're
@ -1479,7 +1518,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()}, updatevalues={"validated_at": self._clock.time_msec()},
) )
return next_link return next_link
@ -1547,35 +1586,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn, start_or_continue_validation_session_txn,
) )
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def find_max_generated_user_id_localpart(cur: Cursor) -> int: def find_max_generated_user_id_localpart(cur: Cursor) -> int:
""" """