Add final type hint to tests.unittest. (#15072)

Adds a return type to HomeServerTestCase.make_homeserver and deal
with any variables which are no longer Any.
This commit is contained in:
Patrick Cloke 2023-02-14 14:03:35 -05:00 committed by GitHub
parent 119e0795a5
commit 42aea0d8af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 433 additions and 320 deletions

1
changelog.d/15072.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -56,9 +56,6 @@ disallow_untyped_defs = False
[mypy-synapse.storage.database] [mypy-synapse.storage.database]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.unittest]
disallow_untyped_defs = False
[mypy-tests.util.caches.test_descriptors] [mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False disallow_untyped_defs = False

View File

@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
} }
# Listen with the config # Listen with the config
self.hs._listen_http(parse_listener_def(0, config)) hs = self.hs
assert isinstance(hs, GenericWorkerServer)
hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
} }
# Listen with the config # Listen with the config
self.hs._listener_http(self.hs.config, parse_listener_def(0, config)) hs = self.hs
assert isinstance(hs, SynapseHomeServer)
hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]

View File

@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1") key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys( r = self.hs.get_datastores().main.store_server_verify_keys(
"server9", "server9",
time.time() * 1000, int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
) )
self.get_success(r) self.get_success(r)
@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1") key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys( r = self.hs.get_datastores().main.store_server_verify_keys(
"server9", "server9",
time.time() * 1000, int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this # None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted # API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys. # to 0 when fetched in KeyStore.get_server_verify_keys.
@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
) )
res = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res), 1) self.assertEqual(len(res_keys), 1)
res = res[0] res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME) self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
) )
res = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res), 1) self.assertEqual(len(res_keys), 1)
res = res[0] res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
) )
res = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res), 1) self.assertEqual(len(res_keys), 1)
res = res[0] res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)

View File

@ -156,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"]) self.fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({}) self.fed_transport_client.send_transaction = simple_async_mock({})
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client, federation_transport_client=self.fed_transport_client,
) )
load_legacy_presence_router(hs) load_legacy_presence_router(hs)
@ -422,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
# #
# Thus we reset the mock, and try sending all online local user # Thus we reset the mock, and try sending all online local user
# presence again # presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock() self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence # Broadcast local user online presence
self.get_success( self.get_success(
@ -447,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
} }
found_users = set() found_users = set()
calls = ( calls = self.fed_transport_client.send_transaction.call_args_list
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
for call in calls: for call in calls:
call_args = call[0] call_args = call[0]
federation_transaction: Transaction = call_args[0] federation_transaction: Transaction = call_args[0]

View File

@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID, create_requester
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity # Artificially raise the complexity
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
async def get_current_state_event_counts(room_id: str) -> int:
return int(500 * 1.23)
store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
# Get the room complexity again -- make sure it's our artificial value # Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
None, create_requester(u1),
["other.example.com"], ["other.example.com"],
"roomid", "roomid",
UserID.from_string(u1), UserID.from_string(u1),
@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
None, create_requester(u1),
["other.example.com"], ["other.example.com"],
"roomid", "roomid",
UserID.from_string(u1), UserID.from_string(u1),
@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
# Artificially raise the complexity # Artificially raise the complexity
self.hs.get_datastores().main.get_current_state_event_counts = ( async def get_current_state_event_counts(room_id: str) -> int:
lambda x: make_awaitable(600) return 600
)
self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
d = handler._remote_join( d = handler._remote_join(
None, create_requester(u1),
["other.example.com"], ["other.example.com"],
room_1, room_1,
UserID.from_string(u1), UserID.from_string(u1),
@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
None, create_requester(u1),
["other.example.com"], ["other.example.com"],
"roomid", "roomid",
UserID.from_string(u1), UserID.from_string(u1),
@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
None, create_requester(u1),
["other.example.com"], ["other.example.com"],
"roomid", "roomid",
UserID.from_string(u1), UserID.from_string(u1),

View File

@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.sender import (
FederationSender,
PerDestinationQueue,
TransactionManager,
)
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=self.federation_transport_client,
) )
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.pdus: List[JsonDict] = [] self.pdus: List[JsonDict] = []
self.failed_pdus: List[JsonDict] = [] self.failed_pdus: List[JsonDict] = []
self.is_online = True self.is_online = True
self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction self.record_transaction
) )
federation_sender = hs.get_federation_sender()
assert isinstance(federation_sender, FederationSender)
self.federation_sender = federation_sender
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["federation_sender_instances"] = None config["federation_sender_instances"] = None
@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# let's delete the federation transmission queue # let's delete the federation transmission queue
# (this pretends we are starting up fresh.) # (this pretends we are starting up fresh.)
self.assertFalse( self.assertFalse(
self.hs.get_federation_sender() self.federation_sender._per_destination_queues[
._per_destination_queues["host2"] "host2"
.transmission_loop_running ].transmission_loop_running
) )
del self.hs.get_federation_sender()._per_destination_queues["host2"] del self.federation_sender._per_destination_queues["host2"]
# let's also clear any backoffs # let's also clear any backoffs
self.get_success( self.get_success(
@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# also fetch event 5 so we know its last_successful_stream_ordering later # also fetch event 5 so we know its last_successful_stream_ordering later
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
assert event_2.internal_metadata.stream_ordering is not None
self.get_success( self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_2.internal_metadata.stream_ordering "host2", event_2.internal_metadata.stream_ordering
@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
def wake_destination_track(destination: str) -> None: def wake_destination_track(destination: str) -> None:
woken.append(destination) woken.append(destination)
self.hs.get_federation_sender().wake_destination = wake_destination_track self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
# cancel the pre-existing timer for _wake_destinations_needing_catchup # cancel the pre-existing timer for _wake_destinations_needing_catchup
# this is because we are calling it manually rather than waiting for it # this is because we are calling it manually rather than waiting for it
# to be called automatically # to be called automatically
self.hs.get_federation_sender()._catchup_after_startup_timer.cancel() assert self.federation_sender._catchup_after_startup_timer is not None
self.federation_sender._catchup_after_startup_timer.cancel()
self.get_success( self.get_success(
self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0 self.federation_sender._wake_destinations_needing_catchup(), by=5.0
) )
# ASSERT (_wake_destinations_needing_catchup): # ASSERT (_wake_destinations_needing_catchup):
@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
) )
) )
assert event_1.internal_metadata.stream_ordering is not None
self.get_success( self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering "host2", event_1.internal_metadata.stream_ordering

View File

@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9, RoomVersions.V9,
) )
) )
self.assertIsNotNone(pulled_pdu_info2) assert pulled_pdu_info2 is not None
remote_pdu2 = pulled_pdu_info2.pdu remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event # Sanity check that we are working against the same event
@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9, RoomVersions.V9,
) )
) )
self.assertIsNotNone(pulled_pdu_info) assert pulled_pdu_info is not None
remote_pdu = pulled_pdu_info.pdu remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent # check the right call got made to the agent

View File

@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.handlers.device import DeviceHandler
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
""" """
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=self.federation_transport_client,
) )
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return config return config
def test_send_receipts(self) -> None: def test_send_receipts(self) -> None:
mock_send_transaction = ( mock_send_transaction = self.federation_transport_client.send_transaction
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
) )
def test_send_receipts_thread(self) -> None: def test_send_receipts_thread(self) -> None:
mock_send_transaction = ( mock_send_transaction = self.federation_transport_client.send_transaction
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = make_awaitable({})
# Create receipts for: # Create receipts for:
@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self) -> None: def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but """Send two receipts in quick succession; the second should be flushed, but
only after 20ms""" only after 20ms"""
mock_send_transaction = ( mock_send_transaction = self.federation_transport_client.send_transaction
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver( self.federation_transport_client = Mock(
federation_transport_client=Mock(
spec=["send_transaction", "query_user_devices"] spec=["send_transaction", "query_user_devices"]
), )
return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
) )
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment] hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self.device_handler = device_handler
# whenever send_transaction is called, record the edu data # whenever send_transaction is called, record the edu data
self.edus: List[JsonDict] = [] self.edus: List[JsonDict] = []
self.hs.get_federation_transport_client().send_transaction.side_effect = ( self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction self.record_transaction
) )
@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause # Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists. # it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = ( self.federation_transport_client.query_user_devices.return_value = (
make_awaitable( make_awaitable(
{ {
"stream_id": "1", "stream_id": "1",
@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
) )
self.get_success( self.get_success(
self.hs.get_device_handler().device_list_updater.incoming_device_list_update( self.device_handler.device_list_updater.incoming_device_list_update(
"host2", "host2",
{ {
"user_id": "@user2:host2", "user_id": "@user2:host2",
@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again # delete them again
self.get_success( self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
# We queue up device list updates to be sent over federation, so we # We queue up device list updates to be sent over federation, so we
# advance to clear the queue. # advance to clear the queue.
@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""If the destination server is unreachable, all the updates should get sent on """If the destination server is unreachable, all the updates should get sent on
recovery recovery
""" """
mock_send_txn = self.hs.get_federation_transport_client().send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices # create devices
@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3") self.login("user", "pass", device_id="D3")
# delete them again # delete them again
self.get_success( self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
# We queue up device list updates to be sent over federation, so we # We queue up device list updates to be sent over federation, so we
# advance to clear the queue. # advance to clear the queue.
@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable. This case tests the behaviour when the server has never been reachable.
""" """
mock_send_txn = self.hs.get_federation_transport_client().send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices # create devices
@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3") self.login("user", "pass", device_id="D3")
# delete them again # delete them again
self.get_success( self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
# We queue up device list updates to be sent over federation, so we # We queue up device list updates to be sent over federation, so we
# advance to clear the queue. # advance to clear the queue.
@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline # now the server goes offline
mock_send_txn = self.hs.get_federation_transport_client().send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D2")
@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1) self.reactor.advance(1)
# delete them again # delete them again
self.get_success( self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.assertGreaterEqual(mock_send_txn.call_count, 3) self.assertGreaterEqual(mock_send_txn.call_count, 3)

View File

@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests # Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = [] self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock( self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services return_value=self._services
) )

View File

@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})
request = _mock_request() request = _mock_request()
@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO. # Map a user via SSO.
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})
@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {}) cas_response = CasResponse("föö", {})
request = _mock_request() request = _mock_request()
@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department. # The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})

View File

@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
# we should now have an unused alg1 key # we should now have an unused alg1 key
res = self.get_success( fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, ["alg1"]) self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback # claiming an OTK when no OTKs are available should return the fallback
# key # key
res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
) )
self.assertEqual( self.assertEqual(
res, claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
) )
# we shouldn't have any unused fallback keys again # we shouldn't have any unused fallback keys again
res = self.get_success( unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, []) self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key # claiming an OTK again should return the same fallback key
res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
) )
self.assertEqual( self.assertEqual(
res, claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
) )
@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
res = self.get_success( unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, []) self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key # uploading a new fallback key should result in an unused fallback key
self.get_success( self.get_success(
@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
res = self.get_success( unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, ["alg1"]) self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the # if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback # one-time key, and then go back to the fallback
@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
) )
self.assertEqual( self.assertEqual(
res, claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
) )
res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
) )
self.assertEqual( self.assertEqual(
res, claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
) )
@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
) )
self.assertEqual( self.assertEqual(
res, claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
) )
@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key # upload two device keys, which will be signed later by the self-signing key
device_key_1 = { device_key_1: JsonDict = {
"user_id": local_user, "user_id": local_user,
"device_id": "abc", "device_id": "abc",
"algorithms": [ "algorithms": [
@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
"signatures": {local_user: {"ed25519:abc": "base64+signature"}}, "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
} }
device_key_2 = { device_key_2: JsonDict = {
"user_id": local_user, "user_id": local_user,
"device_id": "def", "device_id": "def",
"algorithms": [ "algorithms": [
@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
} }
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
device_handler = self.hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
e = self.get_failure( e = self.get_failure(
self.hs.get_device_handler().check_device_registered( device_handler.check_device_registered(
user_id=local_user, user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name", initial_device_display_name="new display name",
@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz" device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
device_key = { device_key: JsonDict = {
"user_id": local_user, "user_id": local_user,
"device_id": device_id, "device_id": device_id,
"algorithms": [ "algorithms": [
@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
master_key = { master_key: JsonDict = {
"user_id": local_user, "user_id": local_user,
"usage": ["master"], "usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey}, "keys": {"ed25519:" + master_pubkey: master_pubkey},
@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user # the first user
other_user = "@otherboris:" + self.hs.hostname other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
other_master_key = { other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user, "user_id": other_user,
"usage": ["master"], "usage": ["master"],
@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock( self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable( return_value=make_awaitable(
{ {
"device_keys": {remote_user_id: {}}, "device_keys": {remote_user_id: {}},
@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock( self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable( return_value=make_awaitable(
{ {
"user_id": remote_user_id, "user_id": remote_user_id,

View File

@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote # We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event. # server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
self.hs.get_federation_client().backfill = federation_client_backfill_mock self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us # We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function. # to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock( persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
) )
self.hs.get_federation_event_handler().persist_events_and_notify = ( self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock persist_events_and_notify_mock
) )
@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync. # Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync. # Try to start another partial state sync.
# Nothing should happen. # Nothing should happen.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync # End the partial state sync
@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work. # The next attempt to start the partial state sync should work.
is_partial_state = True is_partial_state = True
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2) self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None: def test_partial_state_room_sync_restart(self) -> None:
@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync. # Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync. # Fail the partial state sync.
@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1) self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again. # Start the partial state sync again.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2) self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync. # Deduplicate another partial state sync.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2) self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync. # Fail the partial state sync.
@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3) self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with( mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3", initial_destination="hs3",
other_destinations=["hs2"], other_destinations={"hs2"},
room_id="room_id", room_id="room_id",
) )

View File

@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier: if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self.get_success( self.get_success(
persistence.persist_event( persistence.persist_event(
prev_event, prev_event,
@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event, bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event, rejected_kick_event.event_id: rejected_kick_event,
}, },
state_res_store=main_store, state_res_store=StateResolutionStore(main_store),
) )
), ),
[bert_member_event.event_id, rejected_kick_event.event_id], [bert_member_event.event_id, rejected_kick_event.event_id],
@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id, rejected_power_levels_event.event_id,
], ],
event_map={}, event_map={},
state_res_store=main_store, state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(), full_conflicted_set=set(),
) )
), ),

View File

@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler() self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = ( persistence = self.hs.get_storage_controllers().persistence
self.hs.get_storage_controllers().persistence assert persistence is not None
) self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar") self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar") self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.info = self.get_success( info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token( self.hs.get_datastores().main.get_user_by_access_token(
self.access_token, self.access_token,
) )
) )
self.token_id = self.info.token_id assert info is not None
self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id) self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View File

@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test. username: The username to use for the test.
registration: Whether to test with registration URLs. registration: Whether to test with registration URLs.
""" """
self.hs.get_identity_handler().send_threepid_validation = Mock( self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0), return_value=make_awaitable(0),
) )

View File

@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None: def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1)) self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True)) self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self, self,
) -> None: ) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True)) self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated. # Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["federatable"]) self.assertFalse(room["federatable"])
self.assertFalse(room["public"]) self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public") self.assertEqual(room["join_rules"], "public")
@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room. # Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertEqual(room["join_rules"], "public") self.assertEqual(room["join_rules"], "public")
# Both users should be in the room. # Both users should be in the room.
@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room. # Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["public"]) self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join") self.assertEqual(room["guest_access"], "can_join")
@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room. # Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["public"]) self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join") self.assertEqual(room["guest_access"], "can_join")

View File

@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO. # Map a user via SSO.
saml_response = FakeAuthnResponse( saml_response = FakeAuthnResponse(
@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too # mock out the error renderer too
sso_handler = self.hs.get_sso_handler() sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None) sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request() request = _mock_request()
@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer # stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler() sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None) sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID # register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department. # The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})

View File

@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True) mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too # we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"]) self.mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000 # the tests assume that we are starting at unix time 1000
reactor.pump((1000,)) reactor.pump((1000,))
@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock() self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier, notifier=self.mock_hs_notifier,
federation_http_client=mock_federation_client, federation_http_client=self.mock_federation_client,
keyring=mock_keyring, keyring=mock_keyring,
replication_streams={}, replication_streams={},
) )
@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
) )
put_json = self.hs.get_federation_http_client().put_json self.mock_federation_client.put_json.assert_called_once_with(
put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(
@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
put_json = self.hs.get_federation_http_client().put_json self.mock_federation_client.put_json.assert_called_once_with(
put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Any, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from urllib.parse import quote from urllib.parse import quote
@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import create_requester from synapse.types import UserProfile, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config from tests.unittest import override_config
# A spam checker which doesn't implement anything, so create a bare object.
class UselessSpamChecker:
def __init__(self, config: Any):
pass
class UserDirectoryTestCase(unittest.HomeserverTestCase): class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler. """Tests the UserDirectoryHandler.
@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
async def allow_all(user_profile: ProfileInfo) -> bool: async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users. # Allow all users.
return False return False
@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users. # Configure a spam checker that filters all users.
async def block_all(user_profile: ProfileInfo) -> bool: async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy. # All users are spammy.
return True return True
@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0) self.assertEqual(len(s["results"]), 0)
@override_config(
{
"spam_checker": {
"module": "tests.handlers.test_user_directory.UselessSpamChecker"
}
}
)
def test_legacy_spam_checker(self) -> None: def test_legacy_spam_checker(self) -> None:
""" """
A spam checker without the expected method should be ignored. A spam checker without the expected method should be ignored.
@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set()) self.assertEqual(public_users, set())
# Configure a spam checker.
spam_checker = self.hs.get_spam_checker()
# The spam checker doesn't need any methods, so create a bare object.
spam_checker.spam_checker = object()
# We get one search result when searching for user2 by user1. # We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
persistence = self.hs.get_storage_controllers().persistence
self.get_success( assert persistence is not None
self.hs.get_storage_controllers().persistence.persist_event(event, context) self.get_success(persistence.persist_event(event, context))
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None: def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by """We've chosen to simplify the user directory's implementation by

View File

@ -68,11 +68,11 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"]) self.fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({}) self.fed_transport_client.send_transaction = simple_async_mock({})
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=fed_transport_client, federation_transport_client=self.fed_transport_client,
) )
def test_can_register_user(self) -> None: def test_can_register_user(self) -> None:
@ -417,7 +417,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
# #
# Thus we reset the mock, and try sending online local user # Thus we reset the mock, and try sending online local user
# presence again # presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock() self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence # Broadcast local user online presence
self.get_success( self.get_success(
@ -429,9 +429,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
# Check that a presence update was sent as part of a federation transaction # Check that a presence update was sent as part of a federation transaction
found_update = False found_update = False
calls = ( calls = self.fed_transport_client.send_transaction.call_args_list
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
for call in calls: for call in calls:
call_args = call[0] call_args = call[0]
federation_transaction: Transaction = call_args[0] federation_transaction: Transaction = call_args[0]
@ -581,7 +579,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
mocked_remote_join = simple_async_mock( mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id) return_value=("fake-event-id", fake_stream_id)
) )
self.hs.get_room_member_handler()._remote_join = mocked_remote_join self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote" fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to # Given that the join is to be faked, we expect the relevant join event not to

View File

@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(self.access_token) self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
) )
assert user_tuple is not None
self.token_id = user_tuple.token_id self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher. # We need to add email to account before we can create a pusher.
@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
) )
) )
self.pusher = self.get_success( pusher = self.get_success(
self.hs.get_pusherpool().add_or_update_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id, user_id=self.user_id,
access_token=self.token_id, access_token=self.token_id,
@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
data={}, data={},
) )
) )
assert isinstance(pusher, EmailPusher)
self.pusher = pusher
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
) )
# check that the pusher for that email address has been deleted # check that the pusher for that email address has been deleted
pushers = self.get_success( pushers = list(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 0) self.assertEqual(len(pushers), 0)
def test_remove_unlinked_pushers_background_job(self) -> None: def test_remove_unlinked_pushers_background_job(self) -> None:
@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
self.wait_for_background_updates() self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted # Check that all pushers with unlinked addresses were deleted
pushers = self.get_success( pushers = list(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 0) self.assertEqual(len(pushers), 0)
def _check_for_mail(self) -> Tuple[Sequence, Dict]: def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
that notification. that notification.
""" """
# Get the stream ordering before it gets sent # Get the stream ordering before it gets sent
pushers = self.get_success( pushers = list(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering last_stream_ordering = pushers[0].last_stream_ordering
@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
self.pump(10) self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved # It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success( pushers = list(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased # The stream ordering has increased
pushers = self.get_success( pushers = list(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import Any, List, Tuple
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
def test_data(data: Optional[JsonDict]) -> None: def test_data(data: Any) -> None:
self.get_failure( self.get_failure(
self.hs.get_pusherpool().add_or_update_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token) self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent # Get the stream ordering before it gets sent
pushers = self.get_success( pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
) )
pushers = list(pushers) )
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering last_stream_ordering = pushers[0].last_stream_ordering
@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump() self.pump()
# It hasn't succeeded yet, so the stream ordering shouldn't have moved # It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success( pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
) )
pushers = list(pushers) )
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump() self.pump()
# The stream ordering has increased # The stream ordering has increased
pushers = self.get_success( pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
) )
pushers = list(pushers) )
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0].last_stream_ordering last_stream_ordering = pushers[0].last_stream_ordering
@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump() self.pump()
# The stream ordering has increased, again # The stream ordering has increased, again
pushers = self.get_success( pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
) )
pushers = list(pushers) )
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
device_id = user_tuple.device_id device_id = user_tuple.device_id
@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
) )
# Look up the user info for the access token so we can compare the device ID. # Look up the user info for the access token so we can compare the device ID.
lookup_result: TokenLookupResult = self.get_success( lookup_result = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert lookup_result is not None
# Get the user's devices and check it has the correct device ID. # Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token) channel = self.make_request("GET", "/pushers", access_token=access_token)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List, Optional from typing import Any, List, Optional, Sequence
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
) )
# this is the point in the DAG where we make a fork # this is the point in the DAG where we make a fork
fork_point: List[str] = self.get_success( fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
) )
@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_event = self.get_success( pl_event = self.get_success(
inject_event( inject_event(
self.hs, self.hs,
prev_event_ids=prev_events, prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
state_key="", state_key="",
sender=self.user_id, sender=self.user_id,
@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
) )
# this is the point in the DAG where we make a fork # this is the point in the DAG where we make a fork
fork_point: List[str] = self.get_success( fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
) )
@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
e = self.get_success( e = self.get_success(
inject_event( inject_event(
self.hs, self.hs,
prev_event_ids=prev_events, prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
state_key="", state_key="",
sender=self.user_id, sender=self.user_id,

View File

@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
room_id = self.helper.create_room_as("@bob:test") room_id = self.helper.create_room_as("@bob:test")
# Mark the room as partial-stated. # Mark the room as partial-stated.
self.get_success( self.get_success(
self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
) )
worker = self.make_worker_hs("synapse.app.generic_worker") worker = self.make_worker_hs("synapse.app.generic_worker")

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import Mock
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self) -> None: def test_typing(self) -> None:
typing = self.hs.get_typing_handler() typing = self.hs.get_typing_handler()
assert isinstance(typing, TypingWriterHandler)
self.reconnect() self.reconnect()
@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
sends the proper position and RDATA). sends the proper position and RDATA).
""" """
typing = self.hs.get_typing_handler() typing = self.hs.get_typing_handler()
assert isinstance(typing, TypingWriterHandler)
self.reconnect() self.reconnect()

View File

@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
# ... updating the cache ID gen on the master still shouldn't cause the # ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up. # deferred to wake up.
assert store._cache_id_gen is not None
ctx = store._cache_id_gen.get_next() ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__()) self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None)) self.get_success(ctx.__aexit__(None, None, None))

View File

@ -16,6 +16,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.handlers.typing import TypingWriterHandler
from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
token = self.login("user3", "pass") token = self.login("user3", "pass")
typing_handler = self.hs.get_typing_handler() typing_handler = self.hs.get_typing_handler()
assert isinstance(typing_handler, TypingWriterHandler)
sent_on_1 = False sent_on_1 = False
sent_on_2 = False sent_on_2 = False

View File

@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success( user_dict = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
assert user_dict is not None
token_id = user_dict.token_id token_id = user_dict.token_id
self.get_success( self.get_success(

View File

@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory() event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler() event_creation_handler = self.hs.get_event_creation_handler()
storage_controllers = self.hs.get_storage_controllers() persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
# Create two rooms, one with a local user only and one with both a local # Create two rooms, one with a local user only and one with both a local
# and remote user. # and remote user.
@ -2940,7 +2941,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
self.get_success(storage_controllers.persistence.persist_event(event, context)) self.get_success(persistence.persist_event(event, context))
# Now get rooms # Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True) self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
async def check_username(username: str) -> bool: async def check_username(
if username == "allowed": localpart: str,
return True guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
inhibit_user_in_use_error: bool = False,
) -> None:
if localpart == "allowed":
return
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",
@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
) )
handler = self.hs.get_registration_handler() handler = self.hs.get_registration_handler()
handler.check_username = check_username handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None: def test_username_available(self) -> None:
""" """

View File

@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
return {} return {}
# Register a mock that will return the expected result depending on the remote. # Register a mock that will return the expected result depending on the remote.
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
# Check that we've got the correct response from the client-side endpoint. # Check that we've got the correct response from the client-side endpoint.
self._test_status( self._test_status(

View File

@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None: def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id), "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine # type: ignore[assignment]
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)

View File

@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler) self.presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = make_awaitable(None) self.presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red",
federation_http_client=None, federation_http_client=None,
federation_client=Mock(), federation_client=Mock(),
presence_handler=presence_handler, presence_handler=self.presence_handler,
) )
return hs return hs
@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False}) @unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None: def test_put_presence_disabled(self) -> None:
@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) self.assertEqual(self.presence_handler.set_state.call_count, 0)

View File

@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self) -> None: def test_POST_guest_registration(self) -> None:
self.hs.config.key.macaroon_secret_key = "test" self.hs.config.key.macaroon_secret_key = b"test"
self.hs.config.registration.allow_guest_access = True self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
""" """
user_id = self.register_user("kermit_delta", "user") user_id = self.register_user("kermit_delta", "user")
self.hs.config.account_validity.startup_job_max_delta = self.max_delta self.hs.config.account_validity.account_validity_startup_job_max_delta = (
self.max_delta
)
now_ms = self.hs.get_clock().time_msec() now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing()) self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
assert res is not None
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period) self.assertLessEqual(res, now_ms + self.validity_period)

View File

@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event, which should be filtered out at the end of the test. # Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token) resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id") first_event_id = resp.get("event_id")
assert isinstance(first_event_id, str)
# Advance the time by 2 days. We're using the default retention policy, therefore # Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid. # after this the first event will still be valid.
@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out. # Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token) resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id") valid_event_id = resp.get("event_id")
assert isinstance(valid_event_id, str)
# Advance the time by another 2 days. After this, the first event should be # Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one. # outdated but not the second one.
@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Check that we can still access state events that were sent before the event that # Check that we can still access state events that were sent before the event that
# has been purged. # has been purged.
self.get_event(room_id, create_event.event_id) self.get_event(room_id, bool(create_event))
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True)) event = self.get_success(self.store.get_event(event_id, allow_none=True))
@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNone(event) self.assertIsNone(event)
return {} return {}
self.assertIsNotNone(event) assert event is not None
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now) serialized = self.serializer.serialize_event(event, time_now)

View File

@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test. # can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
) )
@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test. # can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
) )
@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
) )
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
persistence = self._storage_controllers.persistence
assert persistence is not None
self.get_success( self.get_success(
self._storage_controllers.persistence.persist_event( persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers) event, EventContext.for_outlier(self._storage_controllers)
) )
) )

View File

@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self) -> None: def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server.""" """Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler() identity_handler = self.hs.get_identity_handler()
identity_handler.lookup_3pid = Mock( identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
side_effect=AssertionError("This should not get called") side_effect=AssertionError("This should not get called")
) )
@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
event_source.get_new_events( event_source.get_new_events(
user=UserID.from_string(self.other_user_id), user=UserID.from_string(self.other_user_id),
from_key=0, from_key=0,
limit=None, limit=10,
room_ids=[room_id], room_ids=[room_id],
is_guest=False, is_guest=False,
) )
@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id, self.banned_user_id,
) )
) )
assert event is not None
self.assertEqual( self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name} event.content, {"membership": "join", "displayname": original_display_name}
) )
@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id, self.banned_user_id,
) )
) )
assert event is not None
self.assertEqual( self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name} event.content, {"membership": "join", "displayname": original_display_name}
) )

View File

@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.room_id, EventTypes.Tombstone, "" self.room_id, EventTypes.Tombstone, ""
) )
) )
self.assertIsNotNone(tombstone_event) assert tombstone_event is not None
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
# Check that the new room exists. # Check that the new room exists.

View File

@ -24,6 +24,7 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import ( from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices, ResourceLimitsServerNotices,
) )
from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -58,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.server_notices_sender = self.hs.get_server_notices_sender() server_notices_sender = self.hs.get_server_notices_sender()
assert isinstance(server_notices_sender, ServerNoticesSender)
# relying on [1] is far from ideal, but the only case where # relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test, # ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ... # general code should never have a reason to do so ...
self._rlsn = self.server_notices_sender._server_notices[1] rlsn = list(server_notices_sender._server_notices)[1]
if not isinstance(self._rlsn, ResourceLimitsServerNotices): assert isinstance(rlsn, ResourceLimitsServerNotices)
raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000) return_value=make_awaitable(1000)
@ -101,25 +103,29 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=make_awaitable(None)
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() maybe_get_notice_room_for_user = (
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
)
assert isinstance(maybe_get_notice_room_for_user, Mock)
maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once() self._send_notice.assert_called_once()
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
""" """
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"), side_effect=ResourceLimitError(403, "foo"),
) )
@ -127,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
@ -139,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"), side_effect=ResourceLimitError(403, "foo"),
) )
@ -152,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=make_awaitable(None)
) )
@ -165,7 +171,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=make_awaitable(None)
) )
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
@ -183,7 +189,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when server is over MAU limit and alerting is suppressed, then Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room an alert message is not sent into the room
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@ -198,7 +204,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@ -217,21 +223,21 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
When the room is already in a blocked state, test that when alerting When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state. is suppressed that the room is returned to an unblocked state.
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
), ),
) )
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
return_value=make_awaitable((True, [])) return_value=make_awaitable((True, []))
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -262,16 +268,18 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager() self.server_notices_manager = self.hs.get_server_notices_manager()
self.event_source = self.hs.get_event_sources() self.event_source = self.hs.get_event_sources()
server_notices_sender = self.hs.get_server_notices_sender()
assert isinstance(server_notices_sender, ServerNoticesSender)
# relying on [1] is far from ideal, but the only case where # relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test, # ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ... # general code should never have a reason to do so ...
self._rlsn = self.server_notices_sender._server_notices[1] rlsn = list(server_notices_sender._server_notices)[1]
if not isinstance(self._rlsn, ResourceLimitsServerNotices): assert isinstance(rlsn, ResourceLimitsServerNotices)
raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn = rlsn
self.user_id = "@user_id:test" self.user_id = "@user_id:test"

View File

@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Persist the event which should invalidate or prefill the # Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values. # `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self.get_success( self.get_success(
persistence.persist_event( persistence.persist_event(
event, event,

View File

@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
""" """
persist_events_store = self.hs.get_datastores().persist_events persist_events_store = self.hs.get_datastores().persist_events
assert persist_events_store is not None
for e in events: for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering e.internal_metadata.stream_ordering = self._next_stream_ordering
@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn: LoggingTransaction) -> None: def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events # We need to persist the events to the events and state_events
# tables. # tables.
assert persist_events_store is not None
persist_events_store._store_event_txn( persist_events_store._store_event_txn(
txn, txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events], [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)] self.requester, events_and_context=[(event, context)]
) )
) )
state1 = set(self.get_success(context.get_current_state_ids()).values()) state_ids1 = self.get_success(context.get_current_state_ids())
assert state_ids1 is not None
state1 = set(state_ids1.values())
event, context = self.get_success( event, context = self.get_success(
event_handler.create_event( event_handler.create_event(
@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)] self.requester, events_and_context=[(event, context)]
) )
) )
state2 = set(self.get_success(context.get_current_state_ids()).values()) state_ids2 = self.get_success(context.get_current_state_ids())
assert state_ids2 is not None
state2 = set(state_ids2.values())
# Delete the chain cover info. # Delete the chain cover info.

View File

@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
persist_events = hs.get_datastores().persist_events
assert persist_events is not None
self.persist_events = persist_events
def test_get_prev_events_for_room(self) -> None: def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local" room_id = "@ROOM:local"
@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
}, },
) )
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ [
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
) )
# Insert all events apart from 'B' # Insert all events apart from 'B'
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ [
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False}, updatevalues={"has_auth_chain_index": False},
) )
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
) )

View File

@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None: ) -> None:
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None: ) -> None:
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persistence = persistence
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self) -> None: def test_remote_user_rooms_cache_invalidated(self) -> None:

View File

@ -16,8 +16,6 @@ import signedjson.key
import signedjson.types import signedjson.types
import unpaddedbase64 import unpaddedbase64
from twisted.internet.defer import Deferred
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
import tests.unittest import tests.unittest
@ -44,7 +42,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1" key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2" key_id_2 = "ed25519:KEY_ID_2"
d = store.store_server_verify_keys( self.get_success(
store.store_server_verify_keys(
"from_server", "from_server",
10, 10,
[ [
@ -52,12 +51,17 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
("server1", key_id_2, FetchKeyResult(KEY_2, 200)), ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
], ],
) )
self.get_success(d)
d = store.get_server_verify_keys(
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
) )
res = self.get_success(d)
res = self.get_success(
store.get_server_verify_keys(
[
("server1", key_id_1),
("server1", key_id_2),
("server1", "ed25519:key3"),
]
)
)
self.assertEqual(len(res.keys()), 3) self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)] res1 = res[("server1", key_id_1)]
@ -82,7 +86,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1" key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2" key_id_2 = "ed25519:key2"
d = store.store_server_verify_keys( self.get_success(
store.store_server_verify_keys(
"from_server", "from_server",
0, 0,
[ [
@ -90,10 +95,11 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
], ],
) )
self.get_success(d) )
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(
res = self.get_success(d) store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
)
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)] res1 = res[("srv1", key_id_1)]
@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200) self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit # we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)]) res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1) self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(
res = self.get_success(d) store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
)
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)] res1 = res[("srv1", key_id_1)]

View File

@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id, "m.room.create", "" self.room_id, "m.room.create", ""
) )
) )
self.assertIsNotNone(create_event) assert create_event is not None
# Purge everything before this topological token # Purge everything before this topological token
self.get_success( self.get_success(

View File

@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler() self.room_creator = homeserver.get_room_creation_handler()
self.persist_event_storage_controller = ( persist_event_storage_controller = self.hs.get_storage_controllers().persistence
self.hs.get_storage_controllers().persistence assert persist_event_storage_controller is not None
) self.persist_event_storage_controller = persist_event_storage_controller
# Create a test user # Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID) self.ourUser = UserID.from_string(OUR_USER_ID)

View File

@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
"content": {"msgtype": "m.text", "body": 2}, "content": {"msgtype": "m.text", "body": 2},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"depth": prev_event.depth + 1,
"prev_events": prev_event_ids, "prev_events": prev_event_ids,
"origin_server_ts": self.clock.time_msec(), "origin_server_ts": self.clock.time_msec(),
} }
@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_state_map, prev_state_map,
for_verification=False, for_verification=False,
), ),
depth=event_dict["depth"], depth=prev_event.depth + 1,
) )
) )

View File

@ -16,7 +16,7 @@ from typing import List
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
room_id=self.room_id, room_id=self.room_id,
from_key=self.from_token.room_key, from_key=self.from_token.room_key,
to_key=None, to_key=None,
direction="f", direction=Direction.FORWARDS,
limit=10, limit=10,
event_filter=Filter(self.hs, filter), event_filter=Filter(self.hs, filter),
) )

View File

@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn from synapse.storage.database import make_conn
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IncorrectDatabaseSetup from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
def test_safe_locale(self) -> None: def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0] database = self.hs.get_datastores().databases[0]
assert isinstance(database.engine, PostgresEngine)
db_conn = make_conn(database._database_config, database.engine, "test_unsafe") db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn: with db_conn.cursor() as txn:

View File

@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union from typing import Collection, List, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError from synapse.api.errors import FederationError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.handlers.device import DeviceListUpdater
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.server import HomeServer from synapse.server import HomeServer
@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) -> None: ) -> None:
pass pass
federation_event_handler._check_event_auth = _check_event_auth federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
self.client = self.hs.get_federation_client() self.client = self.hs.get_federation_client()
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
lambda dest, pdus, **k: succeed(pdus) async def _check_sigs_and_hash_for_pulled_events_and_fetch(
) dest: str, pdus: Collection[EventBase], room_version: RoomVersion
) -> List[EventBase]:
return list(pdus)
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
# Send the join, it should return None (which is not an error) # Send the join, it should return None (which is not an error)
self.assertEqual( self.assertEqual(
@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register the mock on the federation client. # Register the mock on the federation client.
federation_client = self.hs.get_federation_client() federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock(side_effect=query_user_devices) federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.
device_list_updater = self.hs.get_device_handler().device_list_updater device_list_updater = self.hs.get_device_handler().device_list_updater
assert isinstance(device_list_updater, DeviceListUpdater)
self.get_success( self.get_success(
device_list_updater.incoming_device_list_update( device_list_updater.incoming_device_list_update(
origin=remote_origin, origin=remote_origin,
@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.hs.get_federation_client() federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock( federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable( return_value=make_awaitable(
{ {
"user_id": remote_user_id, "user_id": remote_user_id,
@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
keys = self.get_success( keys = self.get_success(
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
) )
self.assertTrue(remote_user_id in keys) self.assertIn(remote_user_id, keys)
key = keys[remote_user_id]
assert key is not None
# Check that the master key is the one returned by the mock. # Check that the master key is the one returned by the mock.
master_key = keys[remote_user_id]["master"] master_key = key["master"]
self.assertEqual(len(master_key["keys"]), 1) self.assertEqual(len(master_key["keys"]), 1)
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
self.assertTrue(remote_master_key in master_key["keys"].values()) self.assertTrue(remote_master_key in master_key["keys"].values())
# Check that the self-signing key is the one returned by the mock. # Check that the self-signing key is the one returned by the mock.
self_signing_key = keys[remote_user_id]["self_signing"] self_signing_key = key["self_signing"]
self.assertEqual(len(self_signing_key["keys"]), 1) self.assertEqual(len(self_signing_key["keys"]), 1)
self.assertTrue( self.assertTrue(
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),

View File

@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
If time doesn't move, don't error out. If time doesn't move, don't error out.
""" """
past_stats = [ past_stats = [
(self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
] ]
stats: JsonDict = {} stats: JsonDict = {}
self.get_success(phone_stats_home(self.hs, stats, past_stats)) self.get_success(phone_stats_home(self.hs, stats, past_stats))

View File

@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler = self.hs.get_event_creation_handler() self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers() self._storage_controllers = self.hs.get_storage_controllers()
assert self._storage_controllers.persistence is not None
self._persistence = self._storage_controllers.persistence
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@ -179,9 +181,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(self._persistence.persist_event(event, context))
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_room_member( def _inject_room_member(
@ -208,9 +208,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(self._persistence.persist_event(event, context))
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_message( def _inject_message(
@ -233,9 +231,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(self._persistence.persist_event(event, context))
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_outlier(self) -> EventBase: def _inject_outlier(self) -> EventBase:
@ -253,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
self.get_success( self.get_success(
self._storage_controllers.persistence.persist_event( self._persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers) event, EventContext.for_outlier(self._storage_controllers)
) )
) )

View File

@ -361,7 +361,9 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1 store.db_pool.updates.do_next_background_update(False), by=0.1
) )
def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
""" """
Make and return a homeserver. Make and return a homeserver.

View File

@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump() self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, failure_ts) self.assertEqual(new_timings.retry_last_ts, failure_ts)
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump() self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, retry_ts) self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual( self.assertGreaterEqual(