Add missing type hints to tests.replication. (#14987)
This commit is contained in:
parent
b275763c65
commit
156cd88eef
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
3
mypy.ini
3
mypy.ini
|
@ -104,6 +104,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.push.*]
|
[mypy-tests.push.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-tests.replication.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.rest.*]
|
[mypy-tests.rest.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,9 @@ from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from twisted.internet.address import IPv4Address
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol, connectionDone
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
|
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
|
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
if not hiredis:
|
if not hiredis:
|
||||||
skip = "Requires hiredis"
|
skip = "Requires hiredis"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# build a replication server
|
# build a replication server
|
||||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||||
self.streamer = hs.get_replication_streamer()
|
self.streamer = hs.get_replication_streamer()
|
||||||
|
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
repl_handler,
|
repl_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._client_transport = None
|
self._client_transport: Optional[FakeTransport] = None
|
||||||
self._server_transport = None
|
self._server_transport: Optional[FakeTransport] = None
|
||||||
|
|
||||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
d = super().create_resource_dict()
|
d = super().create_resource_dict()
|
||||||
|
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
config["worker_replication_http_port"] = "8765"
|
config["worker_replication_http_port"] = "8765"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _build_replication_data_handler(self):
|
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
|
||||||
return TestReplicationDataHandler(self.worker_hs)
|
return TestReplicationDataHandler(self.worker_hs)
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self) -> None:
|
||||||
if self._client_transport:
|
if self._client_transport:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
|
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self._server_transport = FakeTransport(self.client, self.reactor)
|
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||||
self.server.makeConnection(self._server_transport)
|
self.server.makeConnection(self._server_transport)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
if self._client_transport:
|
if self._client_transport:
|
||||||
self._client_transport = None
|
self._client_transport = None
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self._server_transport = None
|
self._server_transport = None
|
||||||
self.server.close()
|
self.server.close()
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self) -> None:
|
||||||
"""Tell the master side of replication that something has happened, and then
|
"""Tell the master side of replication that something has happened, and then
|
||||||
wait for the replication to occur.
|
wait for the replication to occur.
|
||||||
"""
|
"""
|
||||||
|
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
requests: List[SynapseRequest] = []
|
requests: List[SynapseRequest] = []
|
||||||
real_request_factory = channel.requestFactory
|
real_request_factory = channel.requestFactory
|
||||||
|
|
||||||
def request_factory(*args, **kwargs):
|
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
|
||||||
request = real_request_factory(*args, **kwargs)
|
request = real_request_factory(*args, **kwargs)
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
return request
|
return request
|
||||||
|
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def assert_request_is_get_repl_stream_updates(
|
def assert_request_is_get_repl_stream_updates(
|
||||||
self, request: SynapseRequest, stream_name: str
|
self, request: SynapseRequest, stream_name: str
|
||||||
):
|
) -> None:
|
||||||
"""Asserts that the given request is a HTTP replication request for
|
"""Asserts that the given request is a HTTP replication request for
|
||||||
fetching updates for given stream.
|
fetching updates for given stream.
|
||||||
"""
|
"""
|
||||||
|
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
base["redis"] = {"enabled": True}
|
base["redis"] = {"enabled": True}
|
||||||
return base
|
return base
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
# build a replication server
|
# build a replication server
|
||||||
|
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
lambda: self._handle_http_replication_attempt(self.hs, 8765),
|
lambda: self._handle_http_replication_attempt(self.hs, 8765),
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_test_resource(self):
|
def create_test_resource(self) -> ReplicationRestResource:
|
||||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||||
# We override this so that it automatically registers all the HTTP
|
# We override this so that it automatically registers all the HTTP
|
||||||
# replication servlets, without having to explicitly do that in all
|
# replication servlets, without having to explicitly do that in all
|
||||||
|
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
return resource
|
return resource
|
||||||
|
|
||||||
def make_worker_hs(
|
def make_worker_hs(
|
||||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
|
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
|
||||||
) -> HomeServer:
|
) -> HomeServer:
|
||||||
"""Make a new worker HS instance, correctly connecting replcation
|
"""Make a new worker HS instance, correctly connecting replcation
|
||||||
stream to the master HS.
|
stream to the master HS.
|
||||||
|
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
config["worker_replication_http_port"] = "8765"
|
config["worker_replication_http_port"] = "8765"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self) -> None:
|
||||||
"""Tell the master side of replication that something has happened, and then
|
"""Tell the master side of replication that something has happened, and then
|
||||||
wait for the replication to occur.
|
wait for the replication to occur.
|
||||||
"""
|
"""
|
||||||
self.streamer.on_notifier_poke()
|
self.streamer.on_notifier_poke()
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
def _handle_http_replication_attempt(self, hs, repl_port):
|
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
|
||||||
"""Handles a connection attempt to the given HS replication HTTP
|
"""Handles a connection attempt to the given HS replication HTTP
|
||||||
listener on the given port.
|
listener on the given port.
|
||||||
"""
|
"""
|
||||||
|
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# inside `connecTCP` before the connection has been passed back to the
|
# inside `connecTCP` before the connection has been passed back to the
|
||||||
# code that requested the TCP connection.
|
# code that requested the TCP connection.
|
||||||
|
|
||||||
def connect_any_redis_attempts(self):
|
def connect_any_redis_attempts(self) -> None:
|
||||||
"""If redis is enabled we need to deal with workers connecting to a
|
"""If redis is enabled we need to deal with workers connecting to a
|
||||||
redis server. We don't want to use a real Redis server so we use a
|
redis server. We don't want to use a real Redis server so we use a
|
||||||
fake one.
|
fake one.
|
||||||
|
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(host, "localhost")
|
self.assertEqual(host, "localhost")
|
||||||
self.assertEqual(port, 6379)
|
self.assertEqual(port, 6379)
|
||||||
|
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
|
||||||
server_protocol = self._redis_server.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(client_address)
|
||||||
|
|
||||||
|
server_address = IPv4Address("TCP", host, port)
|
||||||
|
server_protocol = self._redis_server.buildProtocol(server_address)
|
||||||
|
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
server_protocol, self.reactor, client_protocol
|
server_protocol, self.reactor, client_protocol
|
||||||
|
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
||||||
# list of received (stream_name, token, row) tuples
|
# list of received (stream_name, token, row) tuples
|
||||||
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
|
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
|
||||||
|
|
||||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
async def on_rdata(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
|
) -> None:
|
||||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
||||||
class FakeRedisPubSubServer:
|
class FakeRedisPubSubServer:
|
||||||
"""A fake Redis server for pub/sub."""
|
"""A fake Redis server for pub/sub."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._subscribers_by_channel: Dict[
|
self._subscribers_by_channel: Dict[
|
||||||
bytes, Set["FakeRedisPubSubProtocol"]
|
bytes, Set["FakeRedisPubSubProtocol"]
|
||||||
] = defaultdict(set)
|
] = defaultdict(set)
|
||||||
|
|
||||||
def add_subscriber(self, conn, channel: bytes):
|
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
|
||||||
"""A connection has called SUBSCRIBE"""
|
"""A connection has called SUBSCRIBE"""
|
||||||
self._subscribers_by_channel[channel].add(conn)
|
self._subscribers_by_channel[channel].add(conn)
|
||||||
|
|
||||||
def remove_subscriber(self, conn):
|
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
|
||||||
"""A connection has lost connection"""
|
"""A connection has lost connection"""
|
||||||
for subscribers in self._subscribers_by_channel.values():
|
for subscribers in self._subscribers_by_channel.values():
|
||||||
subscribers.discard(conn)
|
subscribers.discard(conn)
|
||||||
|
|
||||||
def publish(self, conn, channel: bytes, msg) -> int:
|
def publish(
|
||||||
|
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
|
||||||
|
) -> int:
|
||||||
"""A connection want to publish a message to subscribers."""
|
"""A connection want to publish a message to subscribers."""
|
||||||
for sub in self._subscribers_by_channel[channel]:
|
for sub in self._subscribers_by_channel[channel]:
|
||||||
sub.send(["message", channel, msg])
|
sub.send(["message", channel, msg])
|
||||||
|
|
||||||
return len(self._subscribers_by_channel)
|
return len(self._subscribers_by_channel)
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
|
||||||
return FakeRedisPubSubProtocol(self)
|
return FakeRedisPubSubProtocol(self)
|
||||||
|
|
||||||
|
|
||||||
|
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
self._server = server
|
self._server = server
|
||||||
self._reader = hiredis.Reader()
|
self._reader = hiredis.Reader()
|
||||||
|
|
||||||
def dataReceived(self, data):
|
def dataReceived(self, data: bytes) -> None:
|
||||||
self._reader.feed(data)
|
self._reader.feed(data)
|
||||||
|
|
||||||
# We might get multiple messages in one packet.
|
# We might get multiple messages in one packet.
|
||||||
|
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
|
|
||||||
self.handle_command(msg[0], *msg[1:])
|
self.handle_command(msg[0], *msg[1:])
|
||||||
|
|
||||||
def handle_command(self, command, *args):
|
def handle_command(self, command: bytes, *args: bytes) -> None:
|
||||||
"""Received a Redis command from the client."""
|
"""Received a Redis command from the client."""
|
||||||
|
|
||||||
# We currently only support pub/sub.
|
# We currently only support pub/sub.
|
||||||
|
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
self.send("PONG")
|
self.send("PONG")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown command: {command}")
|
raise Exception(f"Unknown command: {command!r}")
|
||||||
|
|
||||||
def send(self, msg):
|
def send(self, msg: object) -> None:
|
||||||
"""Send a message back to the client."""
|
"""Send a message back to the client."""
|
||||||
assert self.transport is not None
|
assert self.transport is not None
|
||||||
|
|
||||||
|
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
self.transport.write(raw)
|
self.transport.write(raw)
|
||||||
self.transport.flush()
|
self.transport.flush()
|
||||||
|
|
||||||
def encode(self, obj):
|
def encode(self, obj: object) -> str:
|
||||||
"""Encode an object to its Redis format.
|
"""Encode an object to its Redis format.
|
||||||
|
|
||||||
Supports: strings/bytes, integers and list/tuples.
|
Supports: strings/bytes, integers and list/tuples.
|
||||||
|
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
|
|
||||||
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
|
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
self._server.remove_subscriber(self)
|
self._server.remove_subscriber(self)
|
||||||
|
|
|
@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
|
||||||
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
|
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests for `ReplicationEndpoint` cancellation."""
|
"""Tests for `ReplicationEndpoint` cancellation."""
|
||||||
|
|
||||||
def create_test_resource(self):
|
def create_test_resource(self) -> JsonResource:
|
||||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||||
resource = JsonResource(self.hs)
|
resource = JsonResource(self.hs)
|
||||||
|
|
||||||
|
|
|
@ -13,35 +13,42 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Iterable, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.replication._base import BaseStreamTestCase
|
from tests.replication._base import BaseStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
return self.setup_test_homeserver(federation_client=Mock())
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(federation_client=Mock())
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
return hs
|
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
|
||||||
super().prepare(reactor, clock, hs)
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
self.master_store = hs.get_datastores().main
|
self.master_store = hs.get_datastores().main
|
||||||
self.slaved_store = self.worker_hs.get_datastores().main
|
self.slaved_store = self.worker_hs.get_datastores().main
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
persistence = hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
|
self.persistance = persistence
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self) -> None:
|
||||||
"""Tell the master side of replication that something has happened, and then
|
"""Tell the master side of replication that something has happened, and then
|
||||||
wait for the replication to occur.
|
wait for the replication to occur.
|
||||||
"""
|
"""
|
||||||
self.streamer.on_notifier_poke()
|
self.streamer.on_notifier_poke()
|
||||||
self.pump(0.1)
|
self.pump(0.1)
|
||||||
|
|
||||||
def check(self, method, args, expected_result=None):
|
def check(
|
||||||
|
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
master_result = self.get_success(getattr(self.master_store, method)(*args))
|
master_result = self.get_success(getattr(self.master_store, method)(*args))
|
||||||
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
|
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
|
||||||
if expected_result is not None:
|
if expected_result is not None:
|
||||||
|
|
|
@ -12,15 +12,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, Optional
|
from typing import Any, Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import ReceiptTypes
|
from synapse.api.constants import ReceiptTypes
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
|
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.handlers.room import RoomEventSource
|
from synapse.handlers.room import RoomEventSource
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.event_push_actions import (
|
from synapse.storage.databases.main.event_push_actions import (
|
||||||
NotifCounts,
|
NotifCounts,
|
||||||
RoomNotifCounts,
|
RoomNotifCounts,
|
||||||
|
@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
|
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
|
||||||
from synapse.types import PersistedEventPosition
|
from synapse.types import PersistedEventPosition
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
|
|
||||||
|
@ -41,19 +46,19 @@ ROOM_ID = "!room:test"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dict_equals(self, other):
|
def dict_equals(self: EventBase, other: EventBase) -> bool:
|
||||||
me = encode_canonical_json(self.get_pdu_json())
|
me = encode_canonical_json(self.get_pdu_json())
|
||||||
them = encode_canonical_json(other.get_pdu_json())
|
them = encode_canonical_json(other.get_pdu_json())
|
||||||
return me == them
|
return me == them
|
||||||
|
|
||||||
|
|
||||||
def patch__eq__(cls):
|
def patch__eq__(cls: object) -> Callable[[], None]:
|
||||||
eq = getattr(cls, "__eq__", None)
|
eq = getattr(cls, "__eq__", None)
|
||||||
cls.__eq__ = dict_equals
|
cls.__eq__ = dict_equals # type: ignore[assignment]
|
||||||
|
|
||||||
def unpatch():
|
def unpatch() -> None:
|
||||||
if eq is not None:
|
if eq is not None:
|
||||||
cls.__eq__ = eq
|
cls.__eq__ = eq # type: ignore[assignment]
|
||||||
|
|
||||||
return unpatch
|
return unpatch
|
||||||
|
|
||||||
|
@ -62,14 +67,14 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
STORE_TYPE = EventsWorkerStore
|
STORE_TYPE = EventsWorkerStore
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
# Patch up the equality operator for events so that we can check
|
# Patch up the equality operator for events so that we can check
|
||||||
# whether lists of events match using assertEqual
|
# whether lists of events match using assertEqual
|
||||||
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
|
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
|
||||||
return super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def prepare(self, *args, **kwargs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
super().prepare(*args, **kwargs)
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.master_store.store_room(
|
self.master_store.store_room(
|
||||||
|
@ -80,10 +85,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
[unpatch() for unpatch in self.unpatches]
|
[unpatch() for unpatch in self.unpatches]
|
||||||
|
|
||||||
def test_get_latest_event_ids_in_room(self):
|
def test_get_latest_event_ids_in_room(self) -> None:
|
||||||
create = self.persist(type="m.room.create", key="", creator=USER_ID)
|
create = self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||||
self.replicate()
|
self.replicate()
|
||||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||||
|
@ -97,7 +102,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
self.replicate()
|
self.replicate()
|
||||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||||
|
|
||||||
def test_redactions(self):
|
def test_redactions(self) -> None:
|
||||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||||
|
|
||||||
|
@ -117,7 +122,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
self.check("get_event", [msg.event_id], redacted)
|
self.check("get_event", [msg.event_id], redacted)
|
||||||
|
|
||||||
def test_backfilled_redactions(self):
|
def test_backfilled_redactions(self) -> None:
|
||||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||||
|
|
||||||
|
@ -139,7 +144,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
self.check("get_event", [msg.event_id], redacted)
|
self.check("get_event", [msg.event_id], redacted)
|
||||||
|
|
||||||
def test_invites(self):
|
def test_invites(self) -> None:
|
||||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||||
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
|
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
|
||||||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
||||||
|
@ -163,7 +168,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([(True,), (False,)])
|
@parameterized.expand([(True,), (False,)])
|
||||||
def test_push_actions_for_user(self, send_receipt: bool):
|
def test_push_actions_for_user(self, send_receipt: bool) -> None:
|
||||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||||
self.persist(
|
self.persist(
|
||||||
|
@ -219,7 +224,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_rooms_for_user_with_stream_ordering(self):
|
def test_get_rooms_for_user_with_stream_ordering(self) -> None:
|
||||||
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
|
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
|
||||||
by rows in the events stream
|
by rows in the events stream
|
||||||
"""
|
"""
|
||||||
|
@ -243,7 +248,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
|
{GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
|
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""Check that current_state invalidation happens correctly with multiple events
|
"""Check that current_state invalidation happens correctly with multiple events
|
||||||
in the persistence batch.
|
in the persistence batch.
|
||||||
|
|
||||||
|
@ -283,11 +290,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
||||||
)
|
)
|
||||||
msg, msgctx = self.build_event()
|
msg, msgctx = self.build_event()
|
||||||
self.get_success(
|
self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
|
||||||
self._storage_controllers.persistence.persist_events(
|
|
||||||
[(j2, j2ctx), (msg, msgctx)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.replicate()
|
self.replicate()
|
||||||
assert j2.internal_metadata.stream_ordering is not None
|
assert j2.internal_metadata.stream_ordering is not None
|
||||||
|
|
||||||
|
@ -339,7 +342,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
event_id = 0
|
event_id = 0
|
||||||
|
|
||||||
def persist(self, backfill=False, **kwargs) -> FrozenEvent:
|
def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
The event that was persisted.
|
The event that was persisted.
|
||||||
|
@ -348,32 +351,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
if backfill:
|
if backfill:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self._storage_controllers.persistence.persist_events(
|
self.persistance.persist_events([(event, context)], backfilled=True)
|
||||||
[(event, context)], backfilled=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.get_success(
|
self.get_success(self.persistance.persist_event(event, context))
|
||||||
self._storage_controllers.persistence.persist_event(event, context)
|
|
||||||
)
|
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def build_event(
|
def build_event(
|
||||||
self,
|
self,
|
||||||
sender=USER_ID,
|
sender: str = USER_ID,
|
||||||
room_id=ROOM_ID,
|
room_id: str = ROOM_ID,
|
||||||
type="m.room.message",
|
type: str = "m.room.message",
|
||||||
key=None,
|
key: Optional[str] = None,
|
||||||
internal: Optional[dict] = None,
|
internal: Optional[dict] = None,
|
||||||
depth=None,
|
depth: Optional[int] = None,
|
||||||
prev_events: Optional[list] = None,
|
prev_events: Optional[List[Tuple[str, dict]]] = None,
|
||||||
auth_events: Optional[list] = None,
|
auth_events: Optional[List[str]] = None,
|
||||||
prev_state: Optional[list] = None,
|
prev_state: Optional[List[str]] = None,
|
||||||
redacts=None,
|
redacts: Optional[str] = None,
|
||||||
push_actions: Iterable = frozenset(),
|
push_actions: Iterable = frozenset(),
|
||||||
**content,
|
**content: object,
|
||||||
):
|
) -> Tuple[EventBase, EventContext]:
|
||||||
prev_events = prev_events or []
|
prev_events = prev_events or []
|
||||||
auth_events = auth_events or []
|
auth_events = auth_events or []
|
||||||
prev_state = prev_state or []
|
prev_state = prev_state or []
|
||||||
|
|
|
@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
class AccountDataStreamTestCase(BaseStreamTestCase):
|
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||||
def test_update_function_room_account_data_limit(self):
|
def test_update_function_room_account_data_limit(self) -> None:
|
||||||
"""Test replication with many room account data updates"""
|
"""Test replication with many room account data updates"""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
self.assertEqual([], received_rows)
|
self.assertEqual([], received_rows)
|
||||||
|
|
||||||
def test_update_function_global_account_data_limit(self):
|
def test_update_function_global_account_data_limit(self) -> None:
|
||||||
"""Test replication with many global account data updates"""
|
"""Test replication with many global account data updates"""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
|
||||||
)
|
)
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.replication._base import BaseStreamTestCase
|
from tests.replication._base import BaseStreamTestCase
|
||||||
from tests.test_utils.event_injection import inject_event, inject_member_event
|
from tests.test_utils.event_injection import inject_event, inject_member_event
|
||||||
|
@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
super().prepare(reactor, clock, hs)
|
super().prepare(reactor, clock, hs)
|
||||||
self.user_id = self.register_user("u1", "pass")
|
self.user_id = self.register_user("u1", "pass")
|
||||||
self.user_tok = self.login("u1", "pass")
|
self.user_tok = self.login("u1", "pass")
|
||||||
|
@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||||
self.test_handler.received_rdata_rows.clear()
|
self.test_handler.received_rdata_rows.clear()
|
||||||
|
|
||||||
def test_update_function_event_row_limit(self):
|
def test_update_function_event_row_limit(self) -> None:
|
||||||
"""Test replication with many non-state events
|
"""Test replication with many non-state events
|
||||||
|
|
||||||
Checks that all events are correctly replicated when there are lots of
|
Checks that all events are correctly replicated when there are lots of
|
||||||
|
@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
self.assertEqual([], received_rows)
|
self.assertEqual([], received_rows)
|
||||||
|
|
||||||
def test_update_function_huge_state_change(self):
|
def test_update_function_huge_state_change(self) -> None:
|
||||||
"""Test replication with many state events
|
"""Test replication with many state events
|
||||||
|
|
||||||
Ensures that all events are correctly replicated when there are lots of
|
Ensures that all events are correctly replicated when there are lots of
|
||||||
|
@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
# "None" indicates the state has been deleted
|
# "None" indicates the state has been deleted
|
||||||
self.assertIsNone(sr.event_id)
|
self.assertIsNone(sr.event_id)
|
||||||
|
|
||||||
def test_update_function_state_row_limit(self):
|
def test_update_function_state_row_limit(self) -> None:
|
||||||
"""Test replication with many state events over several stream ids."""
|
"""Test replication with many state events over several stream ids."""
|
||||||
|
|
||||||
# we want to generate lots of state changes, but for this test, we want to
|
# we want to generate lots of state changes, but for this test, we want to
|
||||||
|
@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
self.assertEqual([], received_rows)
|
self.assertEqual([], received_rows)
|
||||||
|
|
||||||
def test_backwards_stream_id(self):
|
def test_backwards_stream_id(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that RDATA that comes after the current position should be discarded.
|
Test that RDATA that comes after the current position should be discarded.
|
||||||
"""
|
"""
|
||||||
|
@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
event_count = 0
|
event_count = 0
|
||||||
|
|
||||||
def _inject_test_event(
|
def _inject_test_event(
|
||||||
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
|
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
if sender is None:
|
if sender is None:
|
||||||
sender = self.user_id
|
sender = self.user_id
|
||||||
|
|
|
@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
|
||||||
config["federation_sender_instances"] = ["federation_sender1"]
|
config["federation_sender_instances"] = ["federation_sender1"]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_catchup(self):
|
def test_catchup(self) -> None:
|
||||||
"""Basic test of catchup on reconnect
|
"""Basic test of catchup on reconnect
|
||||||
|
|
||||||
Makes sure that updates sent while we are offline are received later.
|
Makes sure that updates sent while we are offline are received later.
|
||||||
|
|
|
@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
hijack_auth = True
|
hijack_auth = True
|
||||||
user_id = "@bob:test"
|
user_id = "@bob:test"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
|
||||||
|
|
||||||
|
|
||||||
class TypingStreamTestCase(BaseStreamTestCase):
|
class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
def _build_replication_data_handler(self):
|
def _build_replication_data_handler(self) -> Mock:
|
||||||
return Mock(wraps=super()._build_replication_data_handler())
|
self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
|
||||||
|
return self.mock_handler
|
||||||
|
|
||||||
def test_typing(self):
|
def test_typing(self) -> None:
|
||||||
typing = self.hs.get_typing_handler()
|
typing = self.hs.get_typing_handler()
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
request = self.handle_http_replication_attempt()
|
request = self.handle_http_replication_attempt()
|
||||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.mock_handler.on_rdata.assert_called_once()
|
||||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
||||||
|
@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
# Now let's disconnect and insert some data.
|
# Now let's disconnect and insert some data.
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
self.test_handler.on_rdata.reset_mock()
|
self.mock_handler.on_rdata.reset_mock()
|
||||||
|
|
||||||
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
|
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_not_called()
|
self.mock_handler.on_rdata.assert_not_called()
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.pump(0.1)
|
self.pump(0.1)
|
||||||
|
@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
assert request.args is not None
|
assert request.args is not None
|
||||||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.mock_handler.on_rdata.assert_called_once()
|
||||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row = rdata_rows[0]
|
row = rdata_rows[0]
|
||||||
self.assertEqual(ROOM_ID, row.room_id)
|
self.assertEqual(ROOM_ID, row.room_id)
|
||||||
self.assertEqual([], row.user_ids)
|
self.assertEqual([], row.user_ids)
|
||||||
|
|
||||||
def test_reset(self):
|
def test_reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test what happens when a typing stream resets.
|
Test what happens when a typing stream resets.
|
||||||
|
|
||||||
|
@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
request = self.handle_http_replication_attempt()
|
request = self.handle_http_replication_attempt()
|
||||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.mock_handler.on_rdata.assert_called_once()
|
||||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
row: TypingStream.TypingStreamRow = rdata_rows[0]
|
||||||
|
@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||||
|
|
||||||
# Reset the test code.
|
# Reset the test code.
|
||||||
self.test_handler.on_rdata.reset_mock()
|
self.mock_handler.on_rdata.reset_mock()
|
||||||
self.test_handler.on_rdata.assert_not_called()
|
self.mock_handler.on_rdata.assert_not_called()
|
||||||
|
|
||||||
# Push additional data.
|
# Push additional data.
|
||||||
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
|
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.mock_handler.on_rdata.assert_called_once()
|
||||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row = rdata_rows[0]
|
row = rdata_rows[0]
|
||||||
|
|
|
@ -21,12 +21,12 @@ from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
class ParseCommandTestCase(TestCase):
|
class ParseCommandTestCase(TestCase):
|
||||||
def test_parse_one_word_command(self):
|
def test_parse_one_word_command(self) -> None:
|
||||||
line = "REPLICATE"
|
line = "REPLICATE"
|
||||||
cmd = parse_command_from_line(line)
|
cmd = parse_command_from_line(line)
|
||||||
self.assertIsInstance(cmd, ReplicateCommand)
|
self.assertIsInstance(cmd, ReplicateCommand)
|
||||||
|
|
||||||
def test_parse_rdata(self):
|
def test_parse_rdata(self) -> None:
|
||||||
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
||||||
cmd = parse_command_from_line(line)
|
cmd = parse_command_from_line(line)
|
||||||
assert isinstance(cmd, RdataCommand)
|
assert isinstance(cmd, RdataCommand)
|
||||||
|
@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
|
||||||
self.assertEqual(cmd.instance_name, "master")
|
self.assertEqual(cmd.instance_name, "master")
|
||||||
self.assertEqual(cmd.token, 6287863)
|
self.assertEqual(cmd.token, 6287863)
|
||||||
|
|
||||||
def test_parse_rdata_batch(self):
|
def test_parse_rdata_batch(self) -> None:
|
||||||
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
|
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
|
||||||
cmd = parse_command_from_line(line)
|
cmd = parse_command_from_line(line)
|
||||||
assert isinstance(cmd, RdataCommand)
|
assert isinstance(cmd, RdataCommand)
|
||||||
|
|
|
@ -16,15 +16,17 @@ from typing import Tuple
|
||||||
|
|
||||||
from twisted.internet.address import IPv4Address
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.interfaces import IProtocol
|
from twisted.internet.interfaces import IProtocol
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import MemoryReactor, StringTransport
|
||||||
|
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class RemoteServerUpTestCase(HomeserverTestCase):
|
class RemoteServerUpTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.factory = ReplicationStreamProtocolFactory(hs)
|
self.factory = ReplicationStreamProtocolFactory(hs)
|
||||||
|
|
||||||
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
||||||
|
@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return proto, transport
|
return proto, transport
|
||||||
|
|
||||||
def test_relay(self):
|
def test_relay(self) -> None:
|
||||||
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
|
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
|
||||||
other connections, but not the one that sent it.
|
other connections, but not the one that sent it.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,7 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest.client import register
|
from synapse.rest.client import register
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.server import FakeChannel, make_request
|
from tests.server import FakeChannel, make_request
|
||||||
|
@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
servlets = [register.register_servlets]
|
servlets = [register.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
# This isn't a real configuration option but is used to provide the main
|
# This isn't a real configuration option but is used to provide the main
|
||||||
# homeserver and worker homeserver different options.
|
# homeserver and worker homeserver different options.
|
||||||
|
@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
{"auth": {"session": session, "type": "m.login.dummy"}},
|
{"auth": {"session": session, "type": "m.login.dummy"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_auth(self):
|
def test_no_auth(self) -> None:
|
||||||
"""With no authentication the request should finish."""
|
"""With no authentication the request should finish."""
|
||||||
channel = self._test_register()
|
channel = self._test_register()
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||||
|
|
||||||
@override_config({"main_replication_secret": "my-secret"})
|
@override_config({"main_replication_secret": "my-secret"})
|
||||||
def test_missing_auth(self):
|
def test_missing_auth(self) -> None:
|
||||||
"""If the main process expects a secret that is not provided, an error results."""
|
"""If the main process expects a secret that is not provided, an error results."""
|
||||||
channel = self._test_register()
|
channel = self._test_register()
|
||||||
self.assertEqual(channel.code, 500)
|
self.assertEqual(channel.code, 500)
|
||||||
|
@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
"worker_replication_secret": "wrong-secret",
|
"worker_replication_secret": "wrong-secret",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_unauthorized(self):
|
def test_unauthorized(self) -> None:
|
||||||
"""If the main process receives the wrong secret, an error results."""
|
"""If the main process receives the wrong secret, an error results."""
|
||||||
channel = self._test_register()
|
channel = self._test_register()
|
||||||
self.assertEqual(channel.code, 500)
|
self.assertEqual(channel.code, 500)
|
||||||
|
|
||||||
@override_config({"worker_replication_secret": "my-secret"})
|
@override_config({"worker_replication_secret": "my-secret"})
|
||||||
def test_authorized(self):
|
def test_authorized(self) -> None:
|
||||||
"""The request should finish when the worker provides the authentication header."""
|
"""The request should finish when the worker provides the authentication header."""
|
||||||
channel = self._test_register()
|
channel = self._test_register()
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
config["worker_replication_http_port"] = "8765"
|
config["worker_replication_http_port"] = "8765"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_register_single_worker(self):
|
def test_register_single_worker(self) -> None:
|
||||||
"""Test that registration works when using a single generic worker."""
|
"""Test that registration works when using a single generic worker."""
|
||||||
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
site = self._hs_to_site[worker_hs]
|
site = self._hs_to_site[worker_hs]
|
||||||
|
@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
# We're given a registered user.
|
# We're given a registered user.
|
||||||
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
|
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
|
||||||
|
|
||||||
def test_register_multi_worker(self):
|
def test_register_multi_worker(self) -> None:
|
||||||
"""Test that registration works when using multiple generic workers."""
|
"""Test that registration works when using multiple generic workers."""
|
||||||
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
|
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
|
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
|
@ -14,10 +14,14 @@
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
from synapse.replication.tcp.commands import FederationAckCommand
|
from synapse.replication.tcp.commands import FederationAckCommand
|
||||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||||
from synapse.replication.tcp.streams.federation import FederationStream
|
from synapse.replication.tcp.streams.federation import FederationStream
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
||||||
config["federation_sender_instances"] = ["federation_sender1"]
|
config["federation_sender_instances"] = ["federation_sender1"]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
|
return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
|
||||||
|
|
||||||
return hs
|
def test_federation_ack_sent(self) -> None:
|
||||||
|
|
||||||
def test_federation_ack_sent(self):
|
|
||||||
"""A FEDERATION_ACK should be sent back after each RDATA federation
|
"""A FEDERATION_ACK should be sent back after each RDATA federation
|
||||||
|
|
||||||
This test checks that the federation sender is correctly sending back
|
This test checks that the federation sender is correctly sending back
|
||||||
|
|
|
@ -40,7 +40,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_send_event_single_sender(self):
|
def test_send_event_single_sender(self) -> None:
|
||||||
"""Test that using a single federation sender worker correctly sends a
|
"""Test that using a single federation sender worker correctly sends a
|
||||||
new event.
|
new event.
|
||||||
"""
|
"""
|
||||||
|
@ -71,7 +71,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
|
self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
|
||||||
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
|
self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
|
||||||
|
|
||||||
def test_send_event_sharded(self):
|
def test_send_event_sharded(self) -> None:
|
||||||
"""Test that using two federation sender workers correctly sends
|
"""Test that using two federation sender workers correctly sends
|
||||||
new events.
|
new events.
|
||||||
"""
|
"""
|
||||||
|
@ -138,7 +138,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertTrue(sent_on_1)
|
self.assertTrue(sent_on_1)
|
||||||
self.assertTrue(sent_on_2)
|
self.assertTrue(sent_on_2)
|
||||||
|
|
||||||
def test_send_typing_sharded(self):
|
def test_send_typing_sharded(self) -> None:
|
||||||
"""Test that using two federation sender workers correctly sends
|
"""Test that using two federation sender workers correctly sends
|
||||||
new typing EDUs.
|
new typing EDUs.
|
||||||
"""
|
"""
|
||||||
|
@ -215,7 +215,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertTrue(sent_on_1)
|
self.assertTrue(sent_on_1)
|
||||||
self.assertTrue(sent_on_2)
|
self.assertTrue(sent_on_2)
|
||||||
|
|
||||||
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
|
def create_room_with_remote_server(
|
||||||
|
self, user: str, token: str, remote_server: str = "other_server"
|
||||||
|
) -> str:
|
||||||
room = self.helper.create_room_as(user, tok=token)
|
room = self.helper.create_room_as(user, tok=token)
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
federation = self.hs.get_federation_event_handler()
|
federation = self.hs.get_federation_event_handler()
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
synapse.rest.admin.register_servlets,
|
synapse.rest.admin.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_module_cache_full_invalidation(self):
|
def test_module_cache_full_invalidation(self) -> None:
|
||||||
main_cache = TestCache()
|
main_cache = TestCache()
|
||||||
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
||||||
|
|
||||||
|
|
|
@ -18,12 +18,14 @@ from typing import Optional, Tuple
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login
|
from synapse.rest.client import login
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_id = self.register_user("user", "pass")
|
self.user_id = self.register_user("user", "pass")
|
||||||
self.access_token = self.login("user", "pass")
|
self.access_token = self.login("user", "pass")
|
||||||
|
|
||||||
self.reactor.lookups["example.com"] = "1.2.3.4"
|
self.reactor.lookups["example.com"] = "1.2.3.4"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> dict:
|
||||||
conf = super().default_config()
|
conf = super().default_config()
|
||||||
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||||
return conf
|
return conf
|
||||||
|
@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
return channel, request
|
return channel, request
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self) -> None:
|
||||||
"""Test basic fetching of remote media from a single worker."""
|
"""Test basic fetching of remote media from a single worker."""
|
||||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
||||||
|
@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.result["body"], b"Hello!")
|
self.assertEqual(channel.result["body"], b"Hello!")
|
||||||
|
|
||||||
def test_download_simple_file_race(self):
|
def test_download_simple_file_race(self) -> None:
|
||||||
"""Test that fetching remote media from two different processes at the
|
"""Test that fetching remote media from two different processes at the
|
||||||
same time works.
|
same time works.
|
||||||
"""
|
"""
|
||||||
|
@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
# We expect only one new file to have been persisted.
|
# We expect only one new file to have been persisted.
|
||||||
self.assertEqual(start_count + 1, self._count_remote_media())
|
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||||
|
|
||||||
def test_download_image_race(self):
|
def test_download_image_race(self) -> None:
|
||||||
"""Test that fetching remote *images* from two different processes at
|
"""Test that fetching remote *images* from two different processes at
|
||||||
the same time works.
|
the same time works.
|
||||||
|
|
||||||
|
@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
return sum(len(files) for _, _, files in os.walk(path))
|
return sum(len(files) for _, _, files in os.walk(path))
|
||||||
|
|
||||||
|
|
||||||
def get_connection_factory():
|
def get_connection_factory() -> TestServerTLSConnectionFactory:
|
||||||
# this needs to happen once, but not until we are ready to run the first test
|
# this needs to happen once, but not until we are ready to run the first test
|
||||||
global test_server_connection_factory
|
global test_server_connection_factory
|
||||||
if test_server_connection_factory is None:
|
if test_server_connection_factory is None:
|
||||||
|
@ -263,6 +265,6 @@ def _build_test_server(
|
||||||
return server_tls_factory.buildProtocol(None)
|
return server_tls_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
|
||||||
def _log_request(request):
|
def _log_request(request: Request) -> None:
|
||||||
"""Implements Factory.log, which is expected by Request.finish"""
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
logger.info("Completed request %s", request)
|
logger.info("Completed request %s", request)
|
||||||
|
|
|
@ -15,9 +15,12 @@ import logging
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
|
||||||
|
@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# Register a user who sends a message that we'll get notified about
|
# Register a user who sends a message that we'll get notified about
|
||||||
self.other_user_id = self.register_user("otheruser", "pass")
|
self.other_user_id = self.register_user("otheruser", "pass")
|
||||||
self.other_access_token = self.login("otheruser", "pass")
|
self.other_access_token = self.login("otheruser", "pass")
|
||||||
|
|
||||||
def _create_pusher_and_send_msg(self, localpart):
|
def _create_pusher_and_send_msg(self, localpart: str) -> str:
|
||||||
# Create a user that will get push notifications
|
# Create a user that will get push notifications
|
||||||
user_id = self.register_user(localpart, "pass")
|
user_id = self.register_user(localpart, "pass")
|
||||||
access_token = self.login(localpart, "pass")
|
access_token = self.login(localpart, "pass")
|
||||||
|
@ -79,7 +82,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def test_send_push_single_worker(self):
|
def test_send_push_single_worker(self) -> None:
|
||||||
"""Test that registration works when using a pusher worker."""
|
"""Test that registration works when using a pusher worker."""
|
||||||
http_client_mock = Mock(spec_set=["post_json_get_json"])
|
http_client_mock = Mock(spec_set=["post_json_get_json"])
|
||||||
http_client_mock.post_json_get_json.side_effect = (
|
http_client_mock.post_json_get_json.side_effect = (
|
||||||
|
@ -109,7 +112,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_send_push_multiple_workers(self):
|
def test_send_push_multiple_workers(self) -> None:
|
||||||
"""Test that registration works when using sharded pusher workers."""
|
"""Test that registration works when using sharded pusher workers."""
|
||||||
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
|
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
|
||||||
http_client_mock1.post_json_get_json.side_effect = (
|
http_client_mock1.post_json_get_json.side_effect = (
|
||||||
|
|
|
@ -14,9 +14,13 @@
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room, sync
|
from synapse.rest.client import login, room, sync
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.server import make_request
|
from tests.server import make_request
|
||||||
|
@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# Register a user who sends a message that we'll get notified about
|
# Register a user who sends a message that we'll get notified about
|
||||||
self.other_user_id = self.register_user("otheruser", "pass")
|
self.other_user_id = self.register_user("otheruser", "pass")
|
||||||
self.other_access_token = self.login("otheruser", "pass")
|
self.other_access_token = self.login("otheruser", "pass")
|
||||||
|
@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.room_creator = self.hs.get_room_creation_handler()
|
self.room_creator = self.hs.get_room_creation_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> dict:
|
||||||
conf = super().default_config()
|
conf = super().default_config()
|
||||||
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
|
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
|
||||||
conf["instance_map"] = {
|
conf["instance_map"] = {
|
||||||
|
@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
}
|
}
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
def _create_room(self, room_id: str, user_id: str, tok: str):
|
def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
|
||||||
"""Create a room with given room_id"""
|
"""Create a room with given room_id"""
|
||||||
|
|
||||||
# We control the room ID generation by patching out the
|
# We control the room ID generation by patching out the
|
||||||
|
@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
mock.side_effect = lambda: room_id
|
mock.side_effect = lambda: room_id
|
||||||
self.helper.create_room_as(user_id, tok=tok)
|
self.helper.create_room_as(user_id, tok=tok)
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self) -> None:
|
||||||
"""Simple test to ensure that multiple rooms can be created and joined,
|
"""Simple test to ensure that multiple rooms can be created and joined,
|
||||||
and that different rooms get handled by different instances.
|
and that different rooms get handled by different instances.
|
||||||
"""
|
"""
|
||||||
|
@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertTrue(persisted_on_1)
|
self.assertTrue(persisted_on_1)
|
||||||
self.assertTrue(persisted_on_2)
|
self.assertTrue(persisted_on_2)
|
||||||
|
|
||||||
def test_vector_clock_token(self):
|
def test_vector_clock_token(self) -> None:
|
||||||
"""Tests that using a stream token with a vector clock component works
|
"""Tests that using a stream token with a vector clock component works
|
||||||
correctly with basic /sync and /messages usage.
|
correctly with basic /sync and /messages usage.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue