Consistently check whether a password may be set for a user. (#9636)

This commit is contained in:
Dirk Klimpel 2021-03-18 17:54:08 +01:00 committed by GitHub
parent dd71eb0f8a
commit 8dd2ea65a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 133 additions and 68 deletions

1
changelog.d/9636.bugfix Normal file
View File

@ -0,0 +1 @@
Checks if passwords are allowed before setting it for the user.

View File

@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
logout_devices: bool, logout_devices: bool,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
) -> None: ) -> None:
if not self.hs.config.password_localdb_enabled: if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
try: try:

View File

@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
elif not deactivate and user["deactivated"]: elif not deactivate and user["deactivated"]:
if ( if (
"password" not in body "password" not in body
and self.hs.config.password_localdb_enabled and self.auth_handler.can_change_password()
): ):
raise SynapseError( raise SynapseError(
400, "Must provide a password to re-activate an account." 400, "Must provide a password to re-activate an account."

View File

@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,) txn, self.get_user_deactivated_status, (user_id,)
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
@cached() @cached()

View File

@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
# create users and get access tokens
# regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.get_success(
self.auth_handler.get_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None
)
)
self.other_user = self.register_user("user", "pass", displayname="User") self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.login("user", "pass") self.other_user_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None
)
)
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user self.other_user
) )
@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user # Get user
@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"]) self.assertFalse(channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self): def test_create_user(self):
@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user # Get user
@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"]) self.assertFalse(channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual(False, channel.json_body["shadow_banned"]) self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config( @override_config(
@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@override_config( @override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore # Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@override_config( @override_config(
{ {
@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
# Deactivate user # Deactivate user
body = json.dumps({"deactivated": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"deactivated": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertTrue(profile["display_name"] == "User") self.assertTrue(profile["display_name"] == "User")
# Deactivate user # Deactivate user
body = json.dumps({"deactivated": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"deactivated": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
# is not in user directory # is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user)) profile = self.get_success(self.store.get_user_in_directory(self.other_user))
self.assertTrue(profile is None) self.assertIsNone(profile)
# Set new displayname user # Set new displayname user
body = json.dumps({"displayname": "Foobar"})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"displayname": "Foobar"},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"]) self.assertEqual("Foobar", channel.json_body["displayname"])
# is not in user directory # is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user)) profile = self.get_success(self.store.get_user_in_directory(self.other_user))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_reactivate_user(self): def test_reactivate_user(self):
""" """
@ -1520,24 +1527,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
""" """
# Deactivate the user. # Deactivate the user.
channel = self.make_request( self._deactivate_user("@user:test")
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._is_erased("@user:test", False)
d = self.store.mark_user_erased("@user:test")
self.assertIsNone(self.get_success(d))
self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password). # Attempt to reactivate the user (without a password).
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), content={"deactivated": False},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@ -1546,22 +1543,76 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False, "password": "foo"}).encode( content={"deactivated": False, "password": "foo"},
encoding="utf_8"
),
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
@override_config({"password_config": {"localdb_enabled": False}})
def test_reactivate_user_localdb_disabled(self):
"""
Test reactivating another user when using SSO.
"""
# Deactivate the user.
self._deactivate_user("@user:test")
# Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
@override_config({"password_config": {"enabled": False}})
def test_reactivate_user_password_disabled(self):
"""
Test reactivating another user when using SSO.
"""
# Deactivate the user.
self._deactivate_user("@user:test")
# Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False) self._is_erased("@user:test", False)
def test_set_user_as_admin(self): def test_set_user_as_admin(self):
@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
""" """
# Set a user as an admin # Set a user as an admin
body = json.dumps({"admin": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"admin": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self): def test_accidental_deactivation_prevention(self):
""" """
@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test" url = "/_synapse/admin/v2/users/@bob:test"
# Create user # Create user
body = json.dumps({"password": "abc123"})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"password": "abc123"},
) )
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool) # Change password (and use a str for deactivate instead of a bool)
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"password": "abc123", "deactivated": "false"},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive # Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
def _is_erased(self, user_id, expect): def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not""" """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id) d = self.store.is_user_erased(user_id)
if expect: if expect:
@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
else: else:
self.assertFalse(self.get_success(d)) self.assertFalse(self.get_success(d))
def _deactivate_user(self, user_id: str) -> None:
"""Deactivate user and set as erased"""
# Deactivate the user.
channel = self.make_request(
"PUT",
"/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
access_token=self.admin_user_tok,
content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
class UserMembershipRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase):