Type hints for RegistrationStore (#8615)
This commit is contained in:
parent
2ac908f377
commit
a9f90fa73a
|
@ -0,0 +1 @@
|
||||||
|
Type hints for `RegistrationStore`.
|
1
mypy.ini
1
mypy.ini
|
@ -57,6 +57,7 @@ files =
|
||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
synapse/storage/database.py,
|
synapse/storage/database.py,
|
||||||
|
|
|
@ -146,7 +146,6 @@ class DataStore(
|
||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
|
|
|
@ -16,29 +16,33 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
|
from synapse.storage.types import Connection, Cursor
|
||||||
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(SQLBaseStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
# Note: we don't check this sequence for consistency as we'd have to
|
# Note: we don't check this sequence for consistency as we'd have to
|
||||||
# call `find_max_generated_user_id_localpart` each time, which is
|
# call `find_max_generated_user_id_localpart` each time, which is
|
||||||
|
@ -55,7 +59,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
# Create a background job for culling expired 3PID validity tokens
|
# Create a background job for culling expired 3PID validity tokens
|
||||||
if hs.config.run_background_tasks:
|
if hs.config.run_background_tasks:
|
||||||
self.clock.looping_call(
|
self._clock.looping_call(
|
||||||
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
|
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,7 +96,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
if not info:
|
if not info:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
||||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||||
return is_trial
|
return is_trial
|
||||||
|
@ -257,7 +261,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_users_expiring_soon",
|
"get_users_expiring_soon",
|
||||||
select_users_txn,
|
select_users_txn,
|
||||||
self.clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
self.config.account_validity.renew_at,
|
self.config.account_validity.renew_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -328,13 +332,17 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token):
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
|
SELECT users.name,
|
||||||
" access_tokens.device_id, access_tokens.valid_until_ms"
|
users.is_guest,
|
||||||
" FROM users"
|
users.shadow_banned,
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
access_tokens.id as token_id,
|
||||||
" WHERE token = ?"
|
access_tokens.device_id,
|
||||||
)
|
access_tokens.valid_until_ms
|
||||||
|
FROM users
|
||||||
|
INNER JOIN access_tokens on users.name = access_tokens.user_id
|
||||||
|
WHERE token = ?
|
||||||
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
txn.execute(sql, (token,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
@ -803,7 +811,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"cull_expired_threepid_validation_tokens",
|
"cull_expired_threepid_validation_tokens",
|
||||||
cull_expired_threepid_validation_tokens_txn,
|
cull_expired_threepid_validation_tokens_txn,
|
||||||
self.clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("account_validity_set_expiration_dates")
|
@wrap_as_background_process("account_validity_set_expiration_dates")
|
||||||
|
@ -890,10 +898,10 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
|
@ -1016,13 +1024,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
async def set_user_deactivated_status(
|
||||||
|
self, user_id: str, deactivated: bool
|
||||||
|
) -> None:
|
||||||
|
"""Set the `deactivated` property for the provided user to the provided value.
|
||||||
|
|
||||||
class RegistrationStore(RegistrationBackgroundUpdateStore):
|
Args:
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
user_id: The ID of the user to set the status for.
|
||||||
|
deactivated: The value to set for `deactivated`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"set_user_deactivated_status",
|
||||||
|
self.set_user_deactivated_status_txn,
|
||||||
|
user_id,
|
||||||
|
deactivated,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_deactivated_status, (user_id,)
|
||||||
|
)
|
||||||
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def is_guest(self, user_id: str) -> bool:
|
||||||
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
retcol="is_guest",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_guest",
|
||||||
|
)
|
||||||
|
|
||||||
|
return res if res else False
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||||
|
|
||||||
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
|
|
||||||
async def add_access_token_to_user(
|
async def add_access_token_to_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -1138,19 +1189,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
def _register_user(
|
def _register_user(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn,
|
||||||
user_id,
|
user_id: str,
|
||||||
password_hash,
|
password_hash: Optional[str],
|
||||||
was_guest,
|
was_guest: bool,
|
||||||
make_guest,
|
make_guest: bool,
|
||||||
appservice_id,
|
appservice_id: Optional[str],
|
||||||
create_profile_with_displayname,
|
create_profile_with_displayname: Optional[str],
|
||||||
admin,
|
admin: bool,
|
||||||
user_type,
|
user_type: Optional[str],
|
||||||
shadow_banned,
|
shadow_banned: bool,
|
||||||
):
|
):
|
||||||
user_id_obj = UserID.from_string(user_id)
|
user_id_obj = UserID.from_string(user_id)
|
||||||
|
|
||||||
now = int(self.clock.time())
|
now = int(self._clock.time())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
|
@ -1374,18 +1425,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
await self.db_pool.runInteraction("delete_access_token", f)
|
await self.db_pool.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
@cached()
|
|
||||||
async def is_guest(self, user_id: str) -> bool:
|
|
||||||
res = await self.db_pool.simple_select_one_onecol(
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user_id},
|
|
||||||
retcol="is_guest",
|
|
||||||
allow_none=True,
|
|
||||||
desc="is_guest",
|
|
||||||
)
|
|
||||||
|
|
||||||
return res if res else False
|
|
||||||
|
|
||||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||||
|
@ -1479,7 +1518,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
txn,
|
txn,
|
||||||
table="threepid_validation_session",
|
table="threepid_validation_session",
|
||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
updatevalues={"validated_at": self.clock.time_msec()},
|
updatevalues={"validated_at": self._clock.time_msec()},
|
||||||
)
|
)
|
||||||
|
|
||||||
return next_link
|
return next_link
|
||||||
|
@ -1547,35 +1586,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
start_or_continue_validation_session_txn,
|
start_or_continue_validation_session_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_user_deactivated_status(
|
|
||||||
self, user_id: str, deactivated: bool
|
|
||||||
) -> None:
|
|
||||||
"""Set the `deactivated` property for the provided user to the provided value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The ID of the user to set the status for.
|
|
||||||
deactivated: The value to set for `deactivated`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
|
||||||
"set_user_deactivated_status",
|
|
||||||
self.set_user_deactivated_status_txn,
|
|
||||||
user_id,
|
|
||||||
deactivated,
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
|
||||||
self.db_pool.simple_update_one_txn(
|
|
||||||
txn=txn,
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user_id},
|
|
||||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
|
||||||
)
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.get_user_deactivated_status, (user_id,)
|
|
||||||
)
|
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
|
||||||
|
|
||||||
|
|
||||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue