Make `_make_callback_with_userinfo` async

... so that we can test its behaviour when it raises.

Also pull it out to the top level so that I can use it from other test classes.
This commit is contained in:
Richard van der Hoff 2020-12-15 13:03:31 +00:00
parent c1883f042d
commit 8388a7fb3a
1 changed files with 81 additions and 66 deletions

View File

@ -21,6 +21,7 @@ import pymacaroons
from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock
@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request(
request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test_user",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, None,
)
@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": 1234,
"username": "test_user_2",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, None,
)
@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None,
)
@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, None,
)
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
@ -784,32 +787,44 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
self._make_callback_with_userinfo(userinfo)
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
async def _make_callback_with_userinfo(
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
self.handler._exchange_code = simple_async_mock(return_value={})
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
"""Mock up an OIDC callback with the given userinfo dict
We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
Args:
hs: the HomeServer impl to send the callback to.
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state"
session = self.handler._generate_oidc_session_token(
session = handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
request = _build_callback_request("code", state, session)
await handler.handle_oidc_callback(request)
self.get_success(self.handler.handle_oidc_callback(request))
def _build_callback_request(
self,
code: str,
state: str,
session: str,