Type hints for RegistrationStore (#8615)
This commit is contained in:
parent
2ac908f377
commit
a9f90fa73a
|
@ -0,0 +1 @@
|
|||
Type hints for `RegistrationStore`.
|
1
mypy.ini
1
mypy.ini
|
@ -57,6 +57,7 @@ files =
|
|||
synapse/spam_checker_api,
|
||||
synapse/state,
|
||||
synapse/storage/databases/main/events.py,
|
||||
synapse/storage/databases/main/registration.py,
|
||||
synapse/storage/databases/main/stream.py,
|
||||
synapse/storage/databases/main/ui_auth.py,
|
||||
synapse/storage/database.py,
|
||||
|
|
|
@ -146,7 +146,6 @@ class DataStore(
|
|||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||
)
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
|
|
|
@ -16,29 +16,33 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.storage.util.id_generators import IdGenerator
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistrationWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# Note: we don't check this sequence for consistency as we'd have to
|
||||
# call `find_max_generated_user_id_localpart` each time, which is
|
||||
|
@ -55,7 +59,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
# Create a background job for culling expired 3PID validity tokens
|
||||
if hs.config.run_background_tasks:
|
||||
self.clock.looping_call(
|
||||
self._clock.looping_call(
|
||||
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
|
||||
)
|
||||
|
||||
|
@ -92,7 +96,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
if not info:
|
||||
return False
|
||||
|
||||
now = self.clock.time_msec()
|
||||
now = self._clock.time_msec()
|
||||
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||
return is_trial
|
||||
|
@ -257,7 +261,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
return await self.db_pool.runInteraction(
|
||||
"get_users_expiring_soon",
|
||||
select_users_txn,
|
||||
self.clock.time_msec(),
|
||||
self._clock.time_msec(),
|
||||
self.config.account_validity.renew_at,
|
||||
)
|
||||
|
||||
|
@ -328,13 +332,17 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
|
||||
" access_tokens.device_id, access_tokens.valid_until_ms"
|
||||
" FROM users"
|
||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||
" WHERE token = ?"
|
||||
)
|
||||
sql = """
|
||||
SELECT users.name,
|
||||
users.is_guest,
|
||||
users.shadow_banned,
|
||||
access_tokens.id as token_id,
|
||||
access_tokens.device_id,
|
||||
access_tokens.valid_until_ms
|
||||
FROM users
|
||||
INNER JOIN access_tokens on users.name = access_tokens.user_id
|
||||
WHERE token = ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
@ -803,7 +811,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
await self.db_pool.runInteraction(
|
||||
"cull_expired_threepid_validation_tokens",
|
||||
cull_expired_threepid_validation_tokens_txn,
|
||||
self.clock.time_msec(),
|
||||
self._clock.time_msec(),
|
||||
)
|
||||
|
||||
@wrap_as_background_process("account_validity_set_expiration_dates")
|
||||
|
@ -890,10 +898,10 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self._clock = hs.get_clock()
|
||||
self.config = hs.config
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -1016,13 +1024,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
|
||||
return 1
|
||||
|
||||
async def set_user_deactivated_status(
|
||||
self, user_id: str, deactivated: bool
|
||||
) -> None:
|
||||
"""Set the `deactivated` property for the provided user to the provided value.
|
||||
|
||||
class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
Args:
|
||||
user_id: The ID of the user to set the status for.
|
||||
deactivated: The value to set for `deactivated`.
|
||||
"""
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_user_deactivated_status",
|
||||
self.set_user_deactivated_status_txn,
|
||||
user_id,
|
||||
deactivated,
|
||||
)
|
||||
|
||||
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
allow_none=True,
|
||||
desc="is_guest",
|
||||
)
|
||||
|
||||
return res if res else False
|
||||
|
||||
|
||||
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
|
||||
async def add_access_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -1138,19 +1189,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
def _register_user(
|
||||
self,
|
||||
txn,
|
||||
user_id,
|
||||
password_hash,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id,
|
||||
create_profile_with_displayname,
|
||||
admin,
|
||||
user_type,
|
||||
shadow_banned,
|
||||
user_id: str,
|
||||
password_hash: Optional[str],
|
||||
was_guest: bool,
|
||||
make_guest: bool,
|
||||
appservice_id: Optional[str],
|
||||
create_profile_with_displayname: Optional[str],
|
||||
admin: bool,
|
||||
user_type: Optional[str],
|
||||
shadow_banned: bool,
|
||||
):
|
||||
user_id_obj = UserID.from_string(user_id)
|
||||
|
||||
now = int(self.clock.time())
|
||||
now = int(self._clock.time())
|
||||
|
||||
try:
|
||||
if was_guest:
|
||||
|
@ -1374,18 +1425,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
allow_none=True,
|
||||
desc="is_guest",
|
||||
)
|
||||
|
||||
return res if res else False
|
||||
|
||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||
"""
|
||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||
|
@ -1479,7 +1518,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
txn,
|
||||
table="threepid_validation_session",
|
||||
keyvalues={"session_id": session_id},
|
||||
updatevalues={"validated_at": self.clock.time_msec()},
|
||||
updatevalues={"validated_at": self._clock.time_msec()},
|
||||
)
|
||||
|
||||
return next_link
|
||||
|
@ -1547,35 +1586,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
start_or_continue_validation_session_txn,
|
||||
)
|
||||
|
||||
async def set_user_deactivated_status(
|
||||
self, user_id: str, deactivated: bool
|
||||
) -> None:
|
||||
"""Set the `deactivated` property for the provided user to the provided value.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to set the status for.
|
||||
deactivated: The value to set for `deactivated`.
|
||||
"""
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_user_deactivated_status",
|
||||
self.set_user_deactivated_status_txn,
|
||||
user_id,
|
||||
deactivated,
|
||||
)
|
||||
|
||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
|
||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue