diff --git a/changelog.d/11051.bugfix b/changelog.d/11051.bugfix new file mode 100644 index 0000000000..63126843d2 --- /dev/null +++ b/changelog.d/11051.bugfix @@ -0,0 +1 @@ +Fix a bug where setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f20aa65301..c0bebc3cf0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -228,12 +229,12 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert List[Dict[str, str]] into Set[Tuple[str, str]] + # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: - new_external_ids = { + new_external_ids = [ (external_id["auth_provider"], external_id["external_id"]) for external_id in external_ids - } + ] # convert List[Dict[str, str]] into Set[Tuple[str, str]] if threepids is not None: @@ -275,28 +276,13 @@ class UserRestServletV2(RestServlet): ) if external_ids is not None: - # get changed external_ids (added and removed) - cur_external_ids = set( - await self.store.get_external_ids_by_user(user_id) - ) - add_external_ids = new_external_ids - cur_external_ids - del_external_ids = cur_external_ids - new_external_ids - - # remove old external_ids - for auth_provider, external_id in del_external_ids: - await self.store.remove_user_external_id( - auth_provider, - external_id, - user_id, - ) - - # add new external_ids - for auth_provider, external_id in add_external_ids: - await self.store.record_user_external_id( - auth_provider, - external_id, + try: + await self.store.replace_user_external_id( + new_external_ids, user_id, ) + except ExternalIDReuseException: + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -384,12 +370,15 @@ class UserRestServletV2(RestServlet): ) if external_ids is not None: - for auth_provider, external_id in new_external_ids: - await self.store.record_user_external_id( - auth_provider, - external_id, - user_id, - ) + try: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + except ExternalIDReuseException: + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0ab56d8a07..37d47aa823 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,11 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Cursor @@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) +class ExternalIDReuseException(Exception): + """Exception if writing an external id for a user fails, + because this external id is given to an other user.""" + + pass + + @attr.s(frozen=True, slots=True) class TokenLookupResult: """Result of looking up an access token. @@ -588,24 +599,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: + ExternalIDReuseException if the new external_id could not be mapped. """ - await self.db_pool.simple_insert( + + try: + await self.db_pool.runInteraction( + "record_user_external_id", + self._record_user_external_id_txn, + auth_provider, + external_id, + user_id, + ) + except self.database_engine.module.IntegrityError: + raise ExternalIDReuseException() + + def _record_user_external_id_txn( + self, + txn: LoggingTransaction, + auth_provider: str, + external_id: str, + user_id: str, + ) -> None: + + self.db_pool.simple_insert_txn( + txn, table="user_external_ids", values={ "auth_provider": auth_provider, "external_id": external_id, "user_id": user_id, }, - desc="record_user_external_id", ) async def remove_user_external_id( self, auth_provider: str, external_id: str, user_id: str ) -> None: """Remove a mapping from an external user id to a mxid - If the mapping is not found, this method does nothing. - Args: auth_provider: identifier for the remote auth provider external_id: id on that system @@ -621,6 +652,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="remove_user_external_id", ) + async def replace_user_external_id( + self, + record_external_ids: List[Tuple[str, str]], + user_id: str, + ) -> None: + """Replace mappings from external user ids to a mxid in a single transaction. + All mappings are deleted and the new ones are created. + + Args: + record_external_ids: + List with tuple of auth_provider and external_id to record + user_id: complete mxid that it is mapped to + Raises: + ExternalIDReuseException if the new external_id could not be mapped. + """ + + def _remove_user_external_ids_txn( + txn: LoggingTransaction, + user_id: str, + ) -> None: + """Remove all mappings from external user ids to a mxid + If these mappings are not found, this method does nothing. + + Args: + user_id: complete mxid that it is mapped to + """ + + self.db_pool.simple_delete_txn( + txn, + table="user_external_ids", + keyvalues={"user_id": user_id}, + ) + + def _replace_user_external_id_txn( + txn: LoggingTransaction, + ): + _remove_user_external_ids_txn(txn, user_id) + + for auth_provider, external_id in record_external_ids: + self._record_user_external_id_txn( + txn, + auth_provider, + external_id, + user_id, + ) + + try: + await self.db_pool.runInteraction( + "replace_user_external_id", + _replace_user_external_id_txn, + ) + except self.database_engine.module.IntegrityError: + raise ExternalIDReuseException() + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c9e2754b09..839442ddba 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1180,9 +1180,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.other_user, device_id=None, valid_until_ms=None ) ) - self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( - self.other_user - ) + self.url_prefix = "/_synapse/admin/v2/users/%s" + self.url_other_user = self.url_prefix % self.other_user def test_requester_is_no_admin(self): """ @@ -1738,6 +1737,93 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) + def test_set_duplicate_threepid(self): + """ + Test setting the same threepid for a second user. + First user loses and second user gets mapping of this threepid. + """ + + # create a user to set a threepid + first_user = self.register_user("first_user", "pass") + url_first_user = self.url_prefix % first_user + + # Add threepid to first user + channel = self.make_request( + "PUT", + url_first_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["threepids"])) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"]) + self._check_fields(channel.json_body) + + # Add threepids to other user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["threepids"])) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self._check_fields(channel.json_body) + + # Add two new threepids to other user + # one is used by first_user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, + ) + + # other user has this two threepids + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob3@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # first_user has no threepid anymore + channel = self.make_request( + "GET", + url_first_user, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) + def test_set_external_id(self): """ Test setting external id for an other user. @@ -1836,6 +1922,129 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_set_duplicate_external_id(self): + """ + Test that setting the same external id for a second user fails and + external id from user must not be changed. + """ + + # create a user to use an external id + first_user = self.register_user("first_user", "pass") + url_first_user = self.url_prefix % first_user + + # Add an external id to first user + channel = self.make_request( + "PUT", + url_first_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # Add an external id to other user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # Add two new external_ids to other user + # one is used by first + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + # must fail + self.assertEqual(409, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("External id is already in use.", channel.json_body["error"]) + + # other user must not changed + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # first user must not changed + channel = self.make_request( + "GET", + url_first_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + def test_deactivate_user(self): """ Test deactivating another user.