Add missing type hints to tests.replication. (#14987)
This commit is contained in:
parent
b275763c65
commit
156cd88eef
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
3
mypy.ini
3
mypy.ini
|
@ -104,6 +104,9 @@ disallow_untyped_defs = True
|
|||
[mypy-tests.push.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.replication.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.rest.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -16,7 +16,9 @@ from collections import defaultdict
|
|||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.protocol import Protocol, connectionDone
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
|
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
|
|||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
repl_handler,
|
||||
)
|
||||
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
self._client_transport: Optional[FakeTransport] = None
|
||||
self._server_transport: Optional[FakeTransport] = None
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
|
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
def reconnect(self):
|
||||
def reconnect(self) -> None:
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
|
||||
|
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||
self.server.makeConnection(self._server_transport)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
if self._client_transport:
|
||||
self._client_transport = None
|
||||
self.client.close()
|
||||
|
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = None
|
||||
self.server.close()
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
|
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
requests: List[SynapseRequest] = []
|
||||
real_request_factory = channel.requestFactory
|
||||
|
||||
def request_factory(*args, **kwargs):
|
||||
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
|
||||
request = real_request_factory(*args, **kwargs)
|
||||
requests.append(request)
|
||||
return request
|
||||
|
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
):
|
||||
) -> None:
|
||||
"""Asserts that the given request is a HTTP replication request for
|
||||
fetching updates for given stream.
|
||||
"""
|
||||
|
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
base["redis"] = {"enabled": True}
|
||||
return base
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
# build a replication server
|
||||
|
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
lambda: self._handle_http_replication_attempt(self.hs, 8765),
|
||||
)
|
||||
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> ReplicationRestResource:
|
||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||
# We override this so that it automatically registers all the HTTP
|
||||
# replication servlets, without having to explicitly do that in all
|
||||
|
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
return resource
|
||||
|
||||
def make_worker_hs(
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
|
||||
) -> HomeServer:
|
||||
"""Make a new worker HS instance, correctly connecting replcation
|
||||
stream to the master HS.
|
||||
|
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump()
|
||||
|
||||
def _handle_http_replication_attempt(self, hs, repl_port):
|
||||
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
|
||||
"""Handles a connection attempt to the given HS replication HTTP
|
||||
listener on the given port.
|
||||
"""
|
||||
|
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
# inside `connecTCP` before the connection has been passed back to the
|
||||
# code that requested the TCP connection.
|
||||
|
||||
def connect_any_redis_attempts(self):
|
||||
def connect_any_redis_attempts(self) -> None:
|
||||
"""If redis is enabled we need to deal with workers connecting to a
|
||||
redis server. We don't want to use a real Redis server so we use a
|
||||
fake one.
|
||||
|
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
server_protocol = self._redis_server.buildProtocol(None)
|
||||
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
|
||||
client_protocol = client_factory.buildProtocol(client_address)
|
||||
|
||||
server_address = IPv4Address("TCP", host, port)
|
||||
server_protocol = self._redis_server.buildProtocol(server_address)
|
||||
|
||||
client_to_server_transport = FakeTransport(
|
||||
server_protocol, self.reactor, client_protocol
|
||||
|
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
|
||||
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
) -> None:
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
class FakeRedisPubSubServer:
|
||||
"""A fake Redis server for pub/sub."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._subscribers_by_channel: Dict[
|
||||
bytes, Set["FakeRedisPubSubProtocol"]
|
||||
] = defaultdict(set)
|
||||
|
||||
def add_subscriber(self, conn, channel: bytes):
|
||||
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
|
||||
"""A connection has called SUBSCRIBE"""
|
||||
self._subscribers_by_channel[channel].add(conn)
|
||||
|
||||
def remove_subscriber(self, conn):
|
||||
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
|
||||
"""A connection has lost connection"""
|
||||
for subscribers in self._subscribers_by_channel.values():
|
||||
subscribers.discard(conn)
|
||||
|
||||
def publish(self, conn, channel: bytes, msg) -> int:
|
||||
def publish(
|
||||
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
|
||||
) -> int:
|
||||
"""A connection want to publish a message to subscribers."""
|
||||
for sub in self._subscribers_by_channel[channel]:
|
||||
sub.send(["message", channel, msg])
|
||||
|
||||
return len(self._subscribers_by_channel)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
|
||||
return FakeRedisPubSubProtocol(self)
|
||||
|
||||
|
||||
|
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self._server = server
|
||||
self._reader = hiredis.Reader()
|
||||
|
||||
def dataReceived(self, data):
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self._reader.feed(data)
|
||||
|
||||
# We might get multiple messages in one packet.
|
||||
|
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
self.handle_command(msg[0], *msg[1:])
|
||||
|
||||
def handle_command(self, command, *args):
|
||||
def handle_command(self, command: bytes, *args: bytes) -> None:
|
||||
"""Received a Redis command from the client."""
|
||||
|
||||
# We currently only support pub/sub.
|
||||
|
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self.send("PONG")
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown command: {command}")
|
||||
raise Exception(f"Unknown command: {command!r}")
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg: object) -> None:
|
||||
"""Send a message back to the client."""
|
||||
assert self.transport is not None
|
||||
|
||||
|
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self.transport.write(raw)
|
||||
self.transport.flush()
|
||||
|
||||
def encode(self, obj):
|
||||
def encode(self, obj: object) -> str:
|
||||
"""Encode an object to its Redis format.
|
||||
|
||||
Supports: strings/bytes, integers and list/tuples.
|
||||
|
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
self._server.remove_subscriber(self)
|
||||
|
|
|
@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
|
|||
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests for `ReplicationEndpoint` cancellation."""
|
||||
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> JsonResource:
|
||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||
resource = JsonResource(self.hs)
|
||||
|
||||
|
|
|
@ -13,35 +13,42 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Iterable, Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseStreamTestCase
|
||||
|
||||
|
||||
class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
return self.setup_test_homeserver(federation_client=Mock())
|
||||
|
||||
hs = self.setup_test_homeserver(federation_client=Mock())
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self.reconnect()
|
||||
|
||||
self.master_store = hs.get_datastores().main
|
||||
self.slaved_store = self.worker_hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
persistence = hs.get_storage_controllers().persistence
|
||||
assert persistence is not None
|
||||
self.persistance = persistence
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump(0.1)
|
||||
|
||||
def check(self, method, args, expected_result=None):
|
||||
def check(
|
||||
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
|
||||
) -> None:
|
||||
master_result = self.get_success(getattr(self.master_store, method)(*args))
|
||||
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
|
||||
if expected_result is not None:
|
||||
|
|
|
@ -12,15 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Iterable, Optional
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.event_push_actions import (
|
||||
NotifCounts,
|
||||
RoomNotifCounts,
|
||||
|
@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
|
||||
from synapse.types import PersistedEventPosition
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.server import FakeTransport
|
||||
|
||||
|
@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dict_equals(self, other):
|
||||
def dict_equals(self: EventBase, other: EventBase) -> bool:
|
||||
me = encode_canonical_json(self.get_pdu_json())
|
||||
them = encode_canonical_json(other.get_pdu_json())
|
||||
return me == them
|
||||
|
||||
|
||||
def patch__eq__(cls):
|
||||
def patch__eq__(cls: object) -> Callable[[], None]:
|
||||
eq = getattr(cls, "__eq__", None)
|
||||
cls.__eq__ = dict_equals
|
||||
cls.__eq__ = dict_equals # type: ignore[assignment]
|
||||
|
||||
def unpatch():
|
||||
def unpatch() -> None:
|
||||
if eq is not None:
|
||||
cls.__eq__ = eq
|
||||
cls.__eq__ = eq # type: ignore[assignment]
|
||||
|
||||
return unpatch
|
||||
|
||||
|
@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
|
||||
STORE_TYPE = EventsWorkerStore
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
# Patch up the equality operator for events so that we can check
|
||||
# whether lists of events match using assertEqual
|
||||
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
|
||||
return super().setUp()
|
||||
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
|
||||
super().setUp()
|
||||
|
||||
def prepare(self, *args, **kwargs):
|
||||
super().prepare(*args, **kwargs)
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self.get_success(
|
||||
self.master_store.store_room(
|
||||
|
@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
[unpatch() for unpatch in self.unpatches]
|
||||
|
||||
def test_get_latest_event_ids_in_room(self):
|
||||
def test_get_latest_event_ids_in_room(self) -> None:
|
||||
create = self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||
|
@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||
|
||||
def test_redactions(self):
|
||||
def test_redactions(self) -> None:
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
|
@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
def test_backfilled_redactions(self):
|
||||
def test_backfilled_redactions(self) -> None:
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
|
@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
def test_invites(self):
|
||||
def test_invites(self) -> None:
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
|
||||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
||||
|
@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
)
|
||||
|
||||
@parameterized.expand([(True,), (False,)])
|
||||
def test_push_actions_for_user(self, send_receipt: bool):
|
||||
def test_push_actions_for_user(self, send_receipt: bool) -> None:
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
self.persist(
|
||||
|
@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
),
|
||||
)
|
||||
|
||||
def test_get_rooms_for_user_with_stream_ordering(self):
|
||||
def test_get_rooms_for_user_with_stream_ordering(self) -> None:
|
||||
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
|
||||
by rows in the events stream
|
||||
"""
|
||||
|
@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
|
||||
)
|
||||
|
||||
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
|
||||
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check that current_state invalidation happens correctly with multiple events
|
||||
in the persistence batch.
|
||||
|
||||
|
@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
||||
)
|
||||
msg, msgctx = self.build_event()
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(j2, j2ctx), (msg, msgctx)]
|
||||
)
|
||||
)
|
||||
self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
|
||||
self.replicate()
|
||||
assert j2.internal_metadata.stream_ordering is not None
|
||||
|
||||
|
@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
|
||||
event_id = 0
|
||||
|
||||
def persist(self, backfill=False, **kwargs) -> FrozenEvent:
|
||||
def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
|
||||
"""
|
||||
Returns:
|
||||
The event that was persisted.
|
||||
|
@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
|||
|
||||
if backfill:
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(event, context)], backfilled=True
|
||||
)
|
||||
self.persistance.persist_events([(event, context)], backfilled=True)
|
||||
)
|
||||
else:
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
self.get_success(self.persistance.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
def build_event(
|
||||
self,
|
||||
sender=USER_ID,
|
||||
room_id=ROOM_ID,
|
||||
type="m.room.message",
|
||||
key=None,
|
||||
sender: str = USER_ID,
|
||||
room_id: str = ROOM_ID,
|
||||
type: str = "m.room.message",
|
||||
key: Optional[str] = None,
|
||||
internal: Optional[dict] = None,
|
||||
depth=None,
|
||||
prev_events: Optional[list] = None,
|
||||
auth_events: Optional[list] = None,
|
||||
prev_state: Optional[list] = None,
|
||||
redacts=None,
|
||||
depth: Optional[int] = None,
|
||||
prev_events: Optional[List[Tuple[str, dict]]] = None,
|
||||
auth_events: Optional[List[str]] = None,
|
||||
prev_state: Optional[List[str]] = None,
|
||||
redacts: Optional[str] = None,
|
||||
push_actions: Iterable = frozenset(),
|
||||
**content,
|
||||
):
|
||||
**content: object,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
prev_events = prev_events or []
|
||||
auth_events = auth_events or []
|
||||
prev_state = prev_state or []
|
||||
|
|
|
@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
|
|||
|
||||
|
||||
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||
def test_update_function_room_account_data_limit(self):
|
||||
def test_update_function_room_account_data_limit(self) -> None:
|
||||
"""Test replication with many room account data updates"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
|
@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_global_account_data_limit(self):
|
||||
def test_update_function_global_account_data_limit(self) -> None:
|
||||
"""Test replication with many global account data updates"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
|
@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
|
|||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseStreamTestCase
|
||||
from tests.test_utils.event_injection import inject_event, inject_member_event
|
||||
|
@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
|
@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
def test_update_function_event_row_limit(self):
|
||||
def test_update_function_event_row_limit(self) -> None:
|
||||
"""Test replication with many non-state events
|
||||
|
||||
Checks that all events are correctly replicated when there are lots of
|
||||
|
@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_huge_state_change(self):
|
||||
def test_update_function_huge_state_change(self) -> None:
|
||||
"""Test replication with many state events
|
||||
|
||||
Ensures that all events are correctly replicated when there are lots of
|
||||
|
@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
# "None" indicates the state has been deleted
|
||||
self.assertIsNone(sr.event_id)
|
||||
|
||||
def test_update_function_state_row_limit(self):
|
||||
def test_update_function_state_row_limit(self) -> None:
|
||||
"""Test replication with many state events over several stream ids."""
|
||||
|
||||
# we want to generate lots of state changes, but for this test, we want to
|
||||
|
@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_backwards_stream_id(self):
|
||||
def test_backwards_stream_id(self) -> None:
|
||||
"""
|
||||
Test that RDATA that comes after the current position should be discarded.
|
||||
"""
|
||||
|
@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
event_count = 0
|
||||
|
||||
def _inject_test_event(
|
||||
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
|
||||
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
|
||||
) -> EventBase:
|
||||
if sender is None:
|
||||
sender = self.user_id
|
||||
|
|
|
@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
|
|||
config["federation_sender_instances"] = ["federation_sender1"]
|
||||
return config
|
||||
|
||||
def test_catchup(self):
|
||||
def test_catchup(self) -> None:
|
||||
"""Basic test of catchup on reconnect
|
||||
|
||||
Makes sure that updates sent while we are offline are received later.
|
||||
|
|
|
@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
|
|||
hijack_auth = True
|
||||
user_id = "@bob:test"
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
|
|
|
@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
|
|||
|
||||
|
||||
class TypingStreamTestCase(BaseStreamTestCase):
|
||||
def _build_replication_data_handler(self):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
def _build_replication_data_handler(self) -> Mock:
|
||||
self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
|
||||
return self.mock_handler
|
||||
|
||||
def test_typing(self):
|
||||
def test_typing(self) -> None:
|
||||
typing = self.hs.get_typing_handler()
|
||||
|
||||
self.reconnect()
|
||||
|
@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
request = self.handle_http_replication_attempt()
|
||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.mock_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
||||
|
@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
# Now let's disconnect and insert some data.
|
||||
self.disconnect()
|
||||
|
||||
self.test_handler.on_rdata.reset_mock()
|
||||
self.mock_handler.on_rdata.reset_mock()
|
||||
|
||||
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
|
||||
|
||||
self.test_handler.on_rdata.assert_not_called()
|
||||
self.mock_handler.on_rdata.assert_not_called()
|
||||
|
||||
self.reconnect()
|
||||
self.pump(0.1)
|
||||
|
@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
assert request.args is not None
|
||||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.mock_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0]
|
||||
self.assertEqual(ROOM_ID, row.room_id)
|
||||
self.assertEqual([], row.user_ids)
|
||||
|
||||
def test_reset(self):
|
||||
def test_reset(self) -> None:
|
||||
"""
|
||||
Test what happens when a typing stream resets.
|
||||
|
||||
|
@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
request = self.handle_http_replication_attempt()
|
||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.mock_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
||||
|
@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
# Reset the test code.
|
||||
self.test_handler.on_rdata.reset_mock()
|
||||
self.test_handler.on_rdata.assert_not_called()
|
||||
self.mock_handler.on_rdata.reset_mock()
|
||||
self.mock_handler.on_rdata.assert_not_called()
|
||||
|
||||
# Push additional data.
|
||||
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
|
||||
self.reactor.advance(0)
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.mock_handler.on_rdata.assert_called_once()
|
||||
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0]
|
||||
|
|
|
@ -21,12 +21,12 @@ from tests.unittest import TestCase
|
|||
|
||||
|
||||
class ParseCommandTestCase(TestCase):
|
||||
def test_parse_one_word_command(self):
|
||||
def test_parse_one_word_command(self) -> None:
|
||||
line = "REPLICATE"
|
||||
cmd = parse_command_from_line(line)
|
||||
self.assertIsInstance(cmd, ReplicateCommand)
|
||||
|
||||
def test_parse_rdata(self):
|
||||
def test_parse_rdata(self) -> None:
|
||||
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
||||
cmd = parse_command_from_line(line)
|
||||
assert isinstance(cmd, RdataCommand)
|
||||
|
@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
|
|||
self.assertEqual(cmd.instance_name, "master")
|
||||
self.assertEqual(cmd.token, 6287863)
|
||||
|
||||
def test_parse_rdata_batch(self):
|
||||
def test_parse_rdata_batch(self) -> None:
|
||||
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
|
||||
cmd = parse_command_from_line(line)
|
||||
assert isinstance(cmd, RdataCommand)
|
||||
|
|
|
@ -16,15 +16,17 @@ from typing import Tuple
|
|||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.interfaces import IProtocol
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.test.proto_helpers import MemoryReactor, StringTransport
|
||||
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class RemoteServerUpTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.factory = ReplicationStreamProtocolFactory(hs)
|
||||
|
||||
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
||||
|
@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
|
|||
|
||||
return proto, transport
|
||||
|
||||
def test_relay(self):
|
||||
def test_relay(self) -> None:
|
||||
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
|
||||
other connections, but not the one that sent it.
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,11 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest.client import register
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import FakeChannel, make_request
|
||||
|
@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
|||
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
# This isn't a real configuration option but is used to provide the main
|
||||
# homeserver and worker homeserver different options.
|
||||
|
@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
|||
{"auth": {"session": session, "type": "m.login.dummy"}},
|
||||
)
|
||||
|
||||
def test_no_auth(self):
|
||||
def test_no_auth(self) -> None:
|
||||
"""With no authentication the request should finish."""
|
||||
channel = self._test_register()
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||
|
||||
@override_config({"main_replication_secret": "my-secret"})
|
||||
def test_missing_auth(self):
|
||||
def test_missing_auth(self) -> None:
|
||||
"""If the main process expects a secret that is not provided, an error results."""
|
||||
channel = self._test_register()
|
||||
self.assertEqual(channel.code, 500)
|
||||
|
@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
|||
"worker_replication_secret": "wrong-secret",
|
||||
}
|
||||
)
|
||||
def test_unauthorized(self):
|
||||
def test_unauthorized(self) -> None:
|
||||
"""If the main process receives the wrong secret, an error results."""
|
||||
channel = self._test_register()
|
||||
self.assertEqual(channel.code, 500)
|
||||
|
||||
@override_config({"worker_replication_secret": "my-secret"})
|
||||
def test_authorized(self):
|
||||
def test_authorized(self) -> None:
|
||||
"""The request should finish when the worker provides the authentication header."""
|
||||
channel = self._test_register()
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
|
|
@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def test_register_single_worker(self):
|
||||
def test_register_single_worker(self) -> None:
|
||||
"""Test that registration works when using a single generic worker."""
|
||||
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||
site = self._hs_to_site[worker_hs]
|
||||
|
@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
# We're given a registered user.
|
||||
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
|
||||
|
||||
def test_register_multi_worker(self):
|
||||
def test_register_multi_worker(self) -> None:
|
||||
"""Test that registration works when using multiple generic workers."""
|
||||
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
|
|
@ -14,10 +14,14 @@
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.replication.tcp.commands import FederationAckCommand
|
||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
|||
config["federation_sender_instances"] = ["federation_sender1"]
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
|
||||
|
||||
return hs
|
||||
|
||||
def test_federation_ack_sent(self):
|
||||
def test_federation_ack_sent(self) -> None:
|
||||
"""A FEDERATION_ACK should be sent back after each RDATA federation
|
||||
|
||||
This test checks that the federation sender is correctly sending back
|
||||
|
|
|
@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
room.register_servlets,
|
||||
]
|
||||
|
||||
def test_send_event_single_sender(self):
|
||||
def test_send_event_single_sender(self) -> None:
|
||||
"""Test that using a single federation sender worker correctly sends a
|
||||
new event.
|
||||
"""
|
||||
|
@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
|
||||
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
|
||||
|
||||
def test_send_event_sharded(self):
|
||||
def test_send_event_sharded(self) -> None:
|
||||
"""Test that using two federation sender workers correctly sends
|
||||
new events.
|
||||
"""
|
||||
|
@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertTrue(sent_on_1)
|
||||
self.assertTrue(sent_on_2)
|
||||
|
||||
def test_send_typing_sharded(self):
|
||||
def test_send_typing_sharded(self) -> None:
|
||||
"""Test that using two federation sender workers correctly sends
|
||||
new typing EDUs.
|
||||
"""
|
||||
|
@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertTrue(sent_on_1)
|
||||
self.assertTrue(sent_on_2)
|
||||
|
||||
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
|
||||
def create_room_with_remote_server(
|
||||
self, user: str, token: str, remote_server: str = "other_server"
|
||||
) -> str:
|
||||
room = self.helper.create_room_as(user, tok=token)
|
||||
store = self.hs.get_datastores().main
|
||||
federation = self.hs.get_federation_event_handler()
|
||||
|
|
|
@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
|
|||
synapse.rest.admin.register_servlets,
|
||||
]
|
||||
|
||||
def test_module_cache_full_invalidation(self):
|
||||
def test_module_cache_full_invalidation(self) -> None:
|
||||
main_cache = TestCache()
|
||||
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
||||
|
||||
|
|
|
@ -18,12 +18,14 @@ from typing import Optional, Tuple
|
|||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
|
||||
self.reactor.lookups["example.com"] = "1.2.3.4"
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> dict:
|
||||
conf = super().default_config()
|
||||
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||
return conf
|
||||
|
@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
|
||||
return channel, request
|
||||
|
||||
def test_basic(self):
|
||||
def test_basic(self) -> None:
|
||||
"""Test basic fetching of remote media from a single worker."""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
|
@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.result["body"], b"Hello!")
|
||||
|
||||
def test_download_simple_file_race(self):
|
||||
def test_download_simple_file_race(self) -> None:
|
||||
"""Test that fetching remote media from two different processes at the
|
||||
same time works.
|
||||
"""
|
||||
|
@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
# We expect only one new file to have been persisted.
|
||||
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||
|
||||
def test_download_image_race(self):
|
||||
def test_download_image_race(self) -> None:
|
||||
"""Test that fetching remote *images* from two different processes at
|
||||
the same time works.
|
||||
|
||||
|
@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
|
||||
def get_connection_factory():
|
||||
def get_connection_factory() -> TestServerTLSConnectionFactory:
|
||||
# this needs to happen once, but not until we are ready to run the first test
|
||||
global test_server_connection_factory
|
||||
if test_server_connection_factory is None:
|
||||
|
@ -263,6 +265,6 @@ def _build_test_server(
|
|||
return server_tls_factory.buildProtocol(None)
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
def _log_request(request: Request) -> None:
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
||||
|
|
|
@ -15,9 +15,12 @@ import logging
|
|||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
||||
|
@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# Register a user who sends a message that we'll get notified about
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
def _create_pusher_and_send_msg(self, localpart):
|
||||
def _create_pusher_and_send_msg(self, localpart: str) -> str:
|
||||
# Create a user that will get push notifications
|
||||
user_id = self.register_user(localpart, "pass")
|
||||
access_token = self.login(localpart, "pass")
|
||||
|
@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
|
||||
return event_id
|
||||
|
||||
def test_send_push_single_worker(self):
|
||||
def test_send_push_single_worker(self) -> None:
|
||||
"""Test that registration works when using a pusher worker."""
|
||||
http_client_mock = Mock(spec_set=["post_json_get_json"])
|
||||
http_client_mock.post_json_get_json.side_effect = (
|
||||
|
@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
],
|
||||
)
|
||||
|
||||
def test_send_push_multiple_workers(self):
|
||||
def test_send_push_multiple_workers(self) -> None:
|
||||
"""Test that registration works when using sharded pusher workers."""
|
||||
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
|
||||
http_client_mock1.post_json_get_json.side_effect = (
|
||||
|
|
|
@ -14,9 +14,13 @@
|
|||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room, sync
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.server import make_request
|
||||
|
@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# Register a user who sends a message that we'll get notified about
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.room_creator = self.hs.get_room_creation_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> dict:
|
||||
conf = super().default_config()
|
||||
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
|
||||
conf["instance_map"] = {
|
||||
|
@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
}
|
||||
return conf
|
||||
|
||||
def _create_room(self, room_id: str, user_id: str, tok: str):
|
||||
def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
|
||||
"""Create a room with given room_id"""
|
||||
|
||||
# We control the room ID generation by patching out the
|
||||
|
@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
mock.side_effect = lambda: room_id
|
||||
self.helper.create_room_as(user_id, tok=tok)
|
||||
|
||||
def test_basic(self):
|
||||
def test_basic(self) -> None:
|
||||
"""Simple test to ensure that multiple rooms can be created and joined,
|
||||
and that different rooms get handled by different instances.
|
||||
"""
|
||||
|
@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
self.assertTrue(persisted_on_1)
|
||||
self.assertTrue(persisted_on_2)
|
||||
|
||||
def test_vector_clock_token(self):
|
||||
def test_vector_clock_token(self) -> None:
|
||||
"""Tests that using a stream token with a vector clock component works
|
||||
correctly with basic /sync and /messages usage.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue