From 327555dddf609211def9b8a73b8d7bf2f827c89e Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Mon, 5 Jun 2023 15:32:36 +0100 Subject: [PATCH] Fix other test cases --- tests/config/test_oauth_delegation.py | 3 ++ tests/handlers/test_oauth_delegation.py | 3 ++ tests/handlers/test_password_providers.py | 44 +++++++++++++++++------ tests/rest/admin/test_jwks.py | 2 ++ tests/rest/client/test_capabilities.py | 2 +- tests/rest/client/test_login.py | 2 +- tests/rest/test_well_known.py | 1 + 7 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py index f57c813a58..82cedb164c 100644 --- a/tests/config/test_oauth_delegation.py +++ b/tests/config/test_oauth_delegation.py @@ -61,6 +61,9 @@ class MSC3861OAuthDelegation(TestCase): **default_config("test"), "public_baseurl": BASE_URL, "enable_registration": False, + "login_via_existing_session": { + "enabled": False, + }, "experimental_features": { "msc3861": { "enabled": True, diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 6309d7b36e..9298cbbae8 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -115,6 +115,9 @@ class MSC3861OAuthDelegation(HomeserverTestCase): config = super().default_config() config["public_baseurl"] = BASE_URL config["disable_registration"] = True + config["login_via_existing_session"] = { + "enabled": False, + } config["experimental_features"] = { "msc3861": { "enabled": True, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 394006f5f3..1744e1b030 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -36,8 +36,9 @@ from tests.test_utils import make_awaitable from tests.unittest import override_config # Login flows we expect to appear in the list after the normal ones. -ADDITIONAL_LOGIN_FLOWS = [ +ADDITIONAL_LOGIN_FLOWS: List[Dict] = [ {"type": "m.login.application_service"}, + {"type": "m.login.token", "get_login_token": True}, ] # a mock instance which the dummy auth providers delegate to, so we can see what's going @@ -45,6 +46,10 @@ ADDITIONAL_LOGIN_FLOWS = [ mock_password_provider = Mock() +def sort_flows(flows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return sorted(flows, key=lambda f: f["type"]) + + class LegacyPasswordOnlyAuthProvider: """A legacy password_provider which only implements `check_password`.""" @@ -184,7 +189,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def password_only_auth_provider_login_test_body(self) -> None: # login flows should only have m.login.password flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + sort_flows(flows), + sort_flows([{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS), + ) # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(True) @@ -365,7 +373,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() - self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS) + self.assertEqual(sort_flows(flows), sort_flows(ADDITIONAL_LOGIN_FLOWS)) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("u", "p") @@ -386,9 +394,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # (password must come first, because reasons) flows = self._get_login_flows() self.assertEqual( - flows, - [{"type": "m.login.password"}, {"type": "test.login_type"}] - + ADDITIONAL_LOGIN_FLOWS, + sort_flows(flows), + sort_flows( + [{"type": "m.login.password"}, {"type": "test.login_type"}] + + ADDITIONAL_LOGIN_FLOWS + ), ) # login with missing param should be rejected @@ -519,7 +529,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + sort_flows(flows), + sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS), + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -554,7 +567,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + sort_flows(flows), + sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS), + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -585,7 +601,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + sort_flows(flows), + sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS), + ) # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") @@ -690,7 +709,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") flows = self._get_login_flows() - self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + self.assertEqual( + sort_flows(flows), + sort_flows([{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS), + ) # password login shouldn't work and should be rejected with a 400 # ("unknown login type") @@ -928,7 +950,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body - def _get_login_flows(self) -> JsonDict: + def _get_login_flows(self) -> List[JsonDict]: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["flows"] diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py index a9a6191c73..30eb1feae1 100644 --- a/tests/rest/admin/test_jwks.py +++ b/tests/rest/admin/test_jwks.py @@ -45,6 +45,7 @@ class JWKSTestCase(HomeserverTestCase): @override_config( { "disable_registration": True, + "login_via_existing_session": {"enabled": False}, "experimental_features": { "msc3861": { "enabled": True, @@ -65,6 +66,7 @@ class JWKSTestCase(HomeserverTestCase): @override_config( { "disable_registration": True, + "login_via_existing_session": {"enabled": False}, "experimental_features": { "msc3861": { "enabled": True, diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index cf23430f6a..9edf7471ee 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -187,6 +187,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): for room_version in details["support"]: self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) + @override_config({"login_via_existing_session": {"enabled": False}}) def test_get_get_token_login_fields_when_disabled(self) -> None: """By default login via an existing session is disabled.""" access_token = self.get_success( @@ -201,7 +202,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK) self.assertFalse(capabilities["m.get_login_token"]["enabled"]) - @override_config({"login_via_existing_session": {"enabled": True}}) def test_get_get_token_login_fields_when_enabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f3c3bc69a9..0eacb49e46 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -446,6 +446,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] ) + @override_config({"login_via_existing_session": {"enabled": False}}) def test_get_login_flows_with_login_via_existing_disabled(self) -> None: """GET /login should return m.login.token without get_login_token""" channel = self.make_request("GET", "/_matrix/client/r0/login") @@ -454,7 +455,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): flows = {flow["type"]: flow for flow in channel.json_body["flows"]} self.assertNotIn("m.login.token", flows) - @override_config({"login_via_existing_session": {"enabled": True}}) def test_get_login_flows_with_login_via_existing_enabled(self) -> None: """GET /login should return m.login.token with get_login_token true""" channel = self.make_request("GET", "/_matrix/client/r0/login") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 377243a170..9507567218 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -119,6 +119,7 @@ class WellKnownTests(unittest.HomeserverTestCase): }, }, "disable_registration": True, + "login_via_existing_session": {"enabled": False}, } ) def test_client_well_known_msc3861_oauth_delegation(self) -> None: