Add type hints to `tests/rest/client` (#12072)

This commit is contained in:
Dirk Klimpel 2022-02-24 19:56:38 +01:00 committed by GitHub
parent 2cc5ea933d
commit 54e74cc15f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 160 additions and 102 deletions

1
changelog.d/12072.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View File

@ -13,11 +13,16 @@
# limitations under the License.
import os
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.urls import ConsentURIBuilder
from synapse.rest.client import login, room
from synapse.rest.consent import consent_resource
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
user_id = True
hijack_auth = False
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["form_secret"] = "123abc"
@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
return hs
def test_render_public_consent(self):
def test_render_public_consent(self) -> None:
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
channel = make_request(
@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"/consent?v=1",
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
def test_accept_consent(self):
def test_accept_consent(self) -> None:
"""
A user can use the consent form to accept the terms.
"""
@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token,
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
# Get the version from the body, and whether we've consented
version, consented = channel.result["body"].decode("ascii").split(",")
@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token,
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
# Fetch the consent page, to get the consent version -- it should have
# changed
@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token,
shorthand=False,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
# Get the version from the body, and check that it's the version we
# agreed to, and that we've consented to it.

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, login, register
@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
devices.register_servlets,
]
def test_receiving_local_device_list_changes(self):
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self):
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
self.assertEqual(
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []

View File

@ -11,9 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest import admin
from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["enable_ephemeral_messages"] = True
@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
def test_message_expiry_no_delay(self):
def test_message_expiry_no_delay(self) -> None:
"""Tests that sending a message sent with a m.self_destruct_after field set to the
past results in that event being deleted right away.
"""
@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content)
def test_message_expiry_delay(self):
def test_message_expiry_delay(self) -> None:
"""Tests that sending a message with a m.self_destruct_after field set to the
future results in that event not being deleted right away, but advancing the
clock to after that expiry timestamp causes the event to be deleted.
@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content)
def get_event(self, room_id, event_id, expected_code=200):
def get_event(
self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK
) -> JsonDict:
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
channel = self.make_request("GET", url)

View File

@ -13,9 +13,14 @@
# limitations under the License.
import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["enable_3pid_lookup"] = False
@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase):
return self.hs
def test_3pid_lookup_disabled(self):
def test_3pid_lookup_disabled(self) -> None:
self.hs.config.registration.enable_3pid_lookup = False
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"]
params = {
@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)

View File

@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def test_rejects_device_id_ice_key_outside_of_list(self):
def test_rejects_device_id_ice_key_outside_of_list(self) -> None:
self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland")
bob = self.register_user("bob", "uncle")
@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
channel.result,
)
def test_rejects_device_key_given_as_map_to_bool(self):
def test_rejects_device_key_given_as_map_to_bool(self) -> None:
self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland")
bob = self.register_user("bob", "uncle")
@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
channel.result,
)
def test_requires_device_key(self):
def test_requires_device_key(self) -> None:
"""`device_keys` is required. We should complain if it's missing."""
self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland")

View File

@ -13,11 +13,16 @@
# limitations under the License.
import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import account, login, password_policy, register
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
account.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.register_url = "/_matrix/client/r0/register"
self.policy = {
"enabled": True,
@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
return hs
def test_get_policy(self):
def test_get_policy(self) -> None:
"""Tests if the /password_policy endpoint returns the configured policy."""
channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(
channel.json_body,
{
@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
channel.result,
)
def test_password_too_short(self):
def test_password_too_short(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "shorty"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.PASSWORD_TOO_SHORT,
channel.result,
)
def test_password_no_digit(self):
def test_password_no_digit(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.PASSWORD_NO_DIGIT,
channel.result,
)
def test_password_no_symbol(self):
def test_password_no_symbol(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.PASSWORD_NO_SYMBOL,
channel.result,
)
def test_password_no_uppercase(self):
def test_password_no_uppercase(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.PASSWORD_NO_UPPERCASE,
channel.result,
)
def test_password_no_lowercase(self):
def test_password_no_lowercase(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.PASSWORD_NO_LOWERCASE,
channel.result,
)
def test_password_compliant(self):
def test_password_compliant(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows.
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_password_change(self):
def test_password_change(self) -> None:
"""This doesn't test every possible use case, only that hitting /account/password
triggers the password validation code.
"""
@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
access_token=tok,
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)

View File

@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase):
sync.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register a room admin, moderator and regular user
self.admin_user_id = self.register_user("admin", "pass")
self.admin_access_token = self.login("admin", "pass")
@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
tok=self.admin_access_token,
)
def test_non_admins_cannot_enable_room_encryption(self):
def test_non_admins_cannot_enable_room_encryption(self) -> None:
# have the mod try to enable room encryption
self.helper.send_state(
self.room_id,
@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"},
tok=self.user_access_token,
expect_code=403, # expect failure
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
def test_non_admins_cannot_send_server_acl(self):
def test_non_admins_cannot_send_server_acl(self) -> None:
# have the mod try to send a server ACL
self.helper.send_state(
self.room_id,
@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"],
},
tok=self.mod_access_token,
expect_code=403, # expect failure
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
# have the user try to send a server ACL
@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"],
},
tok=self.user_access_token,
expect_code=403, # expect failure
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
def test_non_admins_cannot_tombstone_room(self):
def test_non_admins_cannot_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token
@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"replacement_room": self.upgraded_room_id,
},
tok=self.mod_access_token,
expect_code=403, # expect failure
expect_code=HTTPStatus.FORBIDDEN, # expect failure
)
# have the user try to send a tombstone event
@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=403, # expect failure
)
def test_admins_can_enable_room_encryption(self):
def test_admins_can_enable_room_encryption(self) -> None:
# have the admin try to enable room encryption
self.helper.send_state(
self.room_id,
"m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"},
tok=self.admin_access_token,
expect_code=200, # expect success
expect_code=HTTPStatus.OK, # expect success
)
def test_admins_can_send_server_acl(self):
def test_admins_can_send_server_acl(self) -> None:
# have the admin try to send a server ACL
self.helper.send_state(
self.room_id,
@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"],
},
tok=self.admin_access_token,
expect_code=200, # expect success
expect_code=HTTPStatus.OK, # expect success
)
def test_admins_can_tombstone_room(self):
def test_admins_can_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token
@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"replacement_room": self.upgraded_room_id,
},
tok=self.admin_access_token,
expect_code=200, # expect success
expect_code=HTTPStatus.OK, # expect success
)
def test_cannot_set_string_power_levels(self):
def test_cannot_set_string_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
self.room_id,
"m.room.power_levels",
@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels",
room_power_levels,
tok=self.admin_access_token,
expect_code=400, # expect failure
expect_code=HTTPStatus.BAD_REQUEST, # expect failure
)
self.assertEqual(
@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
body,
)
def test_cannot_set_unsafe_large_power_levels(self):
def test_cannot_set_unsafe_large_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
self.room_id,
"m.room.power_levels",
@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels",
room_power_levels,
tok=self.admin_access_token,
expect_code=400, # expect failure
expect_code=HTTPStatus.BAD_REQUEST, # expect failure
)
self.assertEqual(
@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
body,
)
def test_cannot_set_unsafe_small_power_levels(self):
def test_cannot_set_unsafe_small_power_levels(self) -> None:
room_power_levels = self.helper.get_state(
self.room_id,
"m.room.power_levels",
@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels",
room_power_levels,
tok=self.admin_access_token,
expect_code=400, # expect failure
expect_code=HTTPStatus.BAD_REQUEST, # expect failure
)
self.assertEqual(

View File

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler
from synapse.rest.client import presence
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id)
servlets = [presence.register_servlets]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
return hs
def test_put_presence(self):
def test_put_presence(self) -> None:
"""
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase):
"PUT", "/presence/%s/status" % (self.user_id,), body
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self):
def test_put_presence_disabled(self) -> None:
"""
PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler.
@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase):
"PUT", "/presence/%s/status" % (self.user_id,), body
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)

View File

@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
return room_id, event_id_a, event_id_b, event_id_c
@unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
def test_same_state_groups_for_whole_historical_batch(self):
def test_same_state_groups_for_whole_historical_batch(self) -> None:
"""Make sure that when using the `/batch_send` endpoint to import a
bunch of historical messages, it re-uses the same `state_group` across
the whole batch. This is an easy optimization to make sure we're getting

View File

@ -19,6 +19,7 @@ import json
import re
import time
import urllib.parse
from http import HTTPStatus
from typing import (
Any,
AnyStr,
@ -89,7 +90,7 @@ class RestHelper:
is_public: Optional[bool] = None,
room_version: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> Optional[str]:
@ -137,12 +138,19 @@ class RestHelper:
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
if expect_code == 200:
if expect_code == HTTPStatus.OK:
return channel.json_body["room_id"]
else:
return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
def invite(
self,
room: Optional[str] = None,
src: Optional[str] = None,
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership(
room=room,
src=src,
@ -156,7 +164,7 @@ class RestHelper:
self,
room: str,
user: Optional[str] = None,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
appservice_user_id: Optional[str] = None,
) -> None:
@ -170,7 +178,14 @@ class RestHelper:
expect_code=expect_code,
)
def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None):
def knock(
self,
room: Optional[str] = None,
user: Optional[str] = None,
reason: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
temp_id = self.auth_user_id
self.auth_user_id = user
path = "/knock/%s" % room
@ -199,7 +214,13 @@ class RestHelper:
self.auth_user_id = temp_id
def leave(self, room=None, user=None, expect_code=200, tok=None):
def leave(
self,
room: Optional[str] = None,
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership(
room=room,
src=user,
@ -209,7 +230,7 @@ class RestHelper:
expect_code=expect_code,
)
def ban(self, room: str, src: str, targ: str, **kwargs: object):
def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
self.change_membership(
room=room,
@ -228,7 +249,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
appservice_user_id: Optional[str] = None,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
) -> None:
"""
@ -294,13 +315,13 @@ class RestHelper:
def send(
self,
room_id,
body=None,
txn_id=None,
tok=None,
expect_code=200,
room_id: str,
body: Optional[str] = None,
txn_id: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
) -> JsonDict:
if body is None:
body = "body_text_here"
@ -318,14 +339,14 @@ class RestHelper:
def send_event(
self,
room_id,
type,
room_id: str,
type: str,
content: Optional[dict] = None,
txn_id=None,
tok=None,
expect_code=200,
txn_id: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
) -> JsonDict:
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@ -358,10 +379,10 @@ class RestHelper:
event_type: str,
body: Optional[Dict[str, Any]],
tok: str,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
state_key: str = "",
method: str = "GET",
) -> Dict:
) -> JsonDict:
"""Read or write some state from a given room
Args:
@ -410,9 +431,9 @@ class RestHelper:
room_id: str,
event_type: str,
tok: str,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
state_key: str = "",
):
) -> JsonDict:
"""Gets some state from a room
Args:
@ -438,9 +459,9 @@ class RestHelper:
event_type: str,
body: Dict[str, Any],
tok: str,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
state_key: str = "",
):
) -> JsonDict:
"""Set some state in a room
Args:
@ -467,8 +488,8 @@ class RestHelper:
image_data: bytes,
tok: str,
filename: str = "test.png",
expect_code: int = 200,
) -> dict:
expect_code: int = HTTPStatus.OK,
) -> JsonDict:
"""Upload a piece of test media to the media repo
Args:
resource: The resource that will handle the upload request
@ -513,7 +534,7 @@ class RestHelper:
channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
# expect a confirmation page
assert channel.code == 200, channel.result
assert channel.code == HTTPStatus.OK, channel.result
# fish the matrix login token out of the body of the confirmation page
m = re.search(
@ -532,7 +553,7 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
assert channel.code == 200
assert channel.code == HTTPStatus.OK
return channel.json_body
def auth_via_oidc(
@ -641,7 +662,7 @@ class RestHelper:
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
code=200,
code=HTTPStatus.OK,
phrase=b"OK",
body=json.dumps(resp_obj).encode("utf-8"),
)
@ -739,7 +760,7 @@ class RestHelper:
self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
)
# that should serve a confirmation page
assert channel.code == 200, channel.text_body
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
# parse the confirmation page to fish out the link.