Remove not needed database updates in modify user admin API (#10627)

This commit is contained in:
Dirk Klimpel 2021-08-19 11:25:05 +02:00 committed by GitHub
parent 0c3565da4c
commit 220f901229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 33 deletions

1
changelog.d/10627.misc Normal file
View File

@ -0,0 +1 @@
Remove not needed database updates in modify user admin API.

View File

@ -21,11 +21,15 @@ It returns a JSON body like the following:
"threepids": [ "threepids": [
{ {
"medium": "email", "medium": "email",
"address": "<user_mail_1>" "address": "<user_mail_1>",
"added_at": 1586458409743,
"validated_at": 1586458409743
}, },
{ {
"medium": "email", "medium": "email",
"address": "<user_mail_2>" "address": "<user_mail_2>",
"added_at": 1586458409743,
"validated_at": 1586458409743
} }
], ],
"avatar_url": "<avatar_url>", "avatar_url": "<avatar_url>",

View File

@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
if not isinstance(deactivate, bool): if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean") raise SynapseError(400, "'deactivated' parameter is not of type boolean")
# convert into List[Tuple[str, str]] # convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None: if external_ids is not None:
new_external_ids = [] new_external_ids = {
for external_id in external_ids:
new_external_ids.append(
(external_id["auth_provider"], external_id["external_id"]) (external_id["auth_provider"], external_id["external_id"])
) for external_id in external_ids
}
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
if threepids is not None:
new_threepids = {
(threepid["medium"], threepid["address"]) for threepid in threepids
}
if user: # modify user if user: # modify user
if "displayname" in body: if "displayname" in body:
@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
) )
if threepids is not None: if threepids is not None:
# remove old threepids from user # get changed threepids (added and removed)
old_threepids = await self.store.user_get_threepids(user_id) # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
for threepid in old_threepids: cur_threepids = {
(threepid["medium"], threepid["address"])
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
del_threepids = cur_threepids - new_threepids
# remove old threepids
for medium, address in del_threepids:
try: try:
await self.auth_handler.delete_threepid( await self.auth_handler.delete_threepid(
user_id, threepid["medium"], threepid["address"], None user_id, medium, address, None
) )
except Exception: except Exception:
logger.exception("Failed to remove threepids") logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids")
# add new threepids to user # add new threepids
current_time = self.hs.get_clock().time_msec() current_time = self.hs.get_clock().time_msec()
for threepid in threepids: for medium, address in add_threepids:
await self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time user_id, medium, address, current_time
) )
if external_ids is not None: if external_ids is not None:
# get changed external_ids (added and removed) # get changed external_ids (added and removed)
cur_external_ids = await self.store.get_external_ids_by_user(user_id) cur_external_ids = set(
add_external_ids = set(new_external_ids) - set(cur_external_ids) await self.store.get_external_ids_by_user(user_id)
del_external_ids = set(cur_external_ids) - set(new_external_ids) )
add_external_ids = new_external_ids - cur_external_ids
del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids # remove old external_ids
for auth_provider, external_id in del_external_ids: for auth_provider, external_id in del_external_ids:
@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
if threepids is not None: if threepids is not None:
current_time = self.hs.get_clock().time_msec() current_time = self.hs.get_clock().time_msec()
for threepid in threepids: for medium, address in new_threepids:
await self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time user_id, medium, address, current_time
) )
if ( if (
self.hs.config.email_enable_notifs self.hs.config.email_enable_notifs
@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
kind="email", kind="email",
app_id="m.email", app_id="m.email",
app_display_name="Email Notifications", app_display_name="Email Notifications",
device_display_name=threepid["address"], device_display_name=address,
pushkey=threepid["address"], pushkey=address,
lang=None, # We don't know a user's language here lang=None, # We don't know a user's language here
data={}, data={},
) )

View File

@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
return user_id return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address): def get_user_id_by_threepid_txn(
self, txn, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid """Returns user id from threepid
Args: Args:
txn (cursor): txn (cursor):
medium (str): threepid medium e.g. email medium: threepid medium e.g. email
address (str): threepid address e.g. me@example.com address: threepid address e.g. me@example.com
Returns: Returns:
str|None: user id or None if no user id/threepid mapping exists user id, or None if no user id/threepid mapping exists
""" """
ret = self.db_pool.simple_select_one_txn( ret = self.db_pool.simple_select_one_txn(
txn, txn,
@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"] return ret["user_id"]
return None return None
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): async def user_add_threepid(
self,
user_id: str,
medium: str,
address: str,
validated_at: int,
added_at: int,
) -> None:
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
) )
async def user_get_threepids(self, user_id): async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
{"user_id": user_id}, {"user_id": user_id},
@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids", "user_get_threepids",
) )
async def user_delete_threepid(self, user_id, medium, address) -> None: async def user_delete_threepid(
self, user_id: str, medium: str, address: str
) -> None:
await self.db_pool.simple_delete( await self.db_pool.simple_delete(
"user_threepids", "user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address}, keyvalues={"user_id": user_id, "medium": medium, "address": address},

View File

@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual( self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"] "external_id1", channel.json_body["external_ids"][0]["external_id"]
) )
self.assertEqual( self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
) )
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test setting threepid for an other user. Test setting threepid for an other user.
""" """
# Delete old and add new threepid to user # Add two threepids to user
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, content={
"threepids": [
{"medium": "email", "address": "bob1@bob.bob"},
{"medium": "email", "address": "bob2@bob.bob"},
],
},
) )
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
sorted_result = sorted(
channel.json_body["threepids"], key=lambda k: k["address"]
)
self.assertEqual("email", sorted_result[0]["medium"])
self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
self.assertEqual("email", sorted_result[1]["medium"])
self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
self._check_fields(channel.json_body)
# Set a new and remove a threepid
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={
"threepids": [
{"medium": "email", "address": "bob2@bob.bob"},
{"medium": "email", "address": "bob3@bob.bob"},
],
},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
self._check_fields(channel.json_body)
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
self._check_fields(channel.json_body)
# Remove threepids
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"threepids": []},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
def test_set_external_id(self): def test_set_external_id(self):
""" """
@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
channel.json_body["external_ids"], channel.json_body["external_ids"],
[ [