Add missing type hints to tests.replication. (#14987)

This commit is contained in:
Patrick Cloke 2023-02-06 09:55:00 -05:00 committed by GitHub
parent b275763c65
commit 156cd88eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 193 additions and 149 deletions

1
changelog.d/14987.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -104,6 +104,9 @@ disallow_untyped_defs = True
[mypy-tests.push.*]
disallow_untyped_defs = True
[mypy-tests.replication.*]
disallow_untyped_defs = True
[mypy-tests.rest.*]
disallow_untyped_defs = True

View File

@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler,
)
self._client_transport = None
self._server_transport = None
self._client_transport: Optional[FakeTransport] = None
self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
def reconnect(self) -> None:
if self._client_transport:
self.client.close()
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
self.server.close()
def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
def request_factory(*args, **kwargs):
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True}
return base
def setUp(self):
def setUp(self) -> None:
super().setUp()
# build a replication server
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_resource(self):
def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def _handle_http_replication_attempt(self, hs, repl_port):
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
def connect_any_redis_attempts(self):
def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
client_protocol = client_factory.buildProtocol(client_address)
server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
async def on_rdata(self, stream_name, instance_name, token, rows):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
def add_subscriber(self, conn, channel: bytes):
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
def publish(self, conn, channel: bytes, msg) -> int:
def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server
self._reader = hiredis.Reader()
def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)
# We might get multiple messages in one packet.
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""
# We currently only support pub/sub.
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG")
else:
raise Exception(f"Unknown command: {command}")
raise Exception(f"Unknown command: {command!r}")
def send(self, msg):
def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw)
self.transport.flush()
def encode(self, obj):
def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
def connectionLost(self, reason):
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)

View File

@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)

View File

@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Optional
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
hs = self.setup_test_homeserver(federation_client=Mock())
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.reconnect()
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence
def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump(0.1)
def check(self, method, args, expected_result=None):
def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:

View File

@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Optional
from typing import Any, Callable, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from synapse.util import Clock
from tests.server import FakeTransport
@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
def dict_equals(self, other):
def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them
def patch__eq__(cls):
def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals
cls.__eq__ = dict_equals # type: ignore[assignment]
def unpatch():
def unpatch() -> None:
if eq is not None:
cls.__eq__ = eq
cls.__eq__ = eq # type: ignore[assignment]
return unpatch
@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = EventsWorkerStore
def setUp(self):
def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super().setUp()
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.get_success(
self.master_store.store_room(
@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
)
def tearDown(self):
def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches]
def test_get_latest_event_ids_in_room(self):
def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
def test_redactions(self):
def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
def test_backfilled_redactions(self):
def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
self.check("get_event", [msg.event_id], redacted)
def test_invites(self):
def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
)
@parameterized.expand([(True,), (False,)])
def test_push_actions_for_user(self, send_receipt: bool):
def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
),
)
def test_get_rooms_for_user_with_stream_ordering(self):
def test_get_rooms_for_user_with_stream_ordering(self) -> None:
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream
"""
@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
self,
) -> None:
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
self.get_success(
self._storage_controllers.persistence.persist_events(
[(j2, j2ctx), (msg, msgctx)]
)
)
self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
assert j2.internal_metadata.stream_ordering is not None
@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
def persist(self, backfill=False, **kwargs) -> FrozenEvent:
def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
"""
Returns:
The event that was persisted.
@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
self.persistance.persist_events([(event, context)], backfilled=True)
)
else:
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
self.get_success(self.persistance.persist_event(event, context))
return event
def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
type="m.room.message",
key=None,
sender: str = USER_ID,
room_id: str = ROOM_ID,
type: str = "m.room.message",
key: Optional[str] = None,
internal: Optional[dict] = None,
depth=None,
prev_events: Optional[list] = None,
auth_events: Optional[list] = None,
prev_state: Optional[list] = None,
redacts=None,
depth: Optional[int] = None,
prev_events: Optional[List[Tuple[str, dict]]] = None,
auth_events: Optional[List[str]] = None,
prev_state: Optional[List[str]] = None,
redacts: Optional[str] = None,
push_actions: Iterable = frozenset(),
**content,
):
**content: object,
) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []

View File

@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
def test_update_function_room_account_data_limit(self) -> None:
"""Test replication with many room account data updates"""
store = self.hs.get_datastores().main
@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self):
def test_update_function_global_account_data_limit(self) -> None:
"""Test replication with many global account data updates"""
store = self.hs.get_datastores().main

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import Any, List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
)
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
def test_update_function_event_row_limit(self) -> None:
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
def test_update_function_huge_state_change(self) -> None:
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to
@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
def test_backwards_stream_id(self):
def test_backwards_stream_id(self) -> None:
"""
Test that RDATA that comes after the current position should be discarded.
"""
@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
) -> EventBase:
if sender is None:
sender = self.user_id

View File

@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
def test_catchup(self):
def test_catchup(self) -> None:
"""Basic test of catchup on reconnect
Makes sure that updates sent while we are offline are received later.

View File

@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
hijack_auth = True
user_id = "@bob:test"
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main

View File

@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def _build_replication_data_handler(self) -> Mock:
self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
return self.mock_handler
def test_typing(self):
def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
self.reconnect()
@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
self.mock_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
self.mock_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids)
def test_reset(self):
def test_reset(self) -> None:
"""
Test what happens when a typing stream resets.
@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code.
self.test_handler.on_rdata.reset_mock()
self.test_handler.on_rdata.assert_not_called()
self.mock_handler.on_rdata.reset_mock()
self.mock_handler.on_rdata.assert_not_called()
# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]

View File

@ -21,12 +21,12 @@ from tests.unittest import TestCase
class ParseCommandTestCase(TestCase):
def test_parse_one_word_command(self):
def test_parse_one_word_command(self) -> None:
line = "REPLICATE"
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, ReplicateCommand)
def test_parse_rdata(self):
def test_parse_rdata(self) -> None:
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)
@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
def test_parse_rdata_batch(self):
def test_parse_rdata_batch(self) -> None:
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand)

View File

@ -16,15 +16,17 @@ from typing import Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
return proto, transport
def test_relay(self):
def test_relay(self) -> None:
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it.
"""

View File

@ -13,7 +13,11 @@
# limitations under the License.
import logging
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.client import register
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
{"auth": {"session": session, "type": "m.login.dummy"}},
)
def test_no_auth(self):
def test_no_auth(self) -> None:
"""With no authentication the request should finish."""
channel = self._test_register()
self.assertEqual(channel.code, 200)
@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self):
def test_missing_auth(self) -> None:
"""If the main process expects a secret that is not provided, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"worker_replication_secret": "wrong-secret",
}
)
def test_unauthorized(self):
def test_unauthorized(self) -> None:
"""If the main process receives the wrong secret, an error results."""
channel = self._test_register()
self.assertEqual(channel.code, 500)
@override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self):
def test_authorized(self) -> None:
"""The request should finish when the worker provides the authentication header."""
channel = self._test_register()
self.assertEqual(channel.code, 200)

View File

@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):
def test_register_single_worker(self) -> None:
"""Test that registration works when using a single generic worker."""
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
def test_register_multi_worker(self) -> None:
"""Test that registration works when using multiple generic workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")

View File

@ -14,10 +14,14 @@
from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
config["federation_sender_instances"] = ["federation_sender1"]
return config
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
def test_federation_ack_sent(self):
def test_federation_ack_sent(self) -> None:
"""A FEDERATION_ACK should be sent back after each RDATA federation
This test checks that the federation sender is correctly sending back

View File

@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets,
]
def test_send_event_single_sender(self):
def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a
new event.
"""
@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
def test_send_event_sharded(self):
def test_send_event_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new events.
"""
@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
def test_send_typing_sharded(self):
def test_send_typing_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends
new typing EDUs.
"""
@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
def create_room_with_remote_server(
self, user: str, token: str, remote_server: str = "other_server"
) -> str:
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler()

View File

@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
synapse.rest.admin.register_servlets,
]
def test_module_cache_full_invalidation(self):
def test_module_cache_full_invalidation(self) -> None:
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)

View File

@ -18,12 +18,14 @@ from typing import Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
def test_basic(self):
def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
def test_download_simple_file_race(self):
def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self):
def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory():
def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
@ -263,6 +265,6 @@ def _build_test_server(
return server_tls_factory.buildProtocol(None)
def _log_request(request):
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)

View File

@ -15,9 +15,12 @@ import logging
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def _create_pusher_and_send_msg(self, localpart):
def _create_pusher_and_send_msg(self, localpart: str) -> str:
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass")
@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id
def test_send_push_single_worker(self):
def test_send_push_single_worker(self) -> None:
"""Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = (
@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
],
)
def test_send_push_multiple_workers(self):
def test_send_push_multiple_workers(self) -> None:
"""Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = (

View File

@ -14,9 +14,13 @@
import logging
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastores().main
def default_config(self):
def default_config(self) -> dict:
conf = super().default_config()
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
def _create_room(self, room_id: str, user_id: str, tok: str):
def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
"""Create a room with given room_id"""
# We control the room ID generation by patching out the
@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self):
def test_basic(self) -> None:
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
"""
@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)
def test_vector_clock_token(self):
def test_vector_clock_token(self) -> None:
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage.
"""