Add `get_userinfo_by_id` method to `ModuleApi` (#9581)
Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson <jasonr@matrix.org>
This commit is contained in:
parent
72935b7c50
commit
c2000ab35b
|
@ -0,0 +1 @@
|
||||||
|
Add `get_userinfo_by_id` method to ModuleApi.
|
|
@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
@ -174,6 +174,16 @@ class ModuleApi:
|
||||||
"""The application name configured in the homeserver's configuration."""
|
"""The application name configured in the homeserver's configuration."""
|
||||||
return self._hs.config.email.email_app_name
|
return self._hs.config.email.email_app_name
|
||||||
|
|
||||||
|
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||||
|
"""Get user info by user_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Fully qualified user id.
|
||||||
|
Returns:
|
||||||
|
UserInfo object if a user was found, otherwise None
|
||||||
|
"""
|
||||||
|
return await self._store.get_userinfo_by_id(user_id)
|
||||||
|
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
req: SynapseRequest,
|
req: SynapseRequest,
|
||||||
|
|
|
@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.storage.util.id_generators import IdGenerator
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, UserInfo
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Deprecated: use get_userinfo_by_id instead"""
|
||||||
return await self.db_pool.simple_select_one(
|
return await self.db_pool.simple_select_one(
|
||||||
table="users",
|
table="users",
|
||||||
keyvalues={"name": user_id},
|
keyvalues={"name": user_id},
|
||||||
|
@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
desc="get_user_by_id",
|
desc="get_user_by_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||||
|
"""Get a UserInfo object for a user by user ID.
|
||||||
|
|
||||||
|
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
|
||||||
|
this method should be cached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user to fetch user info for.
|
||||||
|
Returns:
|
||||||
|
`UserInfo` object if user found, otherwise `None`.
|
||||||
|
"""
|
||||||
|
user_data = await self.get_user_by_id(user_id)
|
||||||
|
if not user_data:
|
||||||
|
return None
|
||||||
|
return UserInfo(
|
||||||
|
appservice_id=user_data["appservice_id"],
|
||||||
|
consent_server_notice_sent=user_data["consent_server_notice_sent"],
|
||||||
|
consent_version=user_data["consent_version"],
|
||||||
|
creation_ts=user_data["creation_ts"],
|
||||||
|
is_admin=bool(user_data["admin"]),
|
||||||
|
is_deactivated=bool(user_data["deactivated"]),
|
||||||
|
is_guest=bool(user_data["is_guest"]),
|
||||||
|
is_shadow_banned=bool(user_data["shadow_banned"]),
|
||||||
|
user_id=UserID.from_string(user_data["name"]),
|
||||||
|
user_type=user_data["user_type"],
|
||||||
|
)
|
||||||
|
|
||||||
async def is_trial_user(self, user_id: str) -> bool:
|
async def is_trial_user(self, user_id: str) -> bool:
|
||||||
"""Checks if user is in the "trial" period, i.e. within the first
|
"""Checks if user is in the "trial" period, i.e. within the first
|
||||||
N days of registration defined by `mau_trial_days` config
|
N days of registration defined by `mau_trial_days` config
|
||||||
|
|
|
@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info):
|
||||||
# and return that one key
|
# and return that one key
|
||||||
for key_id, key_data in keys.items():
|
for key_id, key_data in keys.items():
|
||||||
return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
|
return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
class UserInfo:
|
||||||
|
"""Holds information about a user. Result of get_userinfo_by_id.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
user_id: ID of the user.
|
||||||
|
appservice_id: Application service ID that created this user.
|
||||||
|
consent_server_notice_sent: Version of policy documents the user has been sent.
|
||||||
|
consent_version: Version of policy documents the user has consented to.
|
||||||
|
creation_ts: Creation timestamp of the user.
|
||||||
|
is_admin: True if the user is an admin.
|
||||||
|
is_deactivated: True if the user has been deactivated.
|
||||||
|
is_guest: True if the user is a guest user.
|
||||||
|
is_shadow_banned: True if the user has been shadow-banned.
|
||||||
|
user_type: User type (None for normal user, 'support' and 'bot' other options).
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id: UserID
|
||||||
|
appservice_id: Optional[int]
|
||||||
|
consent_server_notice_sent: Optional[str]
|
||||||
|
consent_version: Optional[str]
|
||||||
|
user_type: Optional[str]
|
||||||
|
creation_ts: int
|
||||||
|
is_admin: bool
|
||||||
|
is_deactivated: bool
|
||||||
|
is_guest: bool
|
||||||
|
is_shadow_banned: bool
|
||||||
|
|
|
@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
||||||
self.assertEqual(displayname, "Bobberino")
|
self.assertEqual(displayname, "Bobberino")
|
||||||
|
|
||||||
|
def test_get_userinfo_by_id(self):
|
||||||
|
user_id = self.register_user("alice", "1234")
|
||||||
|
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
|
||||||
|
self.assertEqual(found_user.user_id.to_string(), user_id)
|
||||||
|
self.assertIdentical(found_user.is_admin, False)
|
||||||
|
|
||||||
|
def test_get_userinfo_by_id__no_user_found(self):
|
||||||
|
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
|
||||||
|
self.assertIsNone(found_user)
|
||||||
|
|
||||||
def test_sending_events_into_room(self):
|
def test_sending_events_into_room(self):
|
||||||
"""Tests that a module can send events into a room"""
|
"""Tests that a module can send events into a room"""
|
||||||
# Mock out create_and_send_nonmember_event to check whether events are being sent
|
# Mock out create_and_send_nonmember_event to check whether events are being sent
|
||||||
|
|
Loading…
Reference in New Issue