Add `get_userinfo_by_id` method to `ModuleApi` (#9581)

Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions.

Signed-off-by: Jason Robinson <jasonr@matrix.org>
This commit is contained in:
Jason Robinson 2021-08-04 13:40:25 +03:00 committed by GitHub
parent 72935b7c50
commit c2000ab35b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 2 deletions

1
changelog.d/9581.feature Normal file
View File

@ -0,0 +1 @@
Add `get_userinfo_by_id` method to ModuleApi.

View File

@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -174,6 +174,16 @@ class ModuleApi:
"""The application name configured in the homeserver's configuration.""" """The application name configured in the homeserver's configuration."""
return self._hs.config.email.email_app_name return self._hs.config.email.email_app_name
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get user info by user_id
Args:
user_id: Fully qualified user id.
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
async def get_user_by_req( async def get_user_by_req(
self, self,
req: SynapseRequest, req: SynapseRequest,

View File

@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID from synapse.types import UserID, UserInfo
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached() @cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id", desc="get_user_by_id",
) )
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
return None
return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
)
async def is_trial_user(self, user_id: str) -> bool: async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first """Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config N days of registration defined by `mau_trial_days` config

View File

@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info):
# and return that one key # and return that one key
for key_id, key_data in keys.items(): for key_id, key_data in keys.items():
return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
"""
user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool

View File

@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob")) displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino") self.assertEqual(displayname, "Bobberino")
def test_get_userinfo_by_id(self):
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
def test_get_userinfo_by_id__no_user_found(self):
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
def test_sending_events_into_room(self): def test_sending_events_into_room(self):
"""Tests that a module can send events into a room""" """Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent # Mock out create_and_send_nonmember_event to check whether events are being sent