Add type hints to `tests/rest/client` (#12094)
* Add type hints to `tests/rest/client` * update `mypy.ini` * newsfile * add `test_register.py`
This commit is contained in:
parent
7754af24ab
commit
952efd0bca
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `tests/rest/client`.
|
3
mypy.ini
3
mypy.ini
|
@ -75,10 +75,7 @@ exclude = (?x)
|
||||||
|tests/push/test_presentable_names.py
|
|tests/push/test_presentable_names.py
|
||||||
|tests/push/test_push_rule_evaluator.py
|
|tests/push/test_push_rule_evaluator.py
|
||||||
|tests/rest/client/test_account.py
|
|tests/rest/client/test_account.py
|
||||||
|tests/rest/client/test_events.py
|
|
||||||
|tests/rest/client/test_filter.py
|
|tests/rest/client/test_filter.py
|
||||||
|tests/rest/client/test_groups.py
|
|
||||||
|tests/rest/client/test_register.py
|
|
||||||
|tests/rest/client/test_report_event.py
|
|tests/rest/client/test_report_event.py
|
||||||
|tests/rest/client/test_rooms.py
|
|tests/rest/client/test_rooms.py
|
||||||
|tests/rest/client/test_third_party_rules.py
|
|tests/rest/client/test_third_party_rules.py
|
||||||
|
|
|
@ -16,8 +16,12 @@
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.rest.client import events, login, room
|
from synapse.rest.client import events, login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["enable_registration_captcha"] = False
|
config["enable_registration_captcha"] = False
|
||||||
|
@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(config=config)
|
hs = self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
hs.get_federation_handler = Mock()
|
hs.get_federation_handler = Mock() # type: ignore[assignment]
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
# register an account
|
# register an account
|
||||||
self.user_id = self.register_user("sid1", "pass")
|
self.user_id = self.register_user("sid1", "pass")
|
||||||
|
@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||||
self.other_user = self.register_user("other2", "pass")
|
self.other_user = self.register_user("other2", "pass")
|
||||||
self.other_token = self.login(self.other_user, "pass")
|
self.other_token = self.login(self.other_user, "pass")
|
||||||
|
|
||||||
def test_stream_basic_permissions(self):
|
def test_stream_basic_permissions(self) -> None:
|
||||||
# invalid token, expect 401
|
# invalid token, expect 401
|
||||||
# note: this is in violation of the original v1 spec, which expected
|
# note: this is in violation of the original v1 spec, which expected
|
||||||
# 403. However, since the v1 spec no longer exists and the v1
|
# 403. However, since the v1 spec no longer exists and the v1
|
||||||
|
@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue("start" in channel.json_body)
|
self.assertTrue("start" in channel.json_body)
|
||||||
self.assertTrue("end" in channel.json_body)
|
self.assertTrue("end" in channel.json_body)
|
||||||
|
|
||||||
def test_stream_room_permissions(self):
|
def test_stream_room_permissions(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
|
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
|
||||||
self.helper.send(room_id, tok=self.other_token)
|
self.helper.send(room_id, tok=self.other_token)
|
||||||
|
|
||||||
|
@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# left to room (expect no content for room)
|
# left to room (expect no content for room)
|
||||||
|
|
||||||
def TODO_test_stream_items(self):
|
def TODO_test_stream_items(self) -> None:
|
||||||
# new user, no content
|
# new user, no content
|
||||||
|
|
||||||
# join room, expect 1 item (join)
|
# join room, expect 1 item (join)
|
||||||
|
@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, hs, reactor, clock):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
# register an account
|
# register an account
|
||||||
self.user_id = self.register_user("sid1", "pass")
|
self.user_id = self.register_user("sid1", "pass")
|
||||||
|
@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||||
|
|
||||||
def test_get_event_via_events(self):
|
def test_get_event_via_events(self) -> None:
|
||||||
resp = self.helper.send(self.room_id, tok=self.token)
|
resp = self.helper.send(self.room_id, tok=self.token)
|
||||||
event_id = resp["event_id"]
|
event_id = resp["event_id"]
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [room.register_servlets, groups.register_servlets]
|
servlets = [room.register_servlets, groups.register_servlets]
|
||||||
|
|
||||||
@override_config({"enable_group_creation": True})
|
@override_config({"enable_group_creation": True})
|
||||||
def test_rooms_limited_by_visibility(self):
|
def test_rooms_limited_by_visibility(self) -> None:
|
||||||
group_id = "+spqr:test"
|
group_id = "+spqr:test"
|
||||||
|
|
||||||
# Alice creates a group
|
# Alice creates a group
|
||||||
|
|
|
@ -16,15 +16,21 @@
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage._base import db_to_json
|
from synapse.storage._base import db_to_json
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
url = b"/_matrix/client/r0/register"
|
url = b"/_matrix/client/r0/register"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["allow_guest_access"] = True
|
config["allow_guest_access"] = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_POST_appservice_registration_valid(self):
|
def test_POST_appservice_registration_valid(self) -> None:
|
||||||
user_id = "@as_user_kermit:test"
|
user_id = "@as_user_kermit:test"
|
||||||
as_token = "i_am_an_app_service"
|
as_token = "i_am_an_app_service"
|
||||||
|
|
||||||
|
@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
|
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
def test_POST_appservice_registration_no_type(self):
|
def test_POST_appservice_registration_no_type(self) -> None:
|
||||||
as_token = "i_am_an_app_service"
|
as_token = "i_am_an_app_service"
|
||||||
|
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
|
@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400", channel.result)
|
self.assertEqual(channel.result["code"], b"400", channel.result)
|
||||||
|
|
||||||
def test_POST_appservice_registration_invalid(self):
|
def test_POST_appservice_registration_invalid(self) -> None:
|
||||||
self.appservice = None # no application service exists
|
self.appservice = None # no application service exists
|
||||||
request_data = json.dumps(
|
request_data = json.dumps(
|
||||||
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
|
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
|
||||||
|
@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
|
||||||
def test_POST_bad_password(self):
|
def test_POST_bad_password(self) -> None:
|
||||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||||
channel = self.make_request(b"POST", self.url, request_data)
|
channel = self.make_request(b"POST", self.url, request_data)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400", channel.result)
|
self.assertEqual(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEqual(channel.json_body["error"], "Invalid password")
|
self.assertEqual(channel.json_body["error"], "Invalid password")
|
||||||
|
|
||||||
def test_POST_bad_username(self):
|
def test_POST_bad_username(self) -> None:
|
||||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||||
channel = self.make_request(b"POST", self.url, request_data)
|
channel = self.make_request(b"POST", self.url, request_data)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400", channel.result)
|
self.assertEqual(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEqual(channel.json_body["error"], "Invalid username")
|
self.assertEqual(channel.json_body["error"], "Invalid username")
|
||||||
|
|
||||||
def test_POST_user_valid(self):
|
def test_POST_user_valid(self) -> None:
|
||||||
user_id = "@kermit:test"
|
user_id = "@kermit:test"
|
||||||
device_id = "frogfone"
|
device_id = "frogfone"
|
||||||
params = {
|
params = {
|
||||||
|
@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
@override_config({"enable_registration": False})
|
@override_config({"enable_registration": False})
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self) -> None:
|
||||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||||
|
|
||||||
|
@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
|
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
|
||||||
def test_POST_guest_registration(self):
|
def test_POST_guest_registration(self) -> None:
|
||||||
self.hs.config.key.macaroon_secret_key = "test"
|
self.hs.config.key.macaroon_secret_key = "test"
|
||||||
self.hs.config.registration.allow_guest_access = True
|
self.hs.config.registration.allow_guest_access = True
|
||||||
|
|
||||||
|
@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
def test_POST_disabled_guest_registration(self):
|
def test_POST_disabled_guest_registration(self) -> None:
|
||||||
self.hs.config.registration.allow_guest_access = False
|
self.hs.config.registration.allow_guest_access = False
|
||||||
|
|
||||||
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
|
@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
|
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
|
||||||
|
|
||||||
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
|
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
|
||||||
def test_POST_ratelimiting_guest(self):
|
def test_POST_ratelimiting_guest(self) -> None:
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
url = self.url + b"?kind=guest"
|
url = self.url + b"?kind=guest"
|
||||||
channel = self.make_request(b"POST", url, b"{}")
|
channel = self.make_request(b"POST", url, b"{}")
|
||||||
|
@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
|
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
|
||||||
def test_POST_ratelimiting(self):
|
def test_POST_ratelimiting(self) -> None:
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
params = {
|
params = {
|
||||||
"username": "kermit" + str(i),
|
"username": "kermit" + str(i),
|
||||||
|
@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_requires_token(self):
|
def test_POST_registration_requires_token(self) -> None:
|
||||||
username = "kermit"
|
username = "kermit"
|
||||||
device_id = "frogfone"
|
device_id = "frogfone"
|
||||||
token = "abcd"
|
token = "abcd"
|
||||||
|
@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
params = {
|
params: JsonDict = {
|
||||||
"username": username,
|
"username": username,
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(res["pending"], 0)
|
self.assertEqual(res["pending"], 0)
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_invalid(self):
|
def test_POST_registration_token_invalid(self) -> None:
|
||||||
params = {
|
params: JsonDict = {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
}
|
}
|
||||||
|
@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["completed"], [])
|
self.assertEqual(channel.json_body["completed"], [])
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_limit_uses(self):
|
def test_POST_registration_token_limit_uses(self) -> None:
|
||||||
token = "abcd"
|
token = "abcd"
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
# Create token that can be used once
|
# Create token that can be used once
|
||||||
|
@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
params1 = {"username": "bert", "password": "monkey"}
|
params1: JsonDict = {"username": "bert", "password": "monkey"}
|
||||||
params2 = {"username": "ernie", "password": "monkey"}
|
params2: JsonDict = {"username": "ernie", "password": "monkey"}
|
||||||
# Do 2 requests without auth to get two session IDs
|
# Do 2 requests without auth to get two session IDs
|
||||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
session1 = channel1.json_body["session"]
|
session1 = channel1.json_body["session"]
|
||||||
|
@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["completed"], [])
|
self.assertEqual(channel.json_body["completed"], [])
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_expiry(self):
|
def test_POST_registration_token_expiry(self) -> None:
|
||||||
token = "abcd"
|
token = "abcd"
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
params = {"username": "kermit", "password": "monkey"}
|
params: JsonDict = {"username": "kermit", "password": "monkey"}
|
||||||
# Request without auth to get session
|
# Request without auth to get session
|
||||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_session_expiry(self):
|
def test_POST_registration_token_session_expiry(self) -> None:
|
||||||
"""Test `pending` is decremented when an uncompleted session expires."""
|
"""Test `pending` is decremented when an uncompleted session expires."""
|
||||||
token = "abcd"
|
token = "abcd"
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Do 2 requests without auth to get two session IDs
|
# Do 2 requests without auth to get two session IDs
|
||||||
params1 = {"username": "bert", "password": "monkey"}
|
params1: JsonDict = {"username": "bert", "password": "monkey"}
|
||||||
params2 = {"username": "ernie", "password": "monkey"}
|
params2: JsonDict = {"username": "ernie", "password": "monkey"}
|
||||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||||
session1 = channel1.json_body["session"]
|
session1 = channel1.json_body["session"]
|
||||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||||
|
@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(pending, 0)
|
self.assertEqual(pending, 0)
|
||||||
|
|
||||||
@override_config({"registration_requires_token": True})
|
@override_config({"registration_requires_token": True})
|
||||||
def test_POST_registration_token_session_expiry_deleted_token(self):
|
def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
|
||||||
"""Test session expiry doesn't break when the token is deleted.
|
"""Test session expiry doesn't break when the token is deleted.
|
||||||
|
|
||||||
1. Start but don't complete UIA with a registration token
|
1. Start but don't complete UIA with a registration token
|
||||||
|
@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Do request without auth to get a session ID
|
# Do request without auth to get a session ID
|
||||||
params = {"username": "kermit", "password": "monkey"}
|
params: JsonDict = {"username": "kermit", "password": "monkey"}
|
||||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
|
@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_advertised_flows(self):
|
def test_advertised_flows(self) -> None:
|
||||||
channel = self.make_request(b"POST", self.url, b"{}")
|
channel = self.make_request(b"POST", self.url, b"{}")
|
||||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
flows = channel.json_body["flows"]
|
flows = channel.json_body["flows"]
|
||||||
|
@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_advertised_flows_captcha_and_terms_and_3pids(self):
|
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
|
||||||
channel = self.make_request(b"POST", self.url, b"{}")
|
channel = self.make_request(b"POST", self.url, b"{}")
|
||||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
flows = channel.json_body["flows"]
|
flows = channel.json_body["flows"]
|
||||||
|
@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_advertised_flows_no_msisdn_email_required(self):
|
def test_advertised_flows_no_msisdn_email_required(self) -> None:
|
||||||
channel = self.make_request(b"POST", self.url, b"{}")
|
channel = self.make_request(b"POST", self.url, b"{}")
|
||||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
flows = channel.json_body["flows"]
|
flows = channel.json_body["flows"]
|
||||||
|
@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_request_token_existing_email_inhibit_error(self):
|
def test_request_token_existing_email_inhibit_error(self) -> None:
|
||||||
"""Test that requesting a token via this endpoint doesn't leak existing
|
"""Test that requesting a token via this endpoint doesn't leak existing
|
||||||
associations if configured that way.
|
associations if configured that way.
|
||||||
"""
|
"""
|
||||||
|
@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_reject_invalid_email(self):
|
def test_reject_invalid_email(self) -> None:
|
||||||
"""Check that bad emails are rejected"""
|
"""Check that bad emails are rejected"""
|
||||||
|
|
||||||
# Test for email with multiple @
|
# Test for email with multiple @
|
||||||
|
@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
"inhibit_user_in_use_error": True,
|
"inhibit_user_in_use_error": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_inhibit_user_in_use_error(self):
|
def test_inhibit_user_in_use_error(self) -> None:
|
||||||
"""Tests that the 'inhibit_user_in_use_error' configuration flag behaves
|
"""Tests that the 'inhibit_user_in_use_error' configuration flag behaves
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
account_validity.register_servlets,
|
account_validity.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
# Test for account expiring after a week.
|
# Test for account expiring after a week.
|
||||||
config["enable_registration"] = True
|
config["enable_registration"] = True
|
||||||
|
@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_validity_period(self):
|
def test_validity_period(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
|
@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_manual_renewal(self):
|
def test_manual_renewal(self) -> None:
|
||||||
user_id = self.register_user("kermit", "monkey")
|
user_id = self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
|
@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request(b"GET", "/sync", access_token=tok)
|
channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
def test_manual_expire(self):
|
def test_manual_expire(self) -> None:
|
||||||
user_id = self.register_user("kermit", "monkey")
|
user_id = self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
|
@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_logging_out_expired_user(self):
|
def test_logging_out_expired_user(self) -> None:
|
||||||
user_id = self.register_user("kermit", "monkey")
|
user_id = self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
|
@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
account.register_servlets,
|
account.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
|
|
||||||
# Test for account expiring after a week and renewal emails being sent 2
|
# Test for account expiring after a week and renewal emails being sent 2
|
||||||
|
@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.hs = self.setup_test_homeserver(config=config)
|
self.hs = self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
async def sendmail(*args, **kwargs):
|
async def sendmail(*args: Any, **kwargs: Any) -> None:
|
||||||
self.email_attempts.append((args, kwargs))
|
self.email_attempts.append((args, kwargs))
|
||||||
|
|
||||||
self.email_attempts = []
|
self.email_attempts: List[Tuple[Any, Any]] = []
|
||||||
self.hs.get_send_email_handler()._sendmail = sendmail
|
self.hs.get_send_email_handler()._sendmail = sendmail
|
||||||
|
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_renewal_email(self):
|
def test_renewal_email(self) -> None:
|
||||||
self.email_attempts = []
|
self.email_attempts = []
|
||||||
|
|
||||||
(user_id, tok) = self.create_user()
|
(user_id, tok) = self.create_user()
|
||||||
|
@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request(b"GET", "/sync", access_token=tok)
|
channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
def test_renewal_invalid_token(self):
|
def test_renewal_invalid_token(self) -> None:
|
||||||
# Hit the renewal endpoint with an invalid token and check that it behaves as
|
# Hit the renewal endpoint with an invalid token and check that it behaves as
|
||||||
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
|
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
|
||||||
url = "/_matrix/client/unstable/account_validity/renew?token=123"
|
url = "/_matrix/client/unstable/account_validity/renew?token=123"
|
||||||
|
@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
channel.result["body"], expected_html.encode("utf8"), channel.result
|
channel.result["body"], expected_html.encode("utf8"), channel.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_manual_email_send(self):
|
def test_manual_email_send(self) -> None:
|
||||||
self.email_attempts = []
|
self.email_attempts = []
|
||||||
|
|
||||||
(user_id, tok) = self.create_user()
|
(user_id, tok) = self.create_user()
|
||||||
|
@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(len(self.email_attempts), 1)
|
self.assertEqual(len(self.email_attempts), 1)
|
||||||
|
|
||||||
def test_deactivated_user(self):
|
def test_deactivated_user(self) -> None:
|
||||||
self.email_attempts = []
|
self.email_attempts = []
|
||||||
|
|
||||||
(user_id, tok) = self.create_user()
|
(user_id, tok) = self.create_user()
|
||||||
|
@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(len(self.email_attempts), 0)
|
self.assertEqual(len(self.email_attempts), 0)
|
||||||
|
|
||||||
def create_user(self):
|
def create_user(self) -> Tuple[str, str]:
|
||||||
user_id = self.register_user("kermit", "monkey")
|
user_id = self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
# We need to manually add an email address otherwise the handler will do
|
# We need to manually add an email address otherwise the handler will do
|
||||||
|
@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return user_id, tok
|
return user_id, tok
|
||||||
|
|
||||||
def test_manual_email_send_expired_account(self):
|
def test_manual_email_send_expired_account(self) -> None:
|
||||||
user_id = self.register_user("kermit", "monkey")
|
user_id = self.register_user("kermit", "monkey")
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
|
@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
|
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.validity_period = 10
|
self.validity_period = 10
|
||||||
self.max_delta = self.validity_period * 10.0 / 100.0
|
self.max_delta = self.validity_period * 10.0 / 100.0
|
||||||
|
|
||||||
|
@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_background_job(self):
|
def test_background_job(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the same thing as test_background_job, except that it sets the
|
Tests the same thing as test_background_job, except that it sets the
|
||||||
startup_job_max_delta parameter and checks that the expiration date is within the
|
startup_job_max_delta parameter and checks that the expiration date is within the
|
||||||
|
@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [register.register_servlets]
|
servlets = [register.register_servlets]
|
||||||
url = "/_matrix/client/v1/register/m.login.registration_token/validity"
|
url = "/_matrix/client/v1/register/m.login.registration_token/validity"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["registration_requires_token"] = True
|
config["registration_requires_token"] = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_GET_token_valid(self):
|
def test_GET_token_valid(self) -> None:
|
||||||
token = "abcd"
|
token = "abcd"
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["valid"], True)
|
self.assertEqual(channel.json_body["valid"], True)
|
||||||
|
|
||||||
def test_GET_token_invalid(self):
|
def test_GET_token_invalid(self) -> None:
|
||||||
token = "1234"
|
token = "1234"
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
b"GET",
|
b"GET",
|
||||||
|
@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
|
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
|
||||||
)
|
)
|
||||||
def test_GET_ratelimiting(self):
|
def test_GET_ratelimiting(self) -> None:
|
||||||
token = "1234"
|
token = "1234"
|
||||||
|
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
|
|
Loading…
Reference in New Issue