Properly typecheck tests.api (#14983)

This commit is contained in:
David Robertson 2023-02-03 20:03:23 +00:00 committed by GitHub
parent b2d97bac09
commit 6e6edea6c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 111 deletions

1
changelog.d/14983.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,7 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
|tests/api/test_auth.py
|tests/appservice/test_scheduler.py
|tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py
@ -73,6 +72,9 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False
[mypy-tests.api.*]
disallow_untyped_defs = True
[mypy-tests.app.*]
disallow_untyped_defs = True

View File

@ -252,9 +252,9 @@ class FilterCollection:
return self._room_timeline_filter.unread_thread_notifications
async def filter_presence(
self, events: Iterable[UserPresenceState]
self, presence_states: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return await self._presence_filter.filter(events)
return await self._presence_filter.filter(presence_states)
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._account_data.filter(events)

View File

@ -31,7 +31,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester
from synapse.types import Requester, UserID
from synapse.util import Clock
from tests import unittest
@ -41,10 +41,12 @@ from tests.utils import mock_getRawHeaders
class AuthTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = Mock()
hs.datastores.main = self.store
# type-ignore: datastores is None until hs.setup() is called---but it'll
# have been called by the HomeserverTestCase machinery.
hs.datastores.main = self.store # type: ignore[union-attr]
hs.get_auth_handler().store = self.store
self.auth = Auth(hs)
@ -61,7 +63,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False)
def test_get_user_by_req_user_valid_token(self):
def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
@ -74,7 +76,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
@ -86,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info)
@ -98,7 +100,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
def test_get_user_by_req_appservice_valid_token(self):
def test_get_user_by_req_appservice_valid_token(self) -> None:
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
@ -112,7 +114,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_good_ip(self):
def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
from netaddr import IPSet
app_service = Mock(
@ -131,7 +133,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
from netaddr import IPSet
app_service = Mock(
@ -153,7 +155,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None)
@ -166,7 +168,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self):
def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
@ -179,7 +181,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@ -200,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_failure(self.auth.get_user_by_req(request), AuthError)
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
"""
Tests that when an application service passes the device_id URL parameter
with the ID of a valid device for the user in question,
@ -249,7 +251,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
"""
Tests that when an application service passes the device_id URL parameter
with an ID that is not a valid device ID for the user in question,
@ -279,7 +281,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.value.code, 400)
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
user_id="@baldrick:matrix.org",
@ -298,7 +300,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request))
self.store.insert_client_ip.assert_called_once()
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
@ -318,7 +320,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self):
def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
@ -336,7 +338,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
)
def test_get_guest_user_from_macaroon(self):
def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None)
@ -357,7 +359,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
def test_blocking_mau(self):
def test_blocking_mau(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
lots_of_users = 100
@ -381,7 +383,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self):
def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
@ -400,7 +402,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
self,
) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False
@ -418,7 +422,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender",
)
requester = Requester(
user="@appservice:server",
user=UserID.from_string("@appservice:server"),
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
@ -428,7 +432,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
self,
) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True
@ -446,7 +452,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
sender="@appservice:sender",
)
requester = Requester(
user="@appservice:server",
user=UserID.from_string("@appservice:server"),
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
@ -459,7 +465,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
def test_reserved_threepid(self):
def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2)
@ -476,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
def test_hs_disabled(self):
def test_hs_disabled(self) -> None:
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(
@ -486,7 +492,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
def test_hs_disabled_no_server_notices_user(self):
def test_hs_disabled_no_server_notices_user(self) -> None:
"""Check that 'hs_disabled_message' works correctly when there is no
server_notices user.
"""
@ -503,7 +509,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
def test_server_notices_mxid_special_cased(self):
def test_server_notices_mxid_special_cased(self) -> None:
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.auth_blocking._server_notices_mxid = user

View File

@ -14,40 +14,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from unittest.mock import patch
import jsonschema
from frozendict import frozendict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from synapse.api.presence import UserPresenceState
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.events.test_utils import MockEvent
user_localpart = "test_user"
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
if "content" not in kwargs:
kwargs["content"] = {}
return make_event_from_dict(kwargs)
class FilteringTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastores().main
def test_errors_on_invalid_filters(self):
def test_errors_on_invalid_filters(self) -> None:
# See USER_FILTER_SCHEMA for the filter schema.
invalid_filters = [
invalid_filters: List[JsonDict] = [
# `account_data` must be a dictionary
{"account_data": "Hello World"},
# `event_fields` entries must not contain backslashes
@ -63,10 +59,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter)
def test_ignores_unknown_filter_fields(self):
def test_ignores_unknown_filter_fields(self) -> None:
# For forward compatibility, we must ignore unknown filter fields.
# See USER_FILTER_SCHEMA for the filter schema.
filters = [
filters: List[JsonDict] = [
{"org.matrix.msc9999.future_option": True},
{"presence": {"org.matrix.msc9999.future_option": True}},
{"room": {"org.matrix.msc9999.future_option": True}},
@ -76,8 +72,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.filtering.check_valid_filter(filter)
# Must not raise.
def test_valid_filters(self):
valid_filters = [
def test_valid_filters(self) -> None:
valid_filters: List[JsonDict] = [
{
"room": {
"timeline": {"limit": 20},
@ -132,22 +128,22 @@ class FilteringTestCase(unittest.HomeserverTestCase):
except jsonschema.ValidationError as e:
self.fail(e)
def test_limits_are_applied(self):
def test_limits_are_applied(self) -> None:
# TODO
pass
def test_definition_types_works_with_literals(self):
def test_definition_types_works_with_literals(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_wildcards(self):
def test_definition_types_works_with_wildcards(self) -> None:
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_unknowns(self):
def test_definition_types_works_with_unknowns(self) -> None:
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(
sender="@foo:bar",
@ -156,24 +152,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_literals(self):
def test_definition_not_types_works_with_literals(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_wildcards(self):
def test_definition_not_types_works_with_wildcards(self) -> None:
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_unknowns(self):
def test_definition_not_types_works_with_unknowns(self) -> None:
definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_types_takes_priority_over_types(self):
def test_definition_not_types_takes_priority_over_types(self) -> None:
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"],
@ -181,35 +177,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_literals(self):
def test_definition_senders_works_with_literals(self) -> None:
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_unknowns(self):
def test_definition_senders_works_with_unknowns(self) -> None:
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_literals(self):
def test_definition_not_senders_works_with_literals(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_unknowns(self):
def test_definition_not_senders_works_with_unknowns(self) -> None:
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_takes_priority_over_senders(self):
def test_definition_not_senders_takes_priority_over_senders(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
@ -219,14 +215,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_literals(self):
def test_definition_rooms_works_with_literals(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_unknowns(self):
def test_definition_rooms_works_with_unknowns(self) -> None:
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@ -235,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_literals(self):
def test_definition_not_rooms_works_with_literals(self) -> None:
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@ -244,7 +240,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_unknowns(self):
def test_definition_not_rooms_works_with_unknowns(self) -> None:
definition = {"not_rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar",
@ -253,7 +249,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self):
def test_definition_not_rooms_takes_priority_over_rooms(self) -> None:
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"],
@ -263,7 +259,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event(self):
def test_definition_combined_event(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@ -279,7 +275,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_sender(self):
def test_definition_combined_event_bad_sender(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@ -295,7 +291,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_room(self):
def test_definition_combined_event_bad_room(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@ -311,7 +307,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_type(self):
def test_definition_combined_event_bad_type(self) -> None:
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
@ -327,7 +323,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_labels(self):
def test_filter_labels(self) -> None:
definition = {"org.matrix.labels": ["#fun"]}
event = MockEvent(
sender="@foo:bar",
@ -356,7 +352,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_not_labels(self):
def test_filter_not_labels(self) -> None:
definition = {"org.matrix.not_labels": ["#fun"]}
event = MockEvent(
sender="@foo:bar",
@ -377,7 +373,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
def test_filter_rel_type(self):
def test_filter_rel_type(self) -> None:
definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
event = MockEvent(
sender="@foo:bar",
@ -407,7 +403,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
def test_filter_not_rel_type(self):
def test_filter_not_rel_type(self) -> None:
definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
event = MockEvent(
sender="@foo:bar",
@ -436,15 +432,25 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
def test_filter_presence_match(self) -> None:
"""Check that filter_presence return events which matches the filter."""
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
)
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
presence_states = [
UserPresenceState(
user_id="@foo:bar",
state="unavailable",
last_active_ts=0,
last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
),
]
user_filter = self.get_success(
self.filtering.get_user_filter(
@ -452,23 +458,29 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
results = self.get_success(user_filter.filter_presence(events=events))
self.assertEqual(events, results)
results = self.get_success(user_filter.filter_presence(presence_states))
self.assertEqual(presence_states, results)
def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
def test_filter_presence_no_match(self) -> None:
"""Check that filter_presence does not return events rejected by the filter."""
user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}}
filter_id = self.get_success(
self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
)
event = MockEvent(
event_id="$asdasd:localhost",
sender="@foo:bar",
type="custom.avatar.3d.crazy",
)
events = [event]
presence_states = [
UserPresenceState(
user_id="@foo:bar",
state="unavailable",
last_active_ts=0,
last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
),
]
user_filter = self.get_success(
self.filtering.get_user_filter(
@ -476,10 +488,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
results = self.get_success(user_filter.filter_presence(events=events))
results = self.get_success(user_filter.filter_presence(presence_states))
self.assertEqual([], results)
def test_filter_room_state_match(self):
def test_filter_room_state_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
@ -498,7 +510,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEqual(events, results)
def test_filter_room_state_no_match(self):
def test_filter_room_state_no_match(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
self.datastore.add_user_filter(
@ -519,7 +531,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
results = self.get_success(user_filter.filter_room_state(events))
self.assertEqual([], results)
def test_filter_rooms(self):
def test_filter_rooms(self) -> None:
definition = {
"rooms": ["!allowed:example.com", "!excluded:example.com"],
"not_rooms": ["!excluded:example.com"],
@ -535,7 +547,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
def test_filter_relations(self):
def test_filter_relations(self) -> None:
events = [
# An event without a relation.
MockEvent(
@ -551,9 +563,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="org.matrix.custom.event",
room_id="!foo:bar",
),
# Non-EventBase objects get passed through.
{},
]
jsondicts: List[JsonDict] = [{}]
# For the following tests we patch the datastore method (intead of injecting
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
@ -561,7 +572,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
# Filter for a particular sender.
definition = {"related_by_senders": ["@foo:bar"]}
async def events_have_relations(*args, **kwargs):
async def events_have_relations(*args: object, **kwargs: object) -> List[str]:
return ["$with_relation"]
with patch.object(
@ -572,9 +583,17 @@ class FilteringTestCase(unittest.HomeserverTestCase):
Filter(self.hs, definition)._check_event_relations(events)
)
)
self.assertEqual(filtered_events, events[1:])
# Non-EventBase objects get passed through.
filtered_jsondicts = list(
self.get_success(
Filter(self.hs, definition)._check_event_relations(jsondicts)
)
)
def test_add_filter(self):
self.assertEqual(filtered_events, events[1:])
self.assertEqual(filtered_jsondicts, [{}])
def test_add_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(
@ -595,7 +614,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
),
)
def test_get_filter(self):
def test_get_filter(self) -> None:
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = self.get_success(

View File

@ -6,7 +6,7 @@ from tests import unittest
class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
def test_allowed_via_can_do_action(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
@ -31,7 +31,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed)
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
appservice = ApplicationService(
token="fake_token",
id="foo",
@ -64,7 +64,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(20.0, time_allowed)
def test_allowed_appservice_via_can_requester_do_action(self):
def test_allowed_appservice_via_can_requester_do_action(self) -> None:
appservice = ApplicationService(
token="fake_token",
id="foo",
@ -97,7 +97,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(-1, time_allowed)
def test_allowed_via_ratelimit(self):
def test_allowed_via_ratelimit(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
@ -120,7 +120,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key="test_id", _time_now_s=10)
)
def test_allowed_via_can_do_action_and_overriding_parameters(self):
def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
@ -169,7 +169,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
def test_allowed_via_ratelimit_and_overriding_parameters(self):
def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
"""Test that we can override options of the ratelimit method that would otherwise
fail an action
"""
@ -204,7 +204,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
)
def test_pruning(self):
def test_pruning(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
@ -223,7 +223,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertNotIn("test_id_1", limiter.actions)
def test_db_user_override(self):
def test_db_user_override(self) -> None:
"""Test that users that have ratelimiting disabled in the DB aren't
ratelimited.
"""
@ -250,7 +250,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
for _ in range(20):
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
def test_multiple_actions(self):
def test_multiple_actions(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,

View File

@ -35,6 +35,8 @@ def MockEvent(**kwargs: Any) -> EventBase:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
if "content" not in kwargs:
kwargs["content"] = {}
return make_event_from_dict(kwargs)