Add type hints to some tests/handlers files. (#12224)
This commit is contained in:
parent
2fcf4b3f6c
commit
5dd949bee6
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to tests files.
|
5
mypy.ini
5
mypy.ini
|
@ -67,13 +67,8 @@ exclude = (?x)
|
||||||
|tests/federation/transport/test_knocking.py
|
|tests/federation/transport/test_knocking.py
|
||||||
|tests/federation/transport/test_server.py
|
|tests/federation/transport/test_server.py
|
||||||
|tests/handlers/test_cas.py
|
|tests/handlers/test_cas.py
|
||||||
|tests/handlers/test_directory.py
|
|
||||||
|tests/handlers/test_e2e_keys.py
|
|
||||||
|tests/handlers/test_federation.py
|
|tests/handlers/test_federation.py
|
||||||
|tests/handlers/test_oidc.py
|
|
||||||
|tests/handlers/test_presence.py
|
|tests/handlers/test_presence.py
|
||||||
|tests/handlers/test_profile.py
|
|
||||||
|tests/handlers/test_saml.py
|
|
||||||
|tests/handlers/test_typing.py
|
|tests/handlers/test_typing.py
|
||||||
|tests/http/federation/test_matrix_federation_agent.py
|
|tests/http/federation/test_matrix_federation_agent.py
|
||||||
|tests/http/federation/test_srv_resolver.py
|
|tests/http/federation/test_srv_resolver.py
|
||||||
|
|
|
@ -12,14 +12,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any, Awaitable, Callable, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.rest.client import directory, login, room
|
from synapse.rest.client import directory, login, room
|
||||||
from synapse.types import RoomAlias, create_requester
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
|
||||||
class DirectoryTestCase(unittest.HomeserverTestCase):
|
class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests the directory service."""
|
"""Tests the directory service."""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.mock_federation = Mock()
|
self.mock_federation = Mock()
|
||||||
self.mock_registry = Mock()
|
self.mock_registry = Mock()
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||||
|
|
||||||
def register_query_handler(query_type, handler):
|
def register_query_handler(
|
||||||
|
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
|
) -> None:
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_get_local_association(self):
|
def test_get_local_association(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
self.my_room, "!8765qwer:test", ["test"]
|
self.my_room, "!8765qwer:test", ["test"]
|
||||||
|
@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||||
|
|
||||||
def test_get_remote_association(self):
|
def test_get_remote_association(self) -> None:
|
||||||
self.mock_federation.make_query.return_value = make_awaitable(
|
self.mock_federation.make_query.return_value = make_awaitable(
|
||||||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
||||||
)
|
)
|
||||||
|
@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incoming_fed_query(self):
|
def test_incoming_fed_query(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
self.your_room, "!8765asdf:test", ["test"]
|
self.your_room, "!8765asdf:test", ["test"]
|
||||||
|
@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
|
@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
self.test_user_tok = self.login("user", "pass")
|
self.test_user_tok = self.login("user", "pass")
|
||||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||||
|
|
||||||
def test_create_alias_joined_room(self):
|
def test_create_alias_joined_room(self) -> None:
|
||||||
"""A user can create an alias for a room they're in."""
|
"""A user can create an alias for a room they're in."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.create_association(
|
self.handler.create_association(
|
||||||
|
@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_alias_other_room(self):
|
def test_create_alias_other_room(self) -> None:
|
||||||
"""A user cannot create an alias for a room they're NOT in."""
|
"""A user cannot create an alias for a room they're NOT in."""
|
||||||
other_room_id = self.helper.create_room_as(
|
other_room_id = self.helper.create_room_as(
|
||||||
self.admin_user, tok=self.admin_user_tok
|
self.admin_user, tok=self.admin_user_tok
|
||||||
|
@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_alias_admin(self):
|
def test_create_alias_admin(self) -> None:
|
||||||
"""An admin can create an alias for a room they're NOT in."""
|
"""An admin can create an alias for a room they're NOT in."""
|
||||||
other_room_id = self.helper.create_room_as(
|
other_room_id = self.helper.create_room_as(
|
||||||
self.test_user, tok=self.test_user_tok
|
self.test_user, tok=self.test_user_tok
|
||||||
|
@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
self.test_user_tok = self.login("user", "pass")
|
self.test_user_tok = self.login("user", "pass")
|
||||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||||
|
|
||||||
def _create_alias(self, user):
|
def _create_alias(self, user) -> None:
|
||||||
# Create a new alias to this room.
|
# Create a new alias to this room.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
|
@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_not_allowed(self):
|
def test_delete_alias_not_allowed(self) -> None:
|
||||||
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
||||||
self._create_alias(self.admin_user)
|
self._create_alias(self.admin_user)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.AuthError,
|
synapse.api.errors.AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_creator(self):
|
def test_delete_alias_creator(self) -> None:
|
||||||
"""An alias creator can delete their own alias."""
|
"""An alias creator can delete their own alias."""
|
||||||
# Create an alias from a different user.
|
# Create an alias from a different user.
|
||||||
self._create_alias(self.test_user)
|
self._create_alias(self.test_user)
|
||||||
|
@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_admin(self):
|
def test_delete_alias_admin(self) -> None:
|
||||||
"""A server admin can delete an alias created by another user."""
|
"""A server admin can delete an alias created by another user."""
|
||||||
# Create an alias from a different user.
|
# Create an alias from a different user.
|
||||||
self._create_alias(self.test_user)
|
self._create_alias(self.test_user)
|
||||||
|
@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_sufficient_power(self):
|
def test_delete_alias_sufficient_power(self) -> None:
|
||||||
"""A user with a sufficient power level should be able to delete an alias."""
|
"""A user with a sufficient power level should be able to delete an alias."""
|
||||||
self._create_alias(self.admin_user)
|
self._create_alias(self.admin_user)
|
||||||
|
|
||||||
|
@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return room_alias
|
return room_alias
|
||||||
|
|
||||||
def _set_canonical_alias(self, content):
|
def _set_canonical_alias(self, content) -> None:
|
||||||
"""Configure the canonical alias state on the room."""
|
"""Configure the canonical alias state on the room."""
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
self.room_id,
|
self.room_id,
|
||||||
|
@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_remove_alias(self):
|
def test_remove_alias(self) -> None:
|
||||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||||
# Set this new alias as the canonical alias for this room
|
# Set this new alias as the canonical alias for this room
|
||||||
self._set_canonical_alias(
|
self._set_canonical_alias(
|
||||||
|
@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotIn("alias", data["content"])
|
self.assertNotIn("alias", data["content"])
|
||||||
self.assertNotIn("alt_aliases", data["content"])
|
self.assertNotIn("alt_aliases", data["content"])
|
||||||
|
|
||||||
def test_remove_other_alias(self):
|
def test_remove_other_alias(self) -> None:
|
||||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||||
# Create a second alias.
|
# Create a second alias.
|
||||||
other_test_alias = "#test2:test"
|
other_test_alias = "#test2:test"
|
||||||
|
@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [directory.register_servlets, room.register_servlets]
|
servlets = [directory.register_servlets, room.register_servlets]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# Add custom alias creation rules to the config.
|
# Add custom alias creation rules to the config.
|
||||||
|
@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_denied(self):
|
def test_denied(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(403, channel.code, channel.result)
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
|
||||||
def test_allowed(self):
|
def test_allowed(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
def test_denied_during_creation(self):
|
def test_denied_during_creation(self) -> None:
|
||||||
"""A room alias that is not allowed should be rejected during creation."""
|
"""A room alias that is not allowed should be rejected during creation."""
|
||||||
# Invalid room alias.
|
# Invalid room alias.
|
||||||
self.helper.create_room_as(
|
self.helper.create_room_as(
|
||||||
|
@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
extra_content={"room_alias_name": "foo"},
|
extra_content={"room_alias_name": "foo"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_during_creation(self):
|
def test_allowed_during_creation(self) -> None:
|
||||||
"""A valid room alias should be allowed during creation."""
|
"""A valid room alias should be allowed during creation."""
|
||||||
room_id = self.helper.create_room_as(
|
room_id = self.helper.create_room_as(
|
||||||
self.user_id,
|
self.user_id,
|
||||||
|
@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
data = {"room_alias_name": "unofficial_test"}
|
data = {"room_alias_name": "unofficial_test"}
|
||||||
allowed_localpart = "allowed"
|
allowed_localpart = "allowed"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# Add custom room list publication rules to the config.
|
# Add custom room list publication rules to the config.
|
||||||
|
@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||||
|
) -> HomeServer:
|
||||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||||
|
|
||||||
|
@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_denied_without_publication_permission(self):
|
def test_denied_without_publication_permission(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user without permission to publish rooms.
|
as a user without permission to publish rooms.
|
||||||
|
@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=403,
|
expect_code=403,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_when_creating_private_room(self):
|
def test_allowed_when_creating_private_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and NOT publish it,
|
Try to create a room, register an alias for it, and NOT publish it,
|
||||||
as a user without permission to publish rooms.
|
as a user without permission to publish rooms.
|
||||||
|
@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_with_publication_permission(self):
|
def test_allowed_with_publication_permission(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user WITH permission to publish rooms.
|
as a user WITH permission to publish rooms.
|
||||||
|
@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_denied_publication_with_invalid_alias(self):
|
def test_denied_publication_with_invalid_alias(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user WITH permission to publish rooms.
|
as a user WITH permission to publish rooms.
|
||||||
|
@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=403,
|
expect_code=403,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_can_create_as_private_room_after_rejection(self):
|
def test_can_create_as_private_room_after_rejection(self) -> None:
|
||||||
"""
|
"""
|
||||||
After failing to publish a room with an alias as a user without publish permission,
|
After failing to publish a room with an alias as a user without publish permission,
|
||||||
retry as the same user, but without publishing the room.
|
retry as the same user, but without publishing the room.
|
||||||
|
@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
self.test_denied_without_publication_permission()
|
self.test_denied_without_publication_permission()
|
||||||
self.test_allowed_when_creating_private_room()
|
self.test_allowed_when_creating_private_room()
|
||||||
|
|
||||||
def test_can_create_with_permission_after_rejection(self):
|
def test_can_create_with_permission_after_rejection(self) -> None:
|
||||||
"""
|
"""
|
||||||
After failing to publish a room with an alias as a user without publish permission,
|
After failing to publish a room with an alias as a user without publish permission,
|
||||||
retry as someone with permission, using the same alias.
|
retry as someone with permission, using the same alias.
|
||||||
|
@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [directory.register_servlets, room.register_servlets]
|
servlets = [directory.register_servlets, room.register_servlets]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||||
|
) -> HomeServer:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_disabling_room_list(self):
|
def test_disabling_room_list(self) -> None:
|
||||||
self.room_list_handler.enable_room_list_search = True
|
self.room_list_handler.enable_room_list_search = True
|
||||||
self.directory_handler.enable_room_list_search = True
|
self.directory_handler.enable_room_list_search = True
|
||||||
|
|
||||||
|
|
|
@ -20,33 +20,37 @@ from parameterized import parameterized
|
||||||
from signedjson import key as key, sign as sign
|
from signedjson import key as key, sign as sign
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(federation_client=mock.Mock())
|
return self.setup_test_homeserver(federation_client=mock.Mock())
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_e2e_keys_handler()
|
self.handler = hs.get_e2e_keys_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
def test_query_local_devices_no_devices(self):
|
def test_query_local_devices_no_devices(self) -> None:
|
||||||
"""If the user has no devices, we expect an empty list."""
|
"""If the user has no devices, we expect an empty list."""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.assertDictEqual(res, {local_user: {}})
|
||||||
|
|
||||||
def test_reupload_one_time_keys(self):
|
def test_reupload_one_time_keys(self) -> None:
|
||||||
"""we should be able to re-upload the same keys"""
|
"""we should be able to re-upload the same keys"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {
|
keys: JsonDict = {
|
||||||
"alg1:k1": "key1",
|
"alg1:k1": "key1",
|
||||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||||
"alg2:k3": {"key": "key3"},
|
"alg2:k3": {"key": "key3"},
|
||||||
|
@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_change_one_time_keys(self):
|
def test_change_one_time_keys(self) -> None:
|
||||||
"""attempts to change one-time-keys should be rejected"""
|
"""attempts to change one-time-keys should be rejected"""
|
||||||
|
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_claim_one_time_key(self):
|
def test_claim_one_time_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {"alg1:k1": "key1"}
|
keys = {"alg1:k1": "key1"}
|
||||||
|
@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fallback_key(self):
|
def test_fallback_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||||
|
@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_replace_master_key(self):
|
def test_replace_master_key(self) -> None:
|
||||||
"""uploading a new signing key should make the old signing key unavailable"""
|
"""uploading a new signing key should make the old signing key unavailable"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||||
|
|
||||||
def test_reupload_signatures(self):
|
def test_reupload_signatures(self) -> None:
|
||||||
"""re-uploading a signature should not fail"""
|
"""re-uploading a signature should not fail"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
||||||
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
||||||
|
|
||||||
def test_self_signing_key_doesnt_show_up_as_device(self):
|
def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
|
||||||
"""signing keys should be hidden when fetching a user's devices"""
|
"""signing keys should be hidden when fetching a user's devices"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.assertDictEqual(res, {local_user: {}})
|
||||||
|
|
||||||
def test_upload_signatures(self):
|
def test_upload_signatures(self) -> None:
|
||||||
"""should check signatures that are uploaded"""
|
"""should check signatures that are uploaded"""
|
||||||
# set up a user with cross-signing keys and a device. This user will
|
# set up a user with cross-signing keys and a device. This user will
|
||||||
# try uploading signatures
|
# try uploading signatures
|
||||||
|
@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_devices_remote_no_sync(self):
|
def test_query_devices_remote_no_sync(self) -> None:
|
||||||
"""Tests that querying keys for a remote user that we don't share a room
|
"""Tests that querying keys for a remote user that we don't share a room
|
||||||
with returns the cross signing keys correctly.
|
with returns the cross signing keys correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_devices_remote_sync(self):
|
def test_query_devices_remote_sync(self) -> None:
|
||||||
"""Tests that querying keys for a remote user that we share a room with,
|
"""Tests that querying keys for a remote user that we share a room with,
|
||||||
but haven't yet fetched the keys for, returns the cross signing keys
|
but haven't yet fetched the keys for, returns the cross signing keys
|
||||||
correctly.
|
correctly.
|
||||||
|
@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
(["device_1", "device_2"],),
|
(["device_1", "device_2"],),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
|
||||||
"""Test that requests for all of a remote user's devices are cached.
|
"""Test that requests for all of a remote user's devices are cached.
|
||||||
|
|
||||||
We do this by asserting that only one call over federation was made, and that
|
We do this by asserting that only one call over federation was made, and that
|
||||||
|
@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
local_user_id = "@test:test"
|
local_user_id = "@test:test"
|
||||||
remote_user_id = "@test:other"
|
remote_user_id = "@test:other"
|
||||||
request_body = {"device_keys": {remote_user_id: []}}
|
request_body: JsonDict = {"device_keys": {remote_user_id: []}}
|
||||||
|
|
||||||
response_devices = [
|
response_devices = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -13,14 +13,18 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.macaroons import get_value_from_macaroon
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||||
|
@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_json(url):
|
async def get_json(url: str) -> JsonDict:
|
||||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||||
if url == WELL_KNOWN:
|
if url == WELL_KNOWN:
|
||||||
# Minimal discovery document, as defined in OpenID.Discovery
|
# Minimal discovery document, as defined in OpenID.Discovery
|
||||||
|
@ -116,6 +120,8 @@ async def get_json(url):
|
||||||
elif url == JWKS_URI:
|
elif url == JWKS_URI:
|
||||||
return {"keys": []}
|
return {"keys": []}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _key_file_path() -> str:
|
def _key_file_path() -> str:
|
||||||
"""path to a file containing the private half of a test key"""
|
"""path to a file containing the private half of a test key"""
|
||||||
|
@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
if not HAS_OIDC:
|
if not HAS_OIDC:
|
||||||
skip = "requires OIDC"
|
skip = "requires OIDC"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.http_client = Mock(spec=["get_json"])
|
self.http_client = Mock(spec=["get_json"])
|
||||||
self.http_client.get_json.side_effect = get_json
|
self.http_client.get_json.side_effect = get_json
|
||||||
self.http_client.user_agent = b"Synapse Test"
|
self.http_client.user_agent = b"Synapse Test"
|
||||||
|
@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
sso_handler = hs.get_sso_handler()
|
sso_handler = hs.get_sso_handler()
|
||||||
# Mock the render error method.
|
# Mock the render error method.
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
sso_handler.render_error = self.render_error
|
sso_handler.render_error = self.render_error # type: ignore[assignment]
|
||||||
|
|
||||||
# Reduce the number of attempts when generating MXIDs.
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_config(self):
|
def test_config(self) -> None:
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||||
def test_discovery(self):
|
def test_discovery(self) -> None:
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
metadata = self.get_success(self.provider.load_metadata())
|
metadata = self.get_success(self.provider.load_metadata())
|
||||||
|
@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self) -> None:
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self) -> None:
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = self.get_success(self.provider.load_jwks())
|
jwks = self.get_success(self.provider.load_jwks())
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
|
@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self) -> None:
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.provider
|
h = self.provider
|
||||||
|
|
||||||
|
@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
force_load_metadata()
|
force_load_metadata()
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||||
def test_skip_verification(self):
|
def test_skip_verification(self) -> None:
|
||||||
"""Provider metadata validation can be disabled by config."""
|
"""Provider metadata validation can be disabled by config."""
|
||||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
get_awaitable_result(self.provider.load_metadata())
|
get_awaitable_result(self.provider.load_metadata())
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self) -> None:
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["cookies"])
|
req = Mock(spec=["cookies"])
|
||||||
req.cookies = []
|
req.cookies = []
|
||||||
|
@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(redirect, "http://client/redirect")
|
self.assertEqual(redirect, "http://client/redirect")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_error(self):
|
def test_callback_error(self) -> None:
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"error"] = [b"invalid_client"]
|
request.args[b"error"] = [b"invalid_client"]
|
||||||
|
@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertRenderedError("invalid_client", "some description")
|
self.assertRenderedError("invalid_client", "some description")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback(self):
|
def test_callback(self) -> None:
|
||||||
"""Code callback works and display errors if something went wrong.
|
"""Code callback works and display errors if something went wrong.
|
||||||
|
|
||||||
A lot of scenarios are tested here:
|
A lot of scenarios are tested here:
|
||||||
|
@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"username": username,
|
"username": username,
|
||||||
}
|
}
|
||||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertRenderedError("mapping_error")
|
self.assertRenderedError("mapping_error")
|
||||||
|
|
||||||
# Handle ID token errors
|
# Handle ID token errors
|
||||||
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_token")
|
self.assertRenderedError("invalid_token")
|
||||||
|
|
||||||
|
@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"type": "bearer",
|
"type": "bearer",
|
||||||
"access_token": "access_token",
|
"access_token": "access_token",
|
||||||
}
|
}
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
id_token = {
|
id_token = {
|
||||||
"sid": "abcdefgh",
|
"sid": "abcdefgh",
|
||||||
}
|
}
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
self.provider._fetch_userinfo.reset_mock()
|
self.provider._fetch_userinfo.reset_mock()
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("fetch_error")
|
self.assertRenderedError("fetch_error")
|
||||||
|
|
||||||
# Handle code exchange failure
|
# Handle code exchange failure
|
||||||
from synapse.handlers.oidc import OidcError
|
from synapse.handlers.oidc import OidcError
|
||||||
|
|
||||||
self.provider._exchange_code = simple_async_mock(
|
self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
|
||||||
raises=OidcError("invalid_request")
|
raises=OidcError("invalid_request")
|
||||||
)
|
)
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_session(self):
|
def test_callback_session(self) -> None:
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||||
|
|
||||||
|
@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||||
)
|
)
|
||||||
def test_exchange_code(self):
|
def test_exchange_code(self) -> None:
|
||||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
token_json = json.dumps(token).encode("utf-8")
|
token_json = json.dumps(token).encode("utf-8")
|
||||||
|
@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_exchange_code_jwt_key(self):
|
def test_exchange_code_jwt_key(self) -> None:
|
||||||
"""Test that code exchange works with a JWK client secret."""
|
"""Test that code exchange works with a JWK client secret."""
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
|
|
||||||
|
@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_exchange_code_no_auth(self):
|
def test_exchange_code_no_auth(self) -> None:
|
||||||
"""Test that code exchange works with no client secret."""
|
"""Test that code exchange works with no client secret."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_extra_attributes(self):
|
def test_extra_attributes(self) -> None:
|
||||||
"""
|
"""
|
||||||
Login while using a mapping provider that implements get_extra_attributes.
|
Login while using a mapping provider that implements get_extra_attributes.
|
||||||
"""
|
"""
|
||||||
|
@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"username": "foo",
|
"username": "foo",
|
||||||
"phone": "1234567",
|
"phone": "1234567",
|
||||||
}
|
}
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_user(self):
|
def test_map_userinfo_to_user(self) -> None:
|
||||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
userinfo = {
|
userinfo: dict = {
|
||||||
"sub": "test_user",
|
"sub": "test_user",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
|
@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||||
def test_map_userinfo_to_existing_user(self):
|
def test_map_userinfo_to_existing_user(self) -> None:
|
||||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
user = UserID.from_string("@test_user:test")
|
user = UserID.from_string("@test_user:test")
|
||||||
|
@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self) -> None:
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
||||||
|
@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_map_userinfo_to_user_retries(self):
|
def test_map_userinfo_to_user_retries(self) -> None:
|
||||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_empty_localpart(self):
|
def test_empty_localpart(self) -> None:
|
||||||
"""Attempts to map onto an empty localpart should be rejected."""
|
"""Attempts to map onto an empty localpart should be rejected."""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
|
@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_null_localpart(self):
|
def test_null_localpart(self) -> None:
|
||||||
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
|
@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements(self):
|
def test_attribute_requirements(self) -> None:
|
||||||
"""The required attributes must be met from the OIDC userinfo response."""
|
"""The required attributes must be met from the OIDC userinfo response."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements_contains(self):
|
def test_attribute_requirements_contains(self) -> None:
|
||||||
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements_mismatch(self):
|
def test_attribute_requirements_mismatch(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that auth fails if attributes exist but don't match,
|
Test that auth fails if attributes exist but don't match,
|
||||||
or are non-string values.
|
or are non-string values.
|
||||||
|
@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
# userinfo with "test": "not_foobar" attribute should fail
|
# userinfo with "test": "not_foobar" attribute should fail
|
||||||
userinfo = {
|
userinfo: dict = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
"username": "tester",
|
"username": "tester",
|
||||||
"test": "not_foobar",
|
"test": "not_foobar",
|
||||||
|
@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
|
||||||
|
|
||||||
handler = hs.get_oidc_handler()
|
handler = hs.get_oidc_handler()
|
||||||
provider = handler._providers["oidc"]
|
provider = handler._providers["oidc"]
|
||||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
|
||||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
|
|
||||||
state = "state"
|
state = "state"
|
||||||
session = handler._token_generator.generate_oidc_session_token(
|
session = handler._token_generator.generate_oidc_session_token(
|
||||||
|
|
|
@ -11,14 +11,17 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict
|
from typing import Any, Awaitable, Callable, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.errors import AuthError, SynapseError
|
from synapse.api.errors import AuthError, SynapseError
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [admin.register_servlets]
|
servlets = [admin.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.mock_federation = Mock()
|
self.mock_federation = Mock()
|
||||||
self.mock_registry = Mock()
|
self.mock_registry = Mock()
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||||
|
|
||||||
def register_query_handler(query_type, handler):
|
def register_query_handler(
|
||||||
|
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
|
) -> None:
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
self.frank = UserID.from_string("@1234abcd:test")
|
self.frank = UserID.from_string("@1234abcd:test")
|
||||||
|
@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.handler = hs.get_profile_handler()
|
self.handler = hs.get_profile_handler()
|
||||||
|
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||||
)
|
)
|
||||||
|
@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual("Frank", displayname)
|
self.assertEqual("Frank", displayname)
|
||||||
|
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.set_displayname(
|
self.handler.set_displayname(
|
||||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||||
|
@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_name_if_disabled(self):
|
def test_set_my_name_if_disabled(self) -> None:
|
||||||
self.hs.config.registration.enable_set_displayname = False
|
self.hs.config.registration.enable_set_displayname = False
|
||||||
|
|
||||||
# Setting displayname for the first time is allowed
|
# Setting displayname for the first time is allowed
|
||||||
|
@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.handler.set_displayname(
|
self.handler.set_displayname(
|
||||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||||
|
@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_other_name(self):
|
def test_get_other_name(self) -> None:
|
||||||
self.mock_federation.make_query.return_value = make_awaitable(
|
self.mock_federation.make_query.return_value = make_awaitable(
|
||||||
{"displayname": "Alice"}
|
{"displayname": "Alice"}
|
||||||
)
|
)
|
||||||
|
@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incoming_fed_query(self):
|
def test_incoming_fed_query(self) -> None:
|
||||||
self.get_success(self.store.create_profile("caroline"))
|
self.get_success(self.store.create_profile("caroline"))
|
||||||
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
||||||
|
|
||||||
|
@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual({"displayname": "Caroline"}, response)
|
self.assertEqual({"displayname": "Caroline"}, response)
|
||||||
|
|
||||||
def test_get_my_avatar(self):
|
def test_get_my_avatar(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_profile_avatar_url(
|
self.store.set_profile_avatar_url(
|
||||||
self.frank.localpart, "http://my.server/me.png"
|
self.frank.localpart, "http://my.server/me.png"
|
||||||
|
@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual("http://my.server/me.png", avatar_url)
|
self.assertEqual("http://my.server/me.png", avatar_url)
|
||||||
|
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.set_avatar_url(
|
self.handler.set_avatar_url(
|
||||||
self.frank,
|
self.frank,
|
||||||
|
@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_avatar_if_disabled(self):
|
def test_set_my_avatar_if_disabled(self) -> None:
|
||||||
self.hs.config.registration.enable_set_avatar_url = False
|
self.hs.config.registration.enable_set_avatar_url = False
|
||||||
|
|
||||||
# Setting displayname for the first time is allowed
|
# Setting displayname for the first time is allowed
|
||||||
|
@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_avatar_constraints_no_config(self):
|
def test_avatar_constraints_no_config(self) -> None:
|
||||||
"""Tests that the method to check an avatar against configured constraints skips
|
"""Tests that the method to check an avatar against configured constraints skips
|
||||||
all of its check if no constraint is configured.
|
all of its check if no constraint is configured.
|
||||||
"""
|
"""
|
||||||
|
@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
|
|
||||||
@unittest.override_config({"max_avatar_size": 50})
|
@unittest.override_config({"max_avatar_size": 50})
|
||||||
def test_avatar_constraints_missing(self):
|
def test_avatar_constraints_missing(self) -> None:
|
||||||
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
||||||
be found.
|
be found.
|
||||||
"""
|
"""
|
||||||
|
@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(res)
|
self.assertFalse(res)
|
||||||
|
|
||||||
@unittest.override_config({"max_avatar_size": 50})
|
@unittest.override_config({"max_avatar_size": 50})
|
||||||
def test_avatar_constraints_file_size(self):
|
def test_avatar_constraints_file_size(self) -> None:
|
||||||
"""Tests that a file that's above the allowed file size is forbidden but one
|
"""Tests that a file that's above the allowed file size is forbidden but one
|
||||||
that's below it is allowed.
|
that's below it is allowed.
|
||||||
"""
|
"""
|
||||||
|
@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(res)
|
self.assertFalse(res)
|
||||||
|
|
||||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||||
def test_avatar_constraint_mime_type(self):
|
def test_avatar_constraint_mime_type(self) -> None:
|
||||||
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
||||||
an authorised content type is allowed.
|
an authorised content type is allowed.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,12 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
|
||||||
|
|
||||||
|
|
||||||
class SamlHandlerTestCase(HomeserverTestCase):
|
class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
saml_config = {
|
saml_config: Dict[str, Any] = {
|
||||||
"sp_config": {"metadata": {}},
|
"sp_config": {"metadata": {}},
|
||||||
# Disable grandfathering.
|
# Disable grandfathering.
|
||||||
"grandfathered_mxid_source_attribute": None,
|
"grandfathered_mxid_source_attribute": None,
|
||||||
|
@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver()
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
self.handler = hs.get_saml_handler()
|
self.handler = hs.get_saml_handler()
|
||||||
|
@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
elif not has_xmlsec1:
|
elif not has_xmlsec1:
|
||||||
skip = "Requires xmlsec1"
|
skip = "Requires xmlsec1"
|
||||||
|
|
||||||
def test_map_saml_response_to_user(self):
|
def test_map_saml_response_to_user(self) -> None:
|
||||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
def test_map_saml_response_to_existing_user(self):
|
def test_map_saml_response_to_existing_user(self) -> None:
|
||||||
"""Existing users can log in with SAML account."""
|
"""Existing users can log in with SAML account."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
auth_provider_session_id=None,
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self) -> None:
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
def test_map_saml_response_to_user_retries(self):
|
def test_map_saml_response_to_user_retries(self) -> None:
|
||||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
|
|
||||||
# stub out the auth handler and error renderer
|
# stub out the auth handler and error renderer
|
||||||
|
@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_map_saml_response_redirect(self):
|
def test_map_saml_response_redirect(self) -> None:
|
||||||
"""Test a mapping provider that raises a RedirectException"""
|
"""Test a mapping provider that raises a RedirectException"""
|
||||||
|
|
||||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||||
|
@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements(self):
|
def test_attribute_requirements(self) -> None:
|
||||||
"""The required attributes must be met from the SAML response."""
|
"""The required attributes must be met from the SAML response."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
|
Loading…
Reference in New Issue