Add type hints to `tests/rest/client` (#12094)

* Add type hints to `tests/rest/client`

* update `mypy.ini`

* newsfile

* add `test_register.py`
This commit is contained in:
Dirk Klimpel 2022-02-28 19:59:00 +01:00 committed by GitHub
parent 7754af24ab
commit 952efd0bca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 64 deletions

1
changelog.d/12094.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View File

@ -75,10 +75,7 @@ exclude = (?x)
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_account.py |tests/rest/client/test_account.py
|tests/rest/client/test_events.py
|tests/rest/client/test_filter.py |tests/rest/client/test_filter.py
|tests/rest/client/test_groups.py
|tests/rest/client/test_register.py
|tests/rest/client/test_report_event.py |tests/rest/client/test_report_event.py
|tests/rest/client/test_rooms.py |tests/rest/client/test_rooms.py
|tests/rest/client/test_third_party_rules.py |tests/rest/client/test_third_party_rules.py

View File

@ -16,8 +16,12 @@
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client import events, login, room from synapse.rest.client import events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["enable_registration_captcha"] = False config["enable_registration_captcha"] = False
@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config) hs = self.setup_test_homeserver(config=config)
hs.get_federation_handler = Mock() hs.get_federation_handler = Mock() # type: ignore[assignment]
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account # register an account
self.user_id = self.register_user("sid1", "pass") self.user_id = self.register_user("sid1", "pass")
@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
self.other_user = self.register_user("other2", "pass") self.other_user = self.register_user("other2", "pass")
self.other_token = self.login(self.other_user, "pass") self.other_token = self.login(self.other_user, "pass")
def test_stream_basic_permissions(self): def test_stream_basic_permissions(self) -> None:
# invalid token, expect 401 # invalid token, expect 401
# note: this is in violation of the original v1 spec, which expected # note: this is in violation of the original v1 spec, which expected
# 403. However, since the v1 spec no longer exists and the v1 # 403. However, since the v1 spec no longer exists and the v1
@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
self.assertTrue("start" in channel.json_body) self.assertTrue("start" in channel.json_body)
self.assertTrue("end" in channel.json_body) self.assertTrue("end" in channel.json_body)
def test_stream_room_permissions(self): def test_stream_room_permissions(self) -> None:
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
self.helper.send(room_id, tok=self.other_token) self.helper.send(room_id, tok=self.other_token)
@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# left to room (expect no content for room) # left to room (expect no content for room)
def TODO_test_stream_items(self): def TODO_test_stream_items(self) -> None:
# new user, no content # new user, no content
# join room, expect 1 item (join) # join room, expect 1 item (join)
@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, hs, reactor, clock): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account # register an account
self.user_id = self.register_user("sid1", "pass") self.user_id = self.register_user("sid1", "pass")
@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
def test_get_event_via_events(self): def test_get_event_via_events(self) -> None:
resp = self.helper.send(self.room_id, tok=self.token) resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"] event_id = resp["event_id"]

View File

@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets, groups.register_servlets] servlets = [room.register_servlets, groups.register_servlets]
@override_config({"enable_group_creation": True}) @override_config({"enable_group_creation": True})
def test_rooms_limited_by_visibility(self): def test_rooms_limited_by_visibility(self) -> None:
group_id = "+spqr:test" group_id = "+spqr:test"
# Alice creates a group # Alice creates a group

View File

@ -16,15 +16,21 @@
import datetime import datetime
import json import json
import os import os
from typing import Any, Dict, List, Tuple
import pkg_resources import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync from synapse.rest.client import account, account_validity, login, logout, register, sync
from synapse.server import HomeServer
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
] ]
url = b"/_matrix/client/r0/register" url = b"/_matrix/client/r0/register"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["allow_guest_access"] = True config["allow_guest_access"] = True
return config return config
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self) -> None:
user_id = "@as_user_kermit:test" user_id = "@as_user_kermit:test"
as_token = "i_am_an_app_service" as_token = "i_am_an_app_service"
@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
det_data = {"user_id": user_id, "home_server": self.hs.hostname} det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_appservice_registration_no_type(self): def test_POST_appservice_registration_no_type(self) -> None:
as_token = "i_am_an_app_service" as_token = "i_am_an_app_service"
appservice = ApplicationService( appservice = ApplicationService(
@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.result["code"], b"400", channel.result)
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists self.appservice = None # no application service exists
request_data = json.dumps( request_data = json.dumps(
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self): def test_POST_bad_password(self) -> None:
request_data = json.dumps({"username": "kermit", "password": 666}) request_data = json.dumps({"username": "kermit", "password": 666})
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password") self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self): def test_POST_bad_username(self) -> None:
request_data = json.dumps({"username": 777, "password": "monkey"}) request_data = json.dumps({"username": 777, "password": "monkey"})
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username") self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self): def test_POST_user_valid(self) -> None:
user_id = "@kermit:test" user_id = "@kermit:test"
device_id = "frogfone" device_id = "frogfone"
params = { params = {
@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False}) @override_config({"enable_registration": False})
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "monkey"}) request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self): def test_POST_guest_registration(self) -> None:
self.hs.config.key.macaroon_secret_key = "test" self.hs.config.key.macaroon_secret_key = "test"
self.hs.config.registration.allow_guest_access = True self.hs.config.registration.allow_guest_access = True
@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self): def test_POST_disabled_guest_registration(self) -> None:
self.hs.config.registration.allow_guest_access = False self.hs.config.registration.allow_guest_access = False
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Guest access is disabled") self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self): def test_POST_ratelimiting_guest(self) -> None:
for i in range(0, 6): for i in range(0, 6):
url = self.url + b"?kind=guest" url = self.url + b"?kind=guest"
channel = self.make_request(b"POST", url, b"{}") channel = self.make_request(b"POST", url, b"{}")
@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self): def test_POST_ratelimiting(self) -> None:
for i in range(0, 6): for i in range(0, 6):
params = { params = {
"username": "kermit" + str(i), "username": "kermit" + str(i),
@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self): def test_POST_registration_requires_token(self) -> None:
username = "kermit" username = "kermit"
device_id = "frogfone" device_id = "frogfone"
token = "abcd" token = "abcd"
@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
params = { params: JsonDict = {
"username": username, "username": username,
"password": "monkey", "password": "monkey",
"device_id": device_id, "device_id": device_id,
@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["pending"], 0) self.assertEqual(res["pending"], 0)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self): def test_POST_registration_token_invalid(self) -> None:
params = { params: JsonDict = {
"username": "kermit", "username": "kermit",
"password": "monkey", "password": "monkey",
} }
@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_limit_uses(self): def test_POST_registration_token_limit_uses(self) -> None:
token = "abcd" token = "abcd"
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
# Create token that can be used once # Create token that can be used once
@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
params1 = {"username": "bert", "password": "monkey"} params1: JsonDict = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs # Do 2 requests without auth to get two session IDs
channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"] session1 = channel1.json_body["session"]
@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_expiry(self): def test_POST_registration_token_expiry(self) -> None:
token = "abcd" token = "abcd"
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
params = {"username": "kermit", "password": "monkey"} params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session # Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params)) channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"] session = channel.json_body["session"]
@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry(self): def test_POST_registration_token_session_expiry(self) -> None:
"""Test `pending` is decremented when an uncompleted session expires.""" """Test `pending` is decremented when an uncompleted session expires."""
token = "abcd" token = "abcd"
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
# Do 2 requests without auth to get two session IDs # Do 2 requests without auth to get two session IDs
params1 = {"username": "bert", "password": "monkey"} params1: JsonDict = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"}
channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"] session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(pending, 0) self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry_deleted_token(self): def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
"""Test session expiry doesn't break when the token is deleted. """Test session expiry doesn't break when the token is deleted.
1. Start but don't complete UIA with a registration token 1. Start but don't complete UIA with a registration token
@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
# Do request without auth to get a session ID # Do request without auth to get a session ID
params = {"username": "kermit", "password": "monkey"} params: JsonDict = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, json.dumps(params)) channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"] session = channel.json_body["session"]
@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
) )
def test_advertised_flows(self): def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
} }
) )
def test_advertised_flows_captcha_and_terms_and_3pids(self): def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
} }
) )
def test_advertised_flows_no_msisdn_email_required(self): def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
} }
) )
def test_request_token_existing_email_inhibit_error(self): def test_request_token_existing_email_inhibit_error(self) -> None:
"""Test that requesting a token via this endpoint doesn't leak existing """Test that requesting a token via this endpoint doesn't leak existing
associations if configured that way. associations if configured that way.
""" """
@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}, },
} }
) )
def test_reject_invalid_email(self): def test_reject_invalid_email(self) -> None:
"""Check that bad emails are rejected""" """Check that bad emails are rejected"""
# Test for email with multiple @ # Test for email with multiple @
@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"inhibit_user_in_use_error": True, "inhibit_user_in_use_error": True,
} }
) )
def test_inhibit_user_in_use_error(self): def test_inhibit_user_in_use_error(self) -> None:
"""Tests that the 'inhibit_user_in_use_error' configuration flag behaves """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
correctly. correctly.
""" """
@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
account_validity.register_servlets, account_validity.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
# Test for account expiring after a week. # Test for account expiring after a week.
config["enable_registration"] = True config["enable_registration"] = True
@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_validity_period(self): def test_validity_period(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
def test_manual_renewal(self): def test_manual_renewal(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
def test_manual_expire(self): def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
def test_logging_out_expired_user(self): def test_logging_out_expired_user(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
account.register_servlets, account.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
# Test for account expiring after a week and renewal emails being sent 2 # Test for account expiring after a week and renewal emails being sent 2
@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
async def sendmail(*args, **kwargs): async def sendmail(*args: Any, **kwargs: Any) -> None:
self.email_attempts.append((args, kwargs)) self.email_attempts.append((args, kwargs))
self.email_attempts = [] self.email_attempts: List[Tuple[Any, Any]] = []
self.hs.get_send_email_handler()._sendmail = sendmail self.hs.get_send_email_handler()._sendmail = sendmail
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
return self.hs return self.hs
def test_renewal_email(self): def test_renewal_email(self) -> None:
self.email_attempts = [] self.email_attempts = []
(user_id, tok) = self.create_user() (user_id, tok) = self.create_user()
@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self): def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as # Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123" url = "/_matrix/client/unstable/account_validity/renew?token=123"
@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
channel.result["body"], expected_html.encode("utf8"), channel.result channel.result["body"], expected_html.encode("utf8"), channel.result
) )
def test_manual_email_send(self): def test_manual_email_send(self) -> None:
self.email_attempts = [] self.email_attempts = []
(user_id, tok) = self.create_user() (user_id, tok) = self.create_user()
@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
def test_deactivated_user(self): def test_deactivated_user(self) -> None:
self.email_attempts = [] self.email_attempts = []
(user_id, tok) = self.create_user() (user_id, tok) = self.create_user()
@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 0) self.assertEqual(len(self.email_attempts), 0)
def create_user(self): def create_user(self) -> Tuple[str, str]:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do # We need to manually add an email address otherwise the handler will do
@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
) )
return user_id, tok return user_id, tok
def test_manual_email_send_expired_account(self): def test_manual_email_send_expired_account(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.validity_period = 10 self.validity_period = 10
self.max_delta = self.validity_period * 10.0 / 100.0 self.max_delta = self.validity_period * 10.0 / 100.0
@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_background_job(self): def test_background_job(self) -> None:
""" """
Tests the same thing as test_background_job, except that it sets the Tests the same thing as test_background_job, except that it sets the
startup_job_max_delta parameter and checks that the expiration date is within the startup_job_max_delta parameter and checks that the expiration date is within the
@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets] servlets = [register.register_servlets]
url = "/_matrix/client/v1/register/m.login.registration_token/validity" url = "/_matrix/client/v1/register/m.login.registration_token/validity"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["registration_requires_token"] = True config["registration_requires_token"] = True
return config return config
def test_GET_token_valid(self): def test_GET_token_valid(self) -> None:
token = "abcd" token = "abcd"
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["valid"], True) self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self): def test_GET_token_invalid(self) -> None:
token = "1234" token = "1234"
channel = self.make_request( channel = self.make_request(
b"GET", b"GET",
@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
@override_config( @override_config(
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
) )
def test_GET_ratelimiting(self): def test_GET_ratelimiting(self) -> None:
token = "1234" token = "1234"
for i in range(0, 6): for i in range(0, 6):