Add missing type hints to tests.replication. (#14987)

This commit is contained in:
Patrick Cloke 2023-02-06 09:55:00 -05:00 committed by GitHub
parent b275763c65
commit 156cd88eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 193 additions and 149 deletions

1
changelog.d/14987.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -104,6 +104,9 @@ disallow_untyped_defs = True
[mypy-tests.push.*] [mypy-tests.push.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.replication.*]
disallow_untyped_defs = True
[mypy-tests.rest.*] [mypy-tests.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
) )
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis: if not hiredis:
skip = "Requires hiredis" skip = "Requires hiredis"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server # build a replication server
server_factory = ReplicationStreamProtocolFactory(hs) server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer() self.streamer = hs.get_replication_streamer()
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler, repl_handler,
) )
self._client_transport = None self._client_transport: Optional[FakeTransport] = None
self._server_transport = None self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]: def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict() d = super().create_resource_dict()
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765" config["worker_replication_http_port"] = "8765"
return config return config
def _build_replication_data_handler(self): def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs) return TestReplicationDataHandler(self.worker_hs)
def reconnect(self): def reconnect(self) -> None:
if self._client_transport: if self._client_transport:
self.client.close() self.client.close()
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor) self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport) self.server.makeConnection(self._server_transport)
def disconnect(self): def disconnect(self) -> None:
if self._client_transport: if self._client_transport:
self._client_transport = None self._client_transport = None
self.client.close() self.client.close()
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None self._server_transport = None
self.server.close() self.server.close()
def replicate(self): def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
wait for the replication to occur. wait for the replication to occur.
""" """
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = [] requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory real_request_factory = channel.requestFactory
def request_factory(*args, **kwargs): def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs) request = real_request_factory(*args, **kwargs)
requests.append(request) requests.append(request)
return request return request
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates( def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str self, request: SynapseRequest, stream_name: str
): ) -> None:
"""Asserts that the given request is a HTTP replication request for """Asserts that the given request is a HTTP replication request for
fetching updates for given stream. fetching updates for given stream.
""" """
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True} base["redis"] = {"enabled": True}
return base return base
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
# build a replication server # build a replication server
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765), lambda: self._handle_http_replication_attempt(self.hs, 8765),
) )
def create_test_resource(self): def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`.""" """Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP # We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all # replication servlets, without having to explicitly do that in all
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource return resource
def make_worker_hs( def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer: ) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation """Make a new worker HS instance, correctly connecting replcation
stream to the master HS. stream to the master HS.
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765" config["worker_replication_http_port"] = "8765"
return config return config
def replicate(self): def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
wait for the replication to occur. wait for the replication to occur.
""" """
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
self.pump() self.pump()
def _handle_http_replication_attempt(self, hs, repl_port): def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP """Handles a connection attempt to the given HS replication HTTP
listener on the given port. listener on the given port.
""" """
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the # inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection. # code that requested the TCP connection.
def connect_any_redis_attempts(self): def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a """If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a redis server. We don't want to use a real Redis server so we use a
fake one. fake one.
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost") self.assertEqual(host, "localhost")
self.assertEqual(port, 6379) self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None) client_address = IPv4Address("TCP", "127.0.0.1", 6379)
server_protocol = self._redis_server.buildProtocol(None) client_protocol = client_factory.buildProtocol(client_address)
server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol server_protocol, self.reactor, client_protocol
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples # list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = [] self.received_rdata_rows: List[Tuple[str, int, Any]] = []
async def on_rdata(self, stream_name, instance_name, token, rows): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows) await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer: class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.""" """A fake Redis server for pub/sub."""
def __init__(self): def __init__(self) -> None:
self._subscribers_by_channel: Dict[ self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"] bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set) ] = defaultdict(set)
def add_subscriber(self, conn, channel: bytes): def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE""" """A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn) self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn): def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection""" """A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values(): for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn) subscribers.discard(conn)
def publish(self, conn, channel: bytes, msg) -> int: def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers.""" """A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]: for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg]) sub.send(["message", channel, msg])
return len(self._subscribers_by_channel) return len(self._subscribers_by_channel)
def buildProtocol(self, addr): def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self) return FakeRedisPubSubProtocol(self)
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server self._server = server
self._reader = hiredis.Reader() self._reader = hiredis.Reader()
def dataReceived(self, data): def dataReceived(self, data: bytes) -> None:
self._reader.feed(data) self._reader.feed(data)
# We might get multiple messages in one packet. # We might get multiple messages in one packet.
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:]) self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args): def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client.""" """Received a Redis command from the client."""
# We currently only support pub/sub. # We currently only support pub/sub.
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG") self.send("PONG")
else: else:
raise Exception(f"Unknown command: {command}") raise Exception(f"Unknown command: {command!r}")
def send(self, msg): def send(self, msg: object) -> None:
"""Send a message back to the client.""" """Send a message back to the client."""
assert self.transport is not None assert self.transport is not None
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw) self.transport.write(raw)
self.transport.flush() self.transport.flush()
def encode(self, obj): def encode(self, obj: object) -> str:
"""Encode an object to its Redis format. """Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples. Supports: strings/bytes, integers and list/tuples.
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
def connectionLost(self, reason): def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self) self._server.remove_subscriber(self)

View File

@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase): class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation.""" """Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self): def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`.""" """Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs) resource = JsonResource(self.hs)

View File

@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable, Optional
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase from tests.replication._base import BaseStreamTestCase
class BaseSlavedStoreTestCase(BaseStreamTestCase): class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
hs = self.setup_test_homeserver(federation_client=Mock()) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
return hs
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
self.reconnect() self.reconnect()
self.master_store = hs.get_datastores().main self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers() persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence
def replicate(self): def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
wait for the replication to occur. wait for the replication to occur.
""" """
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
self.pump(0.1) self.pump(0.1)
def check(self, method, args, expected_result=None): def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args)) master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None: if expected_result is not None:

View File

@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Optional from typing import Any, Callable, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import ( from synapse.storage.databases.main.event_push_actions import (
NotifCounts, NotifCounts,
RoomNotifCounts, RoomNotifCounts,
@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition from synapse.types import PersistedEventPosition
from synapse.util import Clock
from tests.server import FakeTransport from tests.server import FakeTransport
@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def dict_equals(self, other): def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json()) me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json()) them = encode_canonical_json(other.get_pdu_json())
return me == them return me == them
def patch__eq__(cls): def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None) eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals cls.__eq__ = dict_equals # type: ignore[assignment]
def unpatch(): def unpatch() -> None:
if eq is not None: if eq is not None:
cls.__eq__ = eq cls.__eq__ = eq # type: ignore[assignment]
return unpatch return unpatch
@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = EventsWorkerStore STORE_TYPE = EventsWorkerStore
def setUp(self): def setUp(self) -> None:
# Patch up the equality operator for events so that we can check # Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual # whether lists of events match using assertEqual
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
return super().setUp() super().setUp()
def prepare(self, *args, **kwargs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(*args, **kwargs) super().prepare(reactor, clock, hs)
self.get_success( self.get_success(
self.master_store.store_room( self.master_store.store_room(
@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
) )
) )
def tearDown(self): def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches] [unpatch() for unpatch in self.unpatches]
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID) create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate() self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
self.replicate() self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
def test_redactions(self): def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
) )
self.check("get_event", [msg.event_id], redacted) self.check("get_event", [msg.event_id], redacted)
def test_backfilled_redactions(self): def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
) )
self.check("get_event", [msg.event_id], redacted) self.check("get_event", [msg.event_id], redacted)
def test_invites(self): def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
) )
@parameterized.expand([(True,), (False,)]) @parameterized.expand([(True,), (False,)])
def test_push_actions_for_user(self, send_receipt: bool): def test_push_actions_for_user(self, send_receipt: bool) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist( self.persist(
@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
), ),
) )
def test_get_rooms_for_user_with_stream_ordering(self): def test_get_rooms_for_user_with_stream_ordering(self) -> None:
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream by rows in the events stream
""" """
@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
) )
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
self,
) -> None:
"""Check that current_state invalidation happens correctly with multiple events """Check that current_state invalidation happens correctly with multiple events
in the persistence batch. in the persistence batch.
@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
) )
msg, msgctx = self.build_event() msg, msgctx = self.build_event()
self.get_success( self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
self._storage_controllers.persistence.persist_events(
[(j2, j2ctx), (msg, msgctx)]
)
)
self.replicate() self.replicate()
assert j2.internal_metadata.stream_ordering is not None assert j2.internal_metadata.stream_ordering is not None
@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0 event_id = 0
def persist(self, backfill=False, **kwargs) -> FrozenEvent: def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
""" """
Returns: Returns:
The event that was persisted. The event that was persisted.
@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
if backfill: if backfill:
self.get_success( self.get_success(
self._storage_controllers.persistence.persist_events( self.persistance.persist_events([(event, context)], backfilled=True)
[(event, context)], backfilled=True
)
) )
else: else:
self.get_success( self.get_success(self.persistance.persist_event(event, context))
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def build_event( def build_event(
self, self,
sender=USER_ID, sender: str = USER_ID,
room_id=ROOM_ID, room_id: str = ROOM_ID,
type="m.room.message", type: str = "m.room.message",
key=None, key: Optional[str] = None,
internal: Optional[dict] = None, internal: Optional[dict] = None,
depth=None, depth: Optional[int] = None,
prev_events: Optional[list] = None, prev_events: Optional[List[Tuple[str, dict]]] = None,
auth_events: Optional[list] = None, auth_events: Optional[List[str]] = None,
prev_state: Optional[list] = None, prev_state: Optional[List[str]] = None,
redacts=None, redacts: Optional[str] = None,
push_actions: Iterable = frozenset(), push_actions: Iterable = frozenset(),
**content, **content: object,
): ) -> Tuple[EventBase, EventContext]:
prev_events = prev_events or [] prev_events = prev_events or []
auth_events = auth_events or [] auth_events = auth_events or []
prev_state = prev_state or [] prev_state = prev_state or []

View File

@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase): class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self): def test_update_function_room_account_data_limit(self) -> None:
"""Test replication with many room account data updates""" """Test replication with many room account data updates"""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows) self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self): def test_update_function_global_account_data_limit(self) -> None:
"""Test replication with many global account data updates""" """Test replication with many global account data updates"""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import Any, List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
) )
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event from tests.test_utils.event_injection import inject_event, inject_member_event
@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass") self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass") self.user_tok = self.login("u1", "pass")
@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.room_id = self.helper.create_room_as(tok=self.user_tok) self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear() self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self): def test_update_function_event_row_limit(self) -> None:
"""Test replication with many non-state events """Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of Checks that all events are correctly replicated when there are lots of
@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows) self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self): def test_update_function_huge_state_change(self) -> None:
"""Test replication with many state events """Test replication with many state events
Ensures that all events are correctly replicated when there are lots of Ensures that all events are correctly replicated when there are lots of
@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# "None" indicates the state has been deleted # "None" indicates the state has been deleted
self.assertIsNone(sr.event_id) self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self): def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids.""" """Test replication with many state events over several stream ids."""
# we want to generate lots of state changes, but for this test, we want to # we want to generate lots of state changes, but for this test, we want to
@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows) self.assertEqual([], received_rows)
def test_backwards_stream_id(self): def test_backwards_stream_id(self) -> None:
""" """
Test that RDATA that comes after the current position should be discarded. Test that RDATA that comes after the current position should be discarded.
""" """
@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
event_count = 0 event_count = 0
def _inject_test_event( def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
) -> EventBase: ) -> EventBase:
if sender is None: if sender is None:
sender = self.user_id sender = self.user_id

View File

@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
config["federation_sender_instances"] = ["federation_sender1"] config["federation_sender_instances"] = ["federation_sender1"]
return config return config
def test_catchup(self): def test_catchup(self) -> None:
"""Basic test of catchup on reconnect """Basic test of catchup on reconnect
Makes sure that updates sent while we are offline are received later. Makes sure that updates sent while we are offline are received later.

View File

@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
hijack_auth = True hijack_auth = True
user_id = "@bob:test" user_id = "@bob:test"
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main

View File

@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase): class TypingStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self): def _build_replication_data_handler(self) -> Mock:
return Mock(wraps=super()._build_replication_data_handler()) self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
return self.mock_handler
def test_typing(self): def test_typing(self) -> None:
typing = self.hs.get_typing_handler() typing = self.hs.get_typing_handler()
self.reconnect() self.reconnect()
@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt() request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once() self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0] row: TypingStream.TypingStreamRow = rdata_rows[0]
@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
# Now let's disconnect and insert some data. # Now let's disconnect and insert some data.
self.disconnect() self.disconnect()
self.test_handler.on_rdata.reset_mock() self.mock_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called() self.mock_handler.on_rdata.assert_not_called()
self.reconnect() self.reconnect()
self.pump(0.1) self.pump(0.1)
@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
assert request.args is not None assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token) self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once() self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] row = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id) self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids) self.assertEqual([], row.user_ids)
def test_reset(self): def test_reset(self) -> None:
""" """
Test what happens when a typing stream resets. Test what happens when a typing stream resets.
@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
request = self.handle_http_replication_attempt() request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once() self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0] row: TypingStream.TypingStreamRow = rdata_rows[0]
@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code. # Reset the test code.
self.test_handler.on_rdata.reset_mock() self.mock_handler.on_rdata.reset_mock()
self.test_handler.on_rdata.assert_not_called() self.mock_handler.on_rdata.assert_not_called()
# Push additional data. # Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0) self.reactor.advance(0)
self.test_handler.on_rdata.assert_called_once() self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] row = rdata_rows[0]

View File

@ -21,12 +21,12 @@ from tests.unittest import TestCase
class ParseCommandTestCase(TestCase): class ParseCommandTestCase(TestCase):
def test_parse_one_word_command(self): def test_parse_one_word_command(self) -> None:
line = "REPLICATE" line = "REPLICATE"
cmd = parse_command_from_line(line) cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, ReplicateCommand) self.assertIsInstance(cmd, ReplicateCommand)
def test_parse_rdata(self): def test_parse_rdata(self) -> None:
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line) cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand) assert isinstance(cmd, RdataCommand)
@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863) self.assertEqual(cmd.token, 6287863)
def test_parse_rdata_batch(self): def test_parse_rdata_batch(self) -> None:
line = 'RDATA presence master batch ["@foo:example.com", "online"]' line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line) cmd = parse_command_from_line(line)
assert isinstance(cmd, RdataCommand) assert isinstance(cmd, RdataCommand)

View File

@ -16,15 +16,17 @@ from typing import Tuple
from twisted.internet.address import IPv4Address from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase): class RemoteServerUpTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.factory = ReplicationStreamProtocolFactory(hs) self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]: def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
return proto, transport return proto, transport
def test_relay(self): def test_relay(self) -> None:
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all """Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it. other connections, but not the one that sent it.
""" """

View File

@ -13,7 +13,11 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.client import register from synapse.rest.client import register
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request from tests.server import FakeChannel, make_request
@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets] servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
# This isn't a real configuration option but is used to provide the main # This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options. # homeserver and worker homeserver different options.
@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
{"auth": {"session": session, "type": "m.login.dummy"}}, {"auth": {"session": session, "type": "m.login.dummy"}},
) )
def test_no_auth(self): def test_no_auth(self) -> None:
"""With no authentication the request should finish.""" """With no authentication the request should finish."""
channel = self._test_register() channel = self._test_register()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.json_body["user_id"], "@user:test") self.assertEqual(channel.json_body["user_id"], "@user:test")
@override_config({"main_replication_secret": "my-secret"}) @override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self): def test_missing_auth(self) -> None:
"""If the main process expects a secret that is not provided, an error results.""" """If the main process expects a secret that is not provided, an error results."""
channel = self._test_register() channel = self._test_register()
self.assertEqual(channel.code, 500) self.assertEqual(channel.code, 500)
@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"worker_replication_secret": "wrong-secret", "worker_replication_secret": "wrong-secret",
} }
) )
def test_unauthorized(self): def test_unauthorized(self) -> None:
"""If the main process receives the wrong secret, an error results.""" """If the main process receives the wrong secret, an error results."""
channel = self._test_register() channel = self._test_register()
self.assertEqual(channel.code, 500) self.assertEqual(channel.code, 500)
@override_config({"worker_replication_secret": "my-secret"}) @override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self): def test_authorized(self) -> None:
"""The request should finish when the worker provides the authentication header.""" """The request should finish when the worker provides the authentication header."""
channel = self._test_register() channel = self._test_register()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)

View File

@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
config["worker_replication_http_port"] = "8765" config["worker_replication_http_port"] = "8765"
return config return config
def test_register_single_worker(self): def test_register_single_worker(self) -> None:
"""Test that registration works when using a single generic worker.""" """Test that registration works when using a single generic worker."""
worker_hs = self.make_worker_hs("synapse.app.generic_worker") worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs] site = self._hs_to_site[worker_hs]
@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
# We're given a registered user. # We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test") self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self): def test_register_multi_worker(self) -> None:
"""Test that registration works when using multiple generic workers.""" """Test that registration works when using multiple generic workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker") worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker") worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")

View File

@ -14,10 +14,14 @@
from unittest import mock from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream from synapse.replication.tcp.streams.federation import FederationStream
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
config["federation_sender_instances"] = ["federation_sender1"] config["federation_sender_instances"] = ["federation_sender1"]
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs def test_federation_ack_sent(self) -> None:
def test_federation_ack_sent(self):
"""A FEDERATION_ACK should be sent back after each RDATA federation """A FEDERATION_ACK should be sent back after each RDATA federation
This test checks that the federation sender is correctly sending back This test checks that the federation sender is correctly sending back

View File

@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets, room.register_servlets,
] ]
def test_send_event_single_sender(self): def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a """Test that using a single federation sender worker correctly sends a
new event. new event.
""" """
@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
def test_send_event_sharded(self): def test_send_event_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends """Test that using two federation sender workers correctly sends
new events. new events.
""" """
@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1) self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2) self.assertTrue(sent_on_2)
def test_send_typing_sharded(self): def test_send_typing_sharded(self) -> None:
"""Test that using two federation sender workers correctly sends """Test that using two federation sender workers correctly sends
new typing EDUs. new typing EDUs.
""" """
@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1) self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2) self.assertTrue(sent_on_2)
def create_room_with_remote_server(self, user, token, remote_server="other_server"): def create_room_with_remote_server(
self, user: str, token: str, remote_server: str = "other_server"
) -> str:
room = self.helper.create_room_as(user, tok=token) room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler() federation = self.hs.get_federation_event_handler()

View File

@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
] ]
def test_module_cache_full_invalidation(self): def test_module_cache_full_invalidation(self) -> None:
main_cache = TestCache() main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function) self.hs.get_module_api().register_cached_function(main_cache.cached_function)

View File

@ -18,12 +18,14 @@ from typing import Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.server import Request from twisted.web.server import Request
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass") self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass") self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4" self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self): def default_config(self) -> dict:
conf = super().default_config() conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf return conf
@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request return channel, request
def test_basic(self): def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker.""" """Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker") hs1 = self.make_worker_hs("synapse.app.generic_worker")
@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!") self.assertEqual(channel.result["body"], b"Hello!")
def test_download_simple_file_race(self): def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the """Test that fetching remote media from two different processes at the
same time works. same time works.
""" """
@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted. # We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media()) self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self): def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at """Test that fetching remote *images* from two different processes at
the same time works. the same time works.
@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path)) return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory(): def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test # this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory global test_server_connection_factory
if test_server_connection_factory is None: if test_server_connection_factory is None:
@ -263,6 +265,6 @@ def _build_test_server(
return server_tls_factory.buildProtocol(None) return server_tls_factory.buildProtocol(None)
def _log_request(request): def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish""" """Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request) logger.info("Completed request %s", request)

View File

@ -15,9 +15,12 @@ import logging
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about # Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass") self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass")
def _create_pusher_and_send_msg(self, localpart): def _create_pusher_and_send_msg(self, localpart: str) -> str:
# Create a user that will get push notifications # Create a user that will get push notifications
user_id = self.register_user(localpart, "pass") user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass") access_token = self.login(localpart, "pass")
@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
return event_id return event_id
def test_send_push_single_worker(self): def test_send_push_single_worker(self) -> None:
"""Test that registration works when using a pusher worker.""" """Test that registration works when using a pusher worker."""
http_client_mock = Mock(spec_set=["post_json_get_json"]) http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = ( http_client_mock.post_json_get_json.side_effect = (
@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
], ],
) )
def test_send_push_multiple_workers(self): def test_send_push_multiple_workers(self) -> None:
"""Test that registration works when using sharded pusher workers.""" """Test that registration works when using sharded pusher workers."""
http_client_mock1 = Mock(spec_set=["post_json_get_json"]) http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = ( http_client_mock1.post_json_get_json.side_effect = (

View File

@ -14,9 +14,13 @@
import logging import logging
from unittest.mock import patch from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room, sync from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request from tests.server import make_request
@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about # Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass") self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass")
@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.room_creator = self.hs.get_room_creation_handler() self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def default_config(self): def default_config(self) -> dict:
conf = super().default_config() conf = super().default_config()
conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = { conf["instance_map"] = {
@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
} }
return conf return conf
def _create_room(self, room_id: str, user_id: str, tok: str): def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
"""Create a room with given room_id""" """Create a room with given room_id"""
# We control the room ID generation by patching out the # We control the room ID generation by patching out the
@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
mock.side_effect = lambda: room_id mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok) self.helper.create_room_as(user_id, tok=tok)
def test_basic(self): def test_basic(self) -> None:
"""Simple test to ensure that multiple rooms can be created and joined, """Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances. and that different rooms get handled by different instances.
""" """
@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1) self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2) self.assertTrue(persisted_on_2)
def test_vector_clock_token(self): def test_vector_clock_token(self) -> None:
"""Tests that using a stream token with a vector clock component works """Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage. correctly with basic /sync and /messages usage.
""" """