Require types in tests.storage. (#14646)

Adds missing type hints to `tests.storage` package
and does not allow untyped definitions.
This commit is contained in:
Patrick Cloke 2022-12-09 12:36:32 -05:00 committed by GitHub
parent 94bc21e69f
commit 3ac412b4e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 489 additions and 341 deletions

1
changelog.d/14646.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -88,6 +88,9 @@ disallow_untyped_defs = False
[mypy-tests.*] [mypy-tests.*]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True
[mypy-tests.handlers.test_user_directory] [mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -103,16 +106,7 @@ disallow_untyped_defs = True
[mypy-tests.state.test_profile] [mypy-tests.state.test_profile]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.storage.test_id_generators] [mypy-tests.storage.*]
disallow_untyped_defs = True
[mypy-tests.storage.test_profile]
disallow_untyped_defs = True
[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.rest.*] [mypy-tests.rest.*]

View File

@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable @cancellable
async def get_e2e_device_keys_for_cs_api( async def get_e2e_device_keys_for_cs_api(
self, self,
query_list: List[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
include_displaynames: bool = True, include_displaynames: bool = True,
) -> Dict[str, Dict[str, JsonDict]]: ) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API. """Fetch a list of device keys, formatted suitably for the C/S API.

View File

@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import devices from synapse.rest.client import devices
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
devices.register_servlets, devices.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass") self.user_id = self.register_user("foo", "pass")
def test_background_remove_deleted_devices_from_device_inbox(self): def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete old device_inboxes works properly.""" """Test that the background task to delete old device_inboxes works properly."""
# create a valid device # create a valid device
@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(1, len(res)) self.assertEqual(1, len(res))
self.assertEqual(res[0], "cur_device") self.assertEqual(res[0], "cur_device")
def test_background_remove_hidden_devices_from_device_inbox(self): def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete hidden devices """Test that the background task to delete hidden devices
from device_inboxes works properly.""" from device_inboxes works properly."""

View File

@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main self.store: EventsWorkerStore = hs.get_datastores().main
@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.event_ids.append(event.event_id) self.event_ids.append(event.event_id)
def test_simple(self): def test_simple(self) -> None:
with LoggingContext(name="test") as ctx: with LoggingContext(name="test") as ctx:
res = self.get_success( res = self.get_success(
self.store.have_seen_events( self.store.have_seen_events(
@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
def test_persisting_event_invalidates_cache(self): def test_persisting_event_invalidates_cache(self) -> None:
""" """
Test to make sure that the `have_seen_event` cache Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns is invalidated after we persist an event and returns
@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup # That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
def test_invalidate_cache_by_room_id(self): def test_invalidate_cache_by_room_id(self) -> None:
""" """
Test to make sure that all events associated with the given `(room_id,)` Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache. are invalidated in the `have_seen_event` cache.
@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache so the tests start with it empty # Reset the event cache so the tests start with it empty
self.get_success(self.store._get_event_cache.clear()) self.get_success(self.store._get_event_cache.clear())
def test_simple(self): def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB.""" """Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx: with LoggingContext("test") as ctx:
@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB # We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
def test_event_ref(self): def test_event_ref(self) -> None:
"""Test that we reuse events that are still in memory but have fallen """Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB. out of the cache, rather than requesting them from the DB.
""" """
@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# from the DB # from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
def test_dedupe(self): def test_dedupe(self) -> None:
"""Test that if we request the same event multiple times we only pull it """Test that if we request the same event multiple times we only pull it
out once. out once.
""" """
@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
class DatabaseOutageTestCase(unittest.HomeserverTestCase): class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage.""" """Test event fetching during a database outage."""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main self.store: EventsWorkerStore = hs.get_datastores().main
self.room_id = f"!room:{hs.hostname}" self.room_id = f"!room:{hs.hostname}"
@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
unblock: "Deferred[None]" = Deferred() unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection original_runWithConnection = self.store.db_pool.runWithConnection
async def runWithConnection(*args, **kwargs): # Don't bother with the types here, we just pass into the original function.
async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
await unblock await unblock
return await original_runWithConnection(*args, **kwargs) return await original_runWithConnection(*args, **kwargs)
@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
def test_first_get_event_cancelled(self): def test_first_get_event_cancelled(self) -> None:
"""Test cancellation of the first `get_event` call sharing a database fetch. """Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the The first `get_event` call is the one which initiates the fetch. We expect the
@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
# The second `get_event` call should complete successfully. # The second `get_event` call should complete successfully.
self.get_success(get_event2) self.get_success(get_event2)
def test_second_get_event_cancelled(self): def test_second_get_event_cancelled(self) -> None:
"""Test cancellation of the second `get_event` call sharing a database fetch.""" """Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call. # Cancel the second `get_event` call.

View File

@ -15,18 +15,20 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
from synapse.util import Clock
from tests import unittest from tests import unittest
class LockTestCase(unittest.HomeserverTestCase): class LockTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def test_acquire_contention(self): def test_acquire_contention(self) -> None:
# Track the number of tasks holding the lock. # Track the number of tasks holding the lock.
# Should be at most 1. # Should be at most 1.
in_lock = 0 in_lock = 0
@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):
release_lock: "Deferred[None]" = Deferred() release_lock: "Deferred[None]" = Deferred()
async def task(): async def task() -> None:
nonlocal in_lock nonlocal in_lock
nonlocal max_in_lock nonlocal max_in_lock
@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# At most one task should have held the lock at a time. # At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1) self.assertEqual(max_in_lock, 1)
def test_simple_lock(self): def test_simple_lock(self) -> None:
"""Test that we can take out a lock and that while we hold it nobody """Test that we can take out a lock and that while we hold it nobody
else can take it out. else can take it out.
""" """
@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock3.__aenter__()) self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None)) self.get_success(lock3.__aexit__(None, None, None))
def test_maintain_lock(self): def test_maintain_lock(self) -> None:
"""Test that we don't time out locks while they're still active""" """Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aexit__(None, None, None)) self.get_success(lock.__aexit__(None, None, None))
def test_timeout_lock(self): def test_timeout_lock(self) -> None:
"""Test that we time out locks if they're not updated for ages""" """Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.assertFalse(self.get_success(lock.is_still_valid())) self.assertFalse(self.get_success(lock.is_still_valid()))
def test_drop(self): def test_drop(self) -> None:
"""Test that dropping the context manager means we stop renewing the lock""" """Test that dropping the context manager means we stop renewing the lock"""
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock2) self.assertIsNotNone(lock2)
def test_shutdown(self): def test_shutdown(self) -> None:
"""Test that shutting down Synapse releases the locks""" """Test that shutting down Synapse releases the locks"""
# Acquire two locks # Acquire two locks
lock = self.get_success(self.store.try_acquire_lock("name", "key1")) lock = self.get_success(self.store.try_acquire_lock("name", "key1"))

View File

@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass") self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass") self.token = self.login("foo", "pass")
@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
table: str, table: str,
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
): ) -> None:
"""Test that the background update to uniqueify non-thread receipts in """Test that the background update to uniqueify non-thread receipts in
the given receipts table works properly. the given receipts table works properly.
@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
f"Background update did not remove all duplicate receipts from {table}", f"Background update did not remove all duplicate receipts from {table}",
) )
def test_background_receipts_linearized_unique_index(self): def test_background_receipts_linearized_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in """Test that the background update to uniqueify non-thread receipts in
`receipts_linearized` works properly. `receipts_linearized` works properly.
""" """
@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
}, },
) )
def test_background_receipts_graph_unique_index(self): def test_background_receipts_graph_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in """Test that the background update to uniqueify non-thread receipts in
`receipts_graph` works properly. `receipts_graph` works properly.
""" """

View File

@ -14,10 +14,14 @@
import json import json
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomTypes from synapse.api.constants import RoomTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.room import _BackgroundUpdates from synapse.storage.databases.main.room import _BackgroundUpdates
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass") self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass") self.token = self.login("foo", "pass")
@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
return room_id return room_id
def test_background_populate_rooms_creator_column(self): def test_background_populate_rooms_creator_column(self) -> None:
"""Test that the background update to populate the rooms creator column """Test that the background update to populate the rooms creator column
works properly. works properly.
""" """
@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
) )
self.assertEqual(room_creator_after, self.user_id) self.assertEqual(room_creator_after, self.user_id)
def test_background_add_room_type_column(self): def test_background_add_room_type_column(self) -> None:
"""Test that the background update to populate the `room_type` column in """Test that the background update to populate the `room_type` column in
`room_stats_state` works properly. `room_stats_state` works properly.
""" """

View File

@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")}, {(1, "user1", "hello"), (2, "user2", "bleb")},
) )
def test_simple_update_many(self): def test_simple_update_many(self) -> None:
""" """
simple_update_many performs many updates at once. simple_update_many performs many updates at once.
""" """

View File

@ -14,13 +14,17 @@
from typing import Iterable, Optional, Set from typing import Iterable, Optional, Set
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
class IgnoredUsersTestCase(unittest.HomeserverTestCase): class IgnoredUsersTestCase(unittest.HomeserverTestCase):
def prepare(self, hs, reactor, clock): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.user = "@user:test" self.user = "@user:test"
@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignored_user_ids, expected_ignored_user_ids,
) )
def test_ignoring_users(self): def test_ignoring_users(self) -> None:
"""Basic adding/removing of users from the ignore list.""" """Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote") self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"}) self.assert_ignored(self.user, {"@other:test", "@another:remote"})
@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user. # Check the removed user.
self.assert_ignorers("@another:remote", {self.user}) self.assert_ignorers("@another:remote", {self.user})
def test_caching(self): def test_caching(self) -> None:
"""Ensure that caching works properly between different users.""" """Ensure that caching works properly between different users."""
# The first user ignores a user. # The first user ignores a user.
self._update_ignore_list("@other:test") self._update_ignore_list("@other:test")
@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
self.assert_ignored(self.user, set()) self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"}) self.assert_ignorers("@other:test", {"@second:test"})
def test_invalid_data(self): def test_invalid_data(self) -> None:
"""Invalid data ends up clearing out the ignored users list.""" """Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there. # Add some data and ensure it is there.
self._update_ignore_list("@other:test") self._update_ignore_list("@other:test")

View File

@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, make_conn from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
from synapse.storage.databases.main.appservice import ( from synapse.storage.databases.main.appservice import (
ApplicationServiceStore, ApplicationServiceStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self): def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp() super(ApplicationServiceStoreTestCase, self).setUp()
self.as_yaml_files: List[str] = [] self.as_yaml_files: List[str] = []
@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
super(ApplicationServiceStoreTestCase, self).tearDown() super(ApplicationServiceStoreTestCase, self).tearDown()
def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str
) -> None:
as_yaml = { as_yaml = {
"url": url, "url": url,
"as_token": as_token, "as_token": as_token,
@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
database, make_conn(db_config, self.engine, "test"), self.hs database, make_conn(db_config, self.engine, "test"), self.hs
) )
def _add_service(self, url, as_token, id) -> None: def _add_service(self, url: str, as_token: str, id: str) -> None:
as_yaml = { as_yaml = {
"url": url, "url": url,
"as_token": as_token, "as_token": as_token,
@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token) self.as_yaml_files.append(as_token)
def _set_state(self, id: str, state: ApplicationServiceState): def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
return self.db_pool.runOperation( return self.db_pool.runOperation(
self.engine.convert_param_style( self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)" "INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(id, state.value), (id, state.value),
) )
def _insert_txn(self, as_id, txn_id, events): def _insert_txn(
self, as_id: str, txn_id: int, events: List[Mock]
) -> "defer.Deferred[None]":
return self.db_pool.runOperation( return self.db_pool.runOperation(
self.engine.convert_param_style( self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
# required for ApplicationServiceTransactionStoreTestCase tests # required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs) -> None: def __init__(
self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
def _write_config(self, suffix, **kwargs) -> str: def _write_config(self, suffix: str, **kwargs: str) -> str:
vals = { vals = {
"id": "id" + suffix, "id": "id" + suffix,
"url": "url" + suffix, "url": "url" + suffix,

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import Generator
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -30,7 +30,7 @@ from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase): class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore.""" """Test the "simple" SQL generating methods in SQLBaseStore."""
def setUp(self): def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"]) self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_conn.rollback.return_value = None self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs)) return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction self.db_pool.runInteraction = runInteraction
def runWithConnection(func, *args, **kwargs): def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs)) return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection self.db_pool.runWithConnection = runWithConnection
@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_1col(self): def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_3cols(self): def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self): def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_3col(self): def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3) self.mock_txn.fetchone.return_value = (1, 2, 3)
@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_missing(self): def test_select_one_missing(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 0 self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None self.mock_txn.fetchone.return_value = None
@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertFalse(ret) self.assertFalse(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self): def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),) self.mock_txn.description = (("colA", None, None, None, None, None, None),)
@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_one_1col(self): def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_one_4cols(self): def test_update_one_4cols(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_one(self): def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield defer.ensureDeferred( yield defer.ensureDeferred(

View File

@ -15,11 +15,16 @@
import os.path import os.path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage import prepare_database from synapse.storage import prepare_database
from synapse.storage.types import Cursor
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
Test the background update to clean forward extremities table. Test the background update to clean forward extremities table.
""" """
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler() self.room_creator = homeserver.get_room_creation_handler()
@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"] self.room_id = info["room_id"]
def run_background_update(self): def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities.""" """Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates. # Make sure we don't clash with in progress updates.
self.assertTrue( self.assertTrue(
@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"delete_forward_extremities.sql", "delete_forward_extremities.sql",
) )
def run_delta_file(txn): def run_delta_file(txn: Cursor) -> None:
prepare_database.executescript(txn, schema_path) prepare_database.executescript(txn, schema_path)
self.get_success( self.get_success(
@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
(room_id,) (room_id,)
) )
def test_soft_failed_extremities_handled_correctly(self): def test_soft_failed_extremities_handled_correctly(self) -> None:
"""Test that extremities are correctly calculated in the presence of """Test that extremities are correctly calculated in the presence of
soft failed events. soft failed events.
@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(latest_event_ids, [event_id_4]) self.assertEqual(latest_event_ids, [event_id_4])
def test_basic_cleanup(self): def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of """Test that extremities are correctly calculated in the presence of
soft failed events. soft failed events.
@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
) )
self.assertEqual(latest_event_ids, [event_id_b]) self.assertEqual(latest_event_ids, [event_id_b])
def test_chain_of_fail_cleanup(self): def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of """Test that extremities are correctly calculated in the presence of
soft failed events. soft failed events.
@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
) )
self.assertEqual(latest_event_ids, [event_id_b]) self.assertEqual(latest_event_ids, [event_id_b])
def test_forked_graph_cleanup(self): def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of r"""Test that extremities are correctly calculated in the presence of
soft failed events. soft failed events.
@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True config["cleanup_extremities_with_dummy_events"] = True
return self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler() self.room_creator = homeserver.get_room_creation_handler()
self.event_creator_handler = homeserver.get_event_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler()
@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator = homeserver.get_event_creation_handler() self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
def test_send_dummy_event(self): def test_send_dummy_event(self) -> None:
self._create_extremity_rich_graph() self._create_extremity_rich_graph()
# Pump the reactor repeatedly so that the background updates have a # Pump the reactor repeatedly so that the background updates have a
@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
def test_send_dummy_events_when_insufficient_power(self): def test_send_dummy_events_when_insufficient_power(self) -> None:
self._create_extremity_rich_graph() self._create_extremity_rich_graph()
# Criple power levels # Criple power levels
self.helper.send_state( self.helper.send_state(
@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self): def test_expiry_logic(self) -> None:
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly. expires old entries correctly.
""" """
@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
0, 0,
) )
def _create_extremity_rich_graph(self): def _create_extremity_rich_graph(self) -> None:
"""Helper method to create bushy graph on demand""" """Helper method to create bushy graph on demand"""
event_id_start = self.create_and_send_event(self.room_id, self.user) event_id_start = self.create_and_send_event(self.room_id, self.user)
@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(len(latest_event_ids), 50) self.assertEqual(len(latest_event_ids), 50)
def _enable_consent_checking(self): def _enable_consent_checking(self) -> None:
"""Helper method to enable consent checking""" """Helper method to enable consent checking"""
self.event_creator._block_events_without_consent_error = "No consent from user" self.event_creator._block_events_without_consent_error = "No consent from user"
consent_uri_builder = Mock() consent_uri_builder = Mock()

View File

@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import Mock
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.http.site import XForwardedForRequest from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request from tests.server import make_request
@ -30,14 +35,10 @@ from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
hs = self.setup_test_homeserver() self.store = hs.get_datastores().main
return hs
def prepare(self, hs, reactor, clock): def test_insert_new_client_ip(self) -> None:
self.store = self.hs.get_datastores().main
def test_insert_new_client_ip(self):
self.reactor.advance(12345678) self.reactor.advance(12345678)
user_id = "@user:id" user_id = "@user:id"
@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r, r,
) )
def test_insert_new_client_ip_none_device_id(self): def test_insert_new_client_ip_none_device_id(self) -> None:
""" """
An insert with a device ID of NULL will not create a new entry, but An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table. update an existing entry in the user_ips table.
@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_get_last_client_ip_by_device(self, after_persisting: bool): def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data""" """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678) self.reactor.advance(12345678)
@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_get_last_client_ip_by_device_combined_data(self): def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted """Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly data together correctly
""" """
@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool): def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data""" """Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678) self.reactor.advance(12345678)
@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_get_user_ip_and_agents_combined_data(self): def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data """Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly together correctly
""" """
@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self): def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server" user_id = "@user:server"
self.get_success( self.get_success(
self.store.insert_client_ip( self.store.insert_client_ip(
@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active) self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self): def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100 lots_of_users = 100
user_id = "@user:server" user_id = "@user:server"
@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active) self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self): def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server" user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) self.assertFalse(active)
@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active) self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self): def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server" user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active) self.assertTrue(active)
def test_devices_last_seen_bg_update(self): def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates. # First make sure we have completed all updates.
self.wait_for_background_updates() self.wait_for_background_updates()
@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r, r,
) )
def test_old_user_ips_pruned(self): def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates. # First make sure we have completed all updates.
self.wait_for_background_updates() self.wait_for_background_updates()
@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, []) self.assertEqual(result, [])
# But we should still get the correct values for the device # But we should still get the correct values for the device
result = self.get_success( result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id) self.store.get_last_client_ip_by_device(user_id, device_id)
) )
r = result[(user_id, device_id)] r = result2[(user_id, device_id)]
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": user_id, "user_id": user_id,
@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
hs = self.setup_test_homeserver()
return hs
def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True) self.user_id = self.register_user("bob", "abc123", True)
def test_request_with_xforwarded(self): def test_request_with_xforwarded(self) -> None:
""" """
The IP in X-Forwarded-For is entered into the client IPs table. The IP in X-Forwarded-For is entered into the client IPs table.
""" """
@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest}, {"request": XForwardedForRequest},
) )
def test_request_from_getPeer(self): def test_request_from_getPeer(self) -> None:
""" """
The IP returned by getPeer is entered into the client IPs table, if The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header. there's no X-Forwarded-For header.
""" """
self._runtest({}, "127.0.0.1", {}) self._runtest({}, "127.0.0.1", {})
def _runtest(self, headers, expected_ip, make_request_args): def _runtest(
self,
headers: Dict[bytes, bytes],
expected_ip: str,
make_request_args: Dict[str, Any],
) -> None:
device_id = "bleb" device_id = "bleb"
access_token = self.login("bob", "abc123", device_id=device_id) access_token = self.login("bob", "abc123", device_id=device_id)

View File

@ -31,7 +31,7 @@ from tests import unittest
class TupleComparisonClauseTestCase(unittest.TestCase): class TupleComparisonClauseTestCase(unittest.TestCase):
def test_native_tuple_comparison(self): def test_native_tuple_comparison(self) -> None:
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2]) self.assertEqual(args, [1, 2])

View File

@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Collection, List, Tuple
from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors import synapse.api.errors
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(HomeserverTestCase): class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def add_device_change(self, user_id, device_ids, host): def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to """Add a device list change for the given device to
`device_lists_outbound_pokes` table. `device_lists_outbound_pokes` table.
""" """
@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
) )
) )
def test_store_new_device(self): def test_store_new_device(self) -> None:
self.get_success( self.get_success(
self.store.store_device("user_id", "device_id", "display_name") self.store.store_device("user_id", "device_id", "display_name")
) )
res = self.get_success(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": "user_id", "user_id": "user_id",
@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res, res,
) )
def test_get_devices_by_user(self): def test_get_devices_by_user(self) -> None:
self.get_success( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"], res["device2"],
) )
def test_count_devices_by_users(self): def test_count_devices_by_users(self) -> None:
self.get_success( self.get_success(
self.store.store_device("user_id", "device1", "display_name 1") self.store.store_device("user_id", "device1", "display_name 1")
) )
@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
) )
self.assertEqual(3, res) self.assertEqual(3, res)
def test_get_device_updates_by_remote(self): def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s # Add two device updates with sequential `stream_id`s
@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates # Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates) self._check_devices_in_updates(device_ids, device_updates)
def test_get_device_updates_by_remote_can_limit_properly(self): def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
""" """
Tests that `get_device_updates_by_remote` returns an appropriate Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results). stream_id to resume fetching from (without skipping any results).
@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
) )
self.assertEqual(device_updates, []) self.assertEqual(device_updates, [])
def _check_devices_in_updates(self, expected_device_ids, device_updates): def _check_devices_in_updates(
self,
expected_device_ids: Collection[str],
device_updates: List[Tuple[str, JsonDict]],
) -> None:
"""Check that an specific device ids exist in a list of device update EDUs""" """Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids)) self.assertEqual(len(device_updates), len(expected_device_ids))
@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
} }
self.assertEqual(received_device_ids, set(expected_device_ids)) self.assertEqual(received_device_ids, set(expected_device_ids))
def test_update_device(self): def test_update_device(self) -> None:
self.get_success( self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1") self.store.store_device("user_id", "device_id", "display_name 1")
) )
res = self.get_success(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
self.get_success(self.store.update_device("user_id", "device_id")) self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
# check it worked # check it worked
res = self.get_success(self.store.get_device("user_id", "device_id")) res = self.get_success(self.store.get_device("user_id", "device_id"))
assert res is not None
self.assertEqual("display_name 2", res["display_name"]) self.assertEqual("display_name 2", res["display_name"])
def test_update_unknown_device(self): def test_update_unknown_device(self) -> None:
exc = self.get_failure( exc = self.get_failure(
self.store.update_device( self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2" "user_id", "unknown_device_id", new_display_name="display_name 2"

View File

@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID from synapse.types import RoomAlias, RoomID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class DirectoryStoreTestCase(HomeserverTestCase): class DirectoryStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test") self.alias = RoomAlias.from_string("#my-room:test")
def test_room_to_alias(self): def test_room_to_alias(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
) )
def test_alias_to_room(self): def test_alias_to_room(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_association_from_room_alias(self.alias))), (self.get_success(self.store.get_association_from_room_alias(self.alias))),
) )
def test_delete_alias(self): def test_delete_alias(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]

View File

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.e2e_room_keys import RoomKey from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -26,12 +30,12 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
def test_room_keys_version_delete(self): def test_room_keys_version_delete(self) -> None:
# test that deleting a room key backup deletes the keys # test that deleting a room key backup deletes the keys
version1 = self.get_success( version1 = self.get_success(
self.store.create_e2e_room_keys_version( self.store.create_e2e_room_keys_version(

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(HomeserverTestCase): class EndToEndKeyStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def test_key_without_device_name(self): def test_key_without_device_name(self) -> None:
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev) self.assertDictContainsSubset(json, dev)
def test_reupload_key(self): def test_reupload_key(self) -> None:
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
) )
self.assertFalse(changed) self.assertFalse(changed)
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self) -> None:
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
) )
def test_multiple_devices(self): def test_multiple_devices(self) -> None:
now = 1470174257070 now = 1470174257070
self.get_success(self.store.store_device("user1", "device1", None)) self.get_success(self.store.store_device("user1", "device1", None))

View File

@ -14,6 +14,7 @@
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest from twisted.trial import unittest
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -22,18 +23,22 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.events import _LinkMap from synapse.storage.databases.main.events import _LinkMap
from synapse.storage.types import Cursor
from synapse.types import create_requester from synapse.types import create_requester
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase): class EventChainStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._next_stream_ordering = 1 self._next_stream_ordering = 1
def test_simple(self): def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md` """Test that the example in `docs/auth_chain_difference_algorithm.md`
works. works.
""" """
@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
), ),
) )
def test_out_of_order_events(self): def test_out_of_order_events(self) -> None:
"""Test that we handle persisting events that we don't have the full """Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships). auth chain for yet (which should only happen for out of band memberships).
""" """
@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def persist( def persist(
self, self,
events: List[EventBase], events: List[EventBase],
): ) -> None:
"""Persist the given events and check that the links generated match """Persist the given events and check that the links generated match
those given. those given.
""" """
@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
e.internal_metadata.stream_ordering = self._next_stream_ordering e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1 self._next_stream_ordering += 1
def _persist(txn): def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events # We need to persist the events to the events and state_events
# tables. # tables.
persist_events_store._store_event_txn( persist_events_store._store_event_txn(
@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
class LinkMapTestCase(unittest.TestCase): class LinkMapTestCase(unittest.TestCase):
def test_simple(self): def test_simple(self) -> None:
"""Basic tests for the LinkMap.""" """Basic tests for the LinkMap."""
link_map = _LinkMap() link_map = _LinkMap()
@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass") self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass") self.token = self.login("foo", "pass")
@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Delete the chain cover info. # Delete the chain cover info.
def _delete_tables(txn): def _delete_tables(txn: Cursor) -> None:
txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links") txn.execute("DELETE FROM event_auth_chain_links")
@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
return room_id, [state1, state2] return room_id, [state1, state2]
def test_background_update_single_room(self): def test_background_update_single_room(self) -> None:
"""Test that the background update to calculate auth chains for historic """Test that the background update to calculate auth chains for historic
rooms works correctly. rooms works correctly.
""" """
@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
) )
) )
def test_background_update_multiple_rooms(self): def test_background_update_multiple_rooms(self) -> None:
"""Test that the background update to calculate auth chains for historic """Test that the background update to calculate auth chains for historic
rooms works correctly. rooms works correctly.
""" """
@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
) )
) )
def test_background_update_single_large_room(self): def test_background_update_single_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic """Test that the background update to calculate auth chains for historic
rooms works correctly. rooms works correctly.
""" """
@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
) )
) )
def test_background_update_multiple_large_room(self): def test_background_update_multiple_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic """Test that the background update to calculate auth chains for historic
rooms works correctly. rooms works correctly.
""" """

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union, cast
import attr import attr
from parameterized import parameterized from parameterized import parameterized
@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions, EventFormatVersions,
RoomVersion, RoomVersion,
) )
from synapse.events import _EventInternalMetadata from synapse.events import EventBase, _EventInternalMetadata
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.types import Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock, json_encoder from synapse.util import Clock, json_encoder
@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def test_get_prev_events_for_room(self): def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local" room_id = "@ROOM:local"
# add a bunch of events and hashes to act as forward extremities # add a bunch of events and hashes to act as forward extremities
def insert_event(txn, i): def insert_event(txn: Cursor, i: int) -> None:
event_id = "$event_%i:local" % i event_id = "$event_%i:local" % i
txn.execute( txn.execute(
@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 10): for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i]) self.assertEqual("$event_%i:local" % (19 - i), r[i])
def test_get_rooms_with_many_extremities(self): def test_get_rooms_with_many_extremities(self) -> None:
room1 = "#room1" room1 = "#room1"
room2 = "#room2" room2 = "#room2"
room3 = "#room3" room3 = "#room3"
def insert_event(txn, i, room_id): def insert_event(txn: Cursor, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i event_id = "$event_%i:local" % i
txn.execute( txn.execute(
( (
@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | | # | |
# K J # K J
auth_graph = { auth_graph: Dict[str, List[str]] = {
"a": ["e"], "a": ["e"],
"b": ["e"], "b": ["e"],
"c": ["g", "i"], "c": ["g", "i"],
@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Mark the room as maybe having a cover index. # Mark the room as maybe having a cover index.
def store_room(txn): def store_room(txn: LoggingTransaction) -> None:
self.store.db_pool.simple_insert_txn( self.store.db_pool.simple_insert_txn(
txn, txn,
"rooms", "rooms",
@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much # We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly. # easier than constructing events properly.
def insert_event(txn): def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0 stream_ordering = 0
for event_id in auth_graph: for event_id in auth_graph:
@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ [
FakeEvent(event_id, room_id, auth_graph[event_id]) cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph for event_id in auth_graph
], ],
) )
@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id return room_id
@parameterized.expand([(True,), (False,)]) @parameterized.expand([(True,), (False,)])
def test_auth_chain_ids(self, use_chain_cover_index: bool): def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index) room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain. # a and b have the same auth chain.
@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["i", "j"]) self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)]) @parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool): def test_auth_difference(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index) room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result: # Now actually test that various combinations give the right result:
@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
) )
self.assertSetEqual(difference, set()) self.assertSetEqual(difference, set())
def test_auth_difference_partial_cover(self): def test_auth_difference_partial_cover(self) -> None:
"""Test that we correctly handle rooms where not all events have a chain """Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old during the background update that calculates the chain cover for old
@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | | # | |
# K J # K J
auth_graph = { auth_graph: Dict[str, List[str]] = {
"a": ["e"], "a": ["e"],
"b": ["e"], "b": ["e"],
"c": ["g", "i"], "c": ["g", "i"],
@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much # We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly. # easier than constructing events properly.
def insert_event(txn): def insert_event(txn: LoggingTransaction) -> None:
# First insert the room and mark it as having a chain cover. # First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn( self.store.db_pool.simple_insert_txn(
txn, txn,
@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ [
FakeEvent(event_id, room_id, auth_graph[event_id]) cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph for event_id in auth_graph
if event_id != "b" if event_id != "b"
], ],
@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn( self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn, txn,
[FakeEvent("b", room_id, auth_graph["b"])], [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
) )
self.store.db_pool.simple_update_txn( self.store.db_pool.simple_update_txn(
@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized.expand( @parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()] [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
) )
def test_prune_inbound_federation_queue(self, room_version: RoomVersion): def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
"""Test that pruning of inbound federation queues work""" """Test that pruning of inbound federation queues work"""
room_id = "some_room_id" room_id = "some_room_id"
@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
stream_ordering += 1 stream_ordering += 1
def populate_db(txn: LoggingTransaction): def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of # Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts` # `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn( self.store.db_pool.simple_insert_txn(
@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
def test_get_backfill_points_in_room(self): def test_get_backfill_points_in_room(self) -> None:
""" """
Test to make sure only backfill points that are older and come before Test to make sure only backfill points that are older and come before
the `current_depth` are returned. the `current_depth` are returned.
@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_excludes_events_we_have_attempted( def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self, self,
): ) -> None:
""" """
Test to make sure that events we have attempted to backfill (and within Test to make sure that events we have attempted to backfill (and within
backoff timeout duration) do not show up as an event to backfill again. backoff timeout duration) do not show up as an event to backfill again.
@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self, self,
): ) -> None:
""" """
Test to make sure after we fake attempt to backfill event "b3" many times, Test to make sure after we fake attempt to backfill event "b3" many times,
we can see retry and see the "b3" again after the backoff timeout duration we can see retry and see the "b3" again after the backoff timeout duration
@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"5": 7, "5": 7,
} }
def populate_db(txn: LoggingTransaction): def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of # Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts` # `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn( self.store.db_pool.simple_insert_txn(
@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
def test_get_insertion_event_backward_extremities_in_room(self): def test_get_insertion_event_backward_extremities_in_room(self) -> None:
""" """
Test to make sure only insertion event backward extremities that are Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned. older and come before the `current_depth` are returned.
@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self, self,
): ) -> None:
""" """
Test to make sure that insertion events we have attempted to backfill Test to make sure that insertion events we have attempted to backfill
(and within backoff timeout duration) do not show up as an event to (and within backoff timeout duration) do not show up as an event to
@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self, self,
): ) -> None:
""" """
Test to make sure after we fake attempt to backfill event Test to make sure after we fake attempt to backfill event
"insertion_eventA" many times, we can see retry and see the "insertion_eventA" many times, we can see retry and see the
@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"]) self.assertEqual(backfill_event_ids, ["insertion_eventA"])
def test_get_event_ids_to_not_pull_from_backoff( def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
self,
):
""" """
Test to make sure only event IDs we should backoff from are returned. Test to make sure only event IDs we should backoff from are returned.
""" """
@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self, self,
): ) -> None:
""" """
Test to make sure no event IDs are returned after the backoff duration has Test to make sure no event IDs are returned after the backoff duration has
elapsed. elapsed.
@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(event_ids_to_backoff, []) self.assertEqual(event_ids_to_backoff, [])
@attr.s @attr.s(auto_attribs=True)
class FakeEvent: class FakeEvent:
event_id = attr.ib() event_id: str
room_id = attr.ib() room_id: str
auth_events = attr.ib() auth_events: List[str]
type = "foo" type = "foo"
state_key = "foo" state_key = "foo"
internal_metadata = _EventInternalMetadata({}) internal_metadata = _EventInternalMetadata({})
def auth_event_ids(self): def auth_event_ids(self) -> List[str]:
return self.auth_events return self.auth_events
def is_state(self): def is_state(self) -> bool:
return True return True

View File

@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
class ExtremStatisticsTestCase(HomeserverTestCase): class ExtremStatisticsTestCase(HomeserverTestCase):
def test_exposed_to_prometheus(self): def test_exposed_to_prometheus(self) -> None:
""" """
Forward extremity counts are exposed via Prometheus. Forward extremity counts are exposed via Prometheus.
""" """

View File

@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import StateMap
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state self._state_storage_controller = self.hs.get_storage_controllers().state
@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that the current extremities is the remote event. # Check that the current extremities is the remote event.
self.assert_extremities([self.remote_event_1.event_id]) self.assert_extremities([self.remote_event_1.event_id])
def persist_event(self, event, state=None): def persist_event(
self, event: EventBase, state: Optional[StateMap[str]] = None
) -> None:
"""Persist the event, with optional state""" """Persist the event, with optional state"""
context = self.get_success( context = self.get_success(
self.state.compute_event_context( self.state.compute_event_context(
@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
self.get_success(self._persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities): def assert_extremities(self, expected_extremities: List[str]) -> None:
"""Assert the current extremities for the room""" """Assert the current extremities for the room"""
extremities = self.get_success( extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id) self.store.get_prev_events_for_room(self.room_id)
) )
self.assertCountEqual(extremities, expected_extremities) self.assertCountEqual(extremities, expected_extremities)
def test_prune_gap(self): def test_prune_gap(self) -> None:
"""Test that we drop extremities after a gap when we see an event from """Test that we drop extremities after a gap when we see an event from
the same domain. the same domain.
""" """
@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
def test_do_not_prune_gap_if_state_different(self): def test_do_not_prune_gap_if_state_different(self) -> None:
"""Test that we don't prune extremities after a gap if the resolved """Test that we don't prune extremities after a gap if the resolved
state is different. state is different.
""" """
@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that we haven't dropped the old extremity. # Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
def test_prune_gap_if_old(self): def test_prune_gap_if_old(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity """Test that we drop extremities after a gap when the previous extremity
is "old" is "old"
""" """
@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
def test_do_not_prune_gap_if_other_server(self): def test_do_not_prune_gap_if_other_server(self) -> None:
"""Test that we do not drop extremities after a gap when we see an event """Test that we do not drop extremities after a gap when we see an event
from a different domain. from a different domain.
""" """
@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
def test_prune_gap_if_dummy_remote(self): def test_prune_gap_if_dummy_remote(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity """Test that we drop extremities after a gap when the previous extremity
is a local dummy event and only points to remote events. is a local dummy event and only points to remote events.
""" """
@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
def test_prune_gap_if_dummy_local(self): def test_prune_gap_if_dummy_local(self) -> None:
"""Test that we don't drop extremities after a gap when the previous """Test that we don't drop extremities after a gap when the previous
extremity is a local dummy event and points to local events. extremity is a local dummy event and points to local events.
""" """
@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id]) self.assert_extremities([remote_event_2.event_id, local_message_event_id])
def test_do_not_prune_gap_if_not_dummy(self): def test_do_not_prune_gap_if_not_dummy(self) -> None:
"""Test that we do not drop extremities after a gap when the previous extremity """Test that we do not drop extremities after a gap when the previous extremity
is not a dummy event. is not a dummy event.
""" """
@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self): def test_remote_user_rooms_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_rooms_for_user` cache """Test that if the server leaves a room the `get_rooms_for_user` cache
is invalidated for remote users. is invalidated for remote users.
""" """
@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
self.assertEqual(set(rooms), set()) self.assertEqual(set(rooms), set())
def test_room_remote_user_cache_invalidated(self): def test_room_remote_user_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_users_in_room` cache """Test that if the server leaves a room the `get_users_in_room` cache
is invalidated for remote users. is invalidated for remote users.
""" """

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import signedjson.key import signedjson.key
import signedjson.types
import unpaddedbase64 import unpaddedbase64
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest import tests.unittest
def decode_verify_key_base64(key_id: str, key_base64: str): def decode_verify_key_base64(
key_id: str, key_base64: str
) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64) key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(
class KeyStoreTestCase(tests.unittest.HomeserverTestCase): class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self): def test_get_server_verify_keys(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1" key_id_1 = "ed25519:key1"
@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
# non-existent result gives None # non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")]) self.assertIsNone(res[("server1", "ed25519:key3")])
def test_cache(self): def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache.""" """Check that updates correctly invalidate the cache."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main

View File

@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.reactor.advance(FORTY_DAYS) self.reactor.advance(FORTY_DAYS)
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self): def test_initialise_reserved_users(self) -> None:
threepids = self.hs.config.server.mau_limits_reserved_threepids threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third # register three users, of which two have reserved 3pids, and a third
@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
active_count = self.get_success(self.store.get_monthly_active_count()) active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(active_count, 3) self.assertEqual(active_count, 3)
def test_can_insert_and_count_mau(self): def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0) self.assertEqual(count, 0)
@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1) self.assertEqual(count, 1)
def test_appservice_user_not_counted_in_mau(self): def test_appservice_user_not_counted_in_mau(self) -> None:
self.get_success( self.get_success(
self.store.register_user( self.store.register_user(
user_id="@appservice_user:server", appservice_id="wibble" user_id="@appservice_user:server", appservice_id="wibble"
@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0) self.assertEqual(count, 0)
def test_user_last_seen_monthly_active(self): def test_user_last_seen_monthly_active(self) -> None:
user_id1 = "@user1:server" user_id1 = "@user1:server"
user_id2 = "@user2:server" user_id2 = "@user2:server"
user_id3 = "@user3:server" user_id3 = "@user3:server"
@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertIsNone(result) self.assertIsNone(result)
@override_config({"max_mau_value": 5}) @override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self): def test_reap_monthly_active_users(self) -> None:
initial_users = 10 initial_users = 10
for i in range(initial_users): for i in range(initial_users):
self.get_success( self.get_success(
@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Note that below says mau_limit (no s), this is the name of the config # Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits. # value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self): def test_reap_monthly_active_users_reserved_users(self) -> None:
"""Tests that reaping correctly handles reaping where reserved users are """Tests that reaping correctly handles reaping where reserved users are
present""" present"""
threepids = self.hs.config.server.mau_limits_reserved_threepids threepids = self.hs.config.server.mau_limits_reserved_threepids
@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, self.hs.config.server.max_mau_value) self.assertEqual(count, self.hs.config.server.max_mau_value)
def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list # Test that guest users are not added to mau list
user_id = "@user_id:host" user_id = "@user_id:host"
@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once() self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self): def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_get_reserved_real_user_account(self): def test_get_reserved_real_user_account(self) -> None:
# Test no reserved users, or reserved threepids # Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users()) users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), 0) self.assertEqual(len(users), 0)
@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
users = self.get_success(self.store.get_registered_reserved_users()) users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids)) self.assertEqual(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self): def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test" support_user_id = "@support:test"
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config( @override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
) )
def test_track_monthly_users_without_cap(self): def test_track_monthly_users_without_cap(self) -> None:
count = self.get_success(self.store.get_monthly_active_count()) count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(0, count) self.assertEqual(0, count)
@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(2, count) self.assertEqual(2, count)
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self): def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_get_monthly_active_count_by_service(self): def test_get_monthly_active_count_by_service(self) -> None:
appservice1_user1 = "@appservice1_user1:example.com" appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com" appservice1_user2 = "@appservice1_user2:example.com"
@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service2], 1) self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1) self.assertEqual(result[native], 1)
def test_get_monthly_active_users_by_service(self): def test_get_monthly_active_users_by_service(self) -> None:
# (No users, no filtering) -> empty result # (No users, no filtering) -> empty result
result = self.get_success(self.store.get_monthly_active_users_by_service()) result = self.get_success(self.store.get_monthly_active_users_by_service())

View File

@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
user_id = "@red:server" user_id = "@red:server"
servlets = [room.register_servlets] servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage_controllers = self.hs.get_storage_controllers() self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self): def test_purge_history(self) -> None:
""" """
Purging a room history will delete everything before the topological point. Purging a room history will delete everything before the topological point.
""" """
@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"])) self.get_success(self.store.get_event(last["event_id"]))
def test_purge_history_wont_delete_extrems(self): def test_purge_history_wont_delete_extrems(self) -> None:
""" """
Purging a room history will delete everything before the topological point. Purging a room history will delete everything before the topological point.
""" """
@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success( token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"]) self.store.get_topological_token_for_event(last["event_id"])
) )
assert token.topological is not None
event = f"t{token.topological + 1}-{token.stream + 1}" event = f"t{token.topological + 1}-{token.stream + 1}"
# Purge everything before this topological token # Purge everything before this topological token
@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
self.get_success(self.store.get_event(third["event_id"])) self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"])) self.get_success(self.store.get_event(last["event_id"]))
def test_purge_room(self): def test_purge_room(self) -> None:
""" """
Purging a room will delete everything about it. Purging a room will delete everything about it.
""" """

View File

@ -14,8 +14,12 @@
from typing import Collection, Optional from typing import Collection, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests.test_utils.event_injection import create_event from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase): class ReceiptTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver) -> None: def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
) )
self.assertEqual(res, {}) self.assertEqual(res, {})
res = self.get_last_unthreaded_receipt( res2 = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
) )
self.assertEqual(res, None) self.assertIsNone(res2)
def test_get_receipts_for_user(self) -> None: def test_get_receipts_for_user(self) -> None:
# Send some events into the first room # Send some events into the first room

View File

@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import List, Optional, cast
from canonicaljson import json from canonicaljson import json
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.events import EventBase, _EventInternalMetadata
from synapse.events.builder import EventBuilder
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomID, UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.utils import create_room from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase): class RedactionTestCase(unittest.HomeserverTestCase):
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["redaction_retention_period"] = "30d" config["redaction_retention_period"] = "30d"
return config return config
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage = hs.get_storage_controllers() storage = hs.get_storage_controllers()
assert storage.persistence is not None
self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1 self.depth = 1
def inject_room_member( def inject_room_member( # type: ignore[override]
self, self,
room, room: RoomID,
user, user: UserID,
membership, membership: str,
replaces_state=None, extra_content: Optional[JsonDict] = None,
extra_content: Optional[dict] = None, ) -> EventBase:
):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content or {}) content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self._storage.persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
return event return event
def inject_message(self, room, user, body): def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1 self.depth += 1
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self._storage.persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
return event return event
def inject_redaction(self, room, event_id, user, reason): def inject_redaction(
self, room: RoomID, event_id: str, user: UserID, reason: str
) -> EventBase:
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
{ {
@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self._storage.persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
return event return event
def test_redact(self): def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t") msg_event = self.inject_message(self.room1, self.u_alice, "t")
@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"], event.unsigned["redacted_because"],
) )
def test_redact_join(self): def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_room_member( msg_event = self.inject_room_member(
@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"], event.unsigned["redacted_because"],
) )
def test_circular_redaction(self): def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test" redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test" redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder: class EventIdManglingBuilder:
def __init__(self, base_builder, event_id): def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder self._base_builder = base_builder
self._event_id = event_id self._event_id = event_id
@ -227,31 +236,33 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str], prev_event_ids: List[str],
auth_event_ids: Optional[List[str]], auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
): ) -> EventBase:
built_event = await self._base_builder.build( built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id assert built_event.event_id == self._event_id
return built_event return built_event
@property @property
def room_id(self): def room_id(self) -> str:
return self._base_builder.room_id return self._base_builder.room_id
@property @property
def type(self): def type(self) -> str:
return self._base_builder.type return self._base_builder.type
@property @property
def internal_metadata(self): def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata return self._base_builder.internal_metadata
event_1, context_1 = self.get_success( event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event( self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
EventIdManglingBuilder( EventIdManglingBuilder(
self.event_builder_factory.for_room_version( self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
@ -264,14 +275,17 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
), ),
redaction_event_id1, redaction_event_id1,
),
) )
) )
) )
self.get_success(self._storage.persistence.persist_event(event_1, context_1)) self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success( event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event( self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
EventIdManglingBuilder( EventIdManglingBuilder(
self.event_builder_factory.for_room_version( self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
@ -284,10 +298,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
), ),
redaction_event_id2, redaction_event_id2,
),
) )
) )
) )
self.get_success(self._storage.persistence.persist_event(event_2, context_2)) self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions # fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1)) fetched = self.get_success(self.store.get_event(redaction_event_id1))
@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2 fetched.unsigned["redacted_because"].event_id, redaction_event_id2
) )
def test_redact_censor(self): def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month""" """Test that a redacted event gets censored in the DB after a month"""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json)) self.assert_dict({"content": {}}, json.loads(event_json))
def test_redact_redaction(self): def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again.""" """Tests that we can redact a redaction and can fetch it again."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True) self.store.get_event(first_redact_event.event_id, allow_none=True)
) )
def test_store_redacted_redaction(self): def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction.""" """Tests that we can store a redacted redaction."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success( self.get_success(self._persistence.persist_event(redaction_event, context))
self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event # Now lets jump to the future where we have censored the redaction event
# in the DB. # in the DB.

View File

@ -14,10 +14,15 @@
from typing import List from typing import List
from unittest import mock from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
from synapse.storage.schema import SCHEMA_VERSION from synapse.storage.schema import SCHEMA_VERSION
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase): class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer federation_http_client=None, homeserver_to_use=GenericWorkerServer
) )
return hs return hs
def default_config(self): def default_config(self) -> JsonDict:
conf = super().default_config() conf = super().default_config()
# Mark this as a worker app. # Mark this as a worker app.
@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):
return conf return conf
def test_rolling_back(self): def test_rolling_back(self) -> None:
"""Test that workers can start if the DB is a newer schema version""" """Test that workers can start if the DB is a newer schema version"""
db_pool = self.hs.get_datastores().main.db_pool db_pool = self.hs.get_datastores().main.db_pool
@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config) prepare_database(db_conn, db_pool.engine, self.hs.config)
def test_not_upgraded_old_schema_version(self): def test_not_upgraded_old_schema_version(self) -> None:
"""Test that workers don't start if the DB has an older schema version""" """Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastores().main.db_pool db_pool = self.hs.get_datastores().main.db_pool
db_conn = LoggingDatabaseConnection( db_conn = LoggingDatabaseConnection(
@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException): with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config) prepare_database(db_conn, db_pool.engine, self.hs.config)
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
""" """
Test that workers don't start if the DB is on the current schema version, Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run. but there are still outstanding delta migrations to run.

View File

@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID, UserID from synapse.types import RoomAlias, RoomID, UserID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class RoomStoreTestCase(HomeserverTestCase): class RoomStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# We can't test RoomStore on its own without the DirectoryStore, for # We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table # management of the 'room_aliases' table
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
) )
) )
def test_get_room(self): def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"room_id": self.room.to_string(), "room_id": self.room.to_string(),
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True, "is_public": True,
}, },
(self.get_success(self.store.get_room(self.room.to_string()))), res,
) )
def test_get_room_unknown_room(self): def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
def test_get_room_with_stats(self): def test_get_room_with_stats(self) -> None:
res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
assert res is not None
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"room_id": self.room.to_string(), "room_id": self.room.to_string(),
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"public": True, "public": True,
}, },
(self.get_success(self.store.get_room_with_stats(self.room.to_string()))), res,
) )
def test_get_room_with_stats_unknown_room(self): def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone( self.assertIsNone(
(self.get_success(self.store.get_room_with_stats("!uknown:test"))), self.get_success(self.store.get_room_with_stats("!uknown:test"))
) )

View File

@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def test_null_byte(self): def test_null_byte(self) -> None:
""" """
Postgres/SQLite don't like null bytes going into the search tables. Internally Postgres/SQLite don't like null bytes going into the search tables. Internally
we replace those with a space. we replace those with a space.
@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
if isinstance(store.database_engine, PostgresEngine): if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights")) self.assertIn("alice", result.get("highlights"))
def test_non_string(self): def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`. """Test that non-string `value`s are not inserted into `event_search`.
This is particularly important when using sqlite, since a sqlite column can hold This is particularly important when using sqlite, since a sqlite column can hold
@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
self.assertEqual(f.value.code, 404) self.assertEqual(f.value.code, 404)
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
def test_sqlite_non_string_deletion_background_update(self): def test_sqlite_non_string_deletion_background_update(self) -> None:
"""Test the background update to delete bad rows from `event_search`.""" """Test the background update to delete bad rows from `event_search`."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
"results array length should match count", "results array length should match count",
) )
def test_postgres_web_search_for_phrase(self): def test_postgres_web_search_for_phrase(self) -> None:
""" """
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
This test is skipped unless the postgres instance supports websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery.
@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
def test_sqlite_search(self): def test_sqlite_search(self) -> None:
""" """
Test sqlite searching for phrases. Test sqlite searching for phrases.
""" """

View File

@ -16,10 +16,15 @@ import logging
from frozendict import frozendict from frozendict import frozendict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.server import HomeServer
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase from tests.unittest import HomeserverTestCase, TestCase
@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase): class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers() self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state self.state_datastore = self.storage.state.stores.state
@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
) )
) )
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
) -> EventBase:
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,
{ {
@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(self.storage.persistence.persist_event(event, context))
return event return event
def assertStateMapEqual(self, s1, s2): def assertStateMapEqual(
self, s1: StateMap[EventBase], s2: StateMap[EventBase]
) -> None:
for t in s1: for t in s1:
# just compare event IDs for simplicity # just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2)) self.assertEqual(len(s1), len(s2))
def test_get_state_groups_ids(self): def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event( e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = self.get_success( state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) self.storage.state.get_state_groups_ids(
self.room.to_string(), [e2.event_id]
)
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0] state_map = list(state_group_map.values())[0]
@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
) )
def test_get_state_groups(self): def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event( e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = self.get_success( state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id]) self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0] state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self): def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase): class StateFilterDifferenceTestCase(TestCase):
def assert_difference( def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
): ) -> None:
self.assertEqual( self.assertEqual(
minuend.approx_difference(subtrahend), minuend.approx_difference(subtrahend),
expected, expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
) )
def test_state_filter_difference_no_include_other_minus_no_include_other(self): def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
""" """
Tests the StateFilter.approx_difference method Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the where, in a.approx_difference(b), both a and b do not have the
@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
), ),
) )
def test_state_filter_difference_include_other_minus_no_include_other(self): def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
""" """
Tests the StateFilter.approx_difference method Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set. where, in a.approx_difference(b), only a has the include_others flag set.
@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
), ),
) )
def test_state_filter_difference_include_other_minus_include_other(self): def test_state_filter_difference_include_other_minus_include_other(self) -> None:
""" """
Tests the StateFilter.approx_difference method Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others where, in a.approx_difference(b), both a and b have the include_others
@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
), ),
) )
def test_state_filter_difference_no_include_other_minus_include_other(self): def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
""" """
Tests the StateFilter.approx_difference method Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set. where, in a.approx_difference(b), only b has the include_others flag set.
@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
), ),
) )
def test_state_filter_difference_simple_cases(self): def test_state_filter_difference_simple_cases(self) -> None:
""" """
Tests some very simple cases of the StateFilter approx_difference, Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests. that are not explicitly tested by the more in-depth tests.
@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
class StateFilterTestCase(TestCase): class StateFilterTestCase(TestCase):
def test_return_expanded(self): def test_return_expanded(self) -> None:
""" """
Tests the behaviour of the return_expanded() function that expands Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate). StateFilters to include more state types (for the sake of cache hit rate).

View File

@ -14,11 +14,15 @@
from typing import List from typing import List
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["experimental_features"] = {"msc3874_enabled": True} config["experimental_features"] = {"msc3874_enabled": True}
return config return config
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.user_id = self.register_user("test", "test") self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test") self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):
return [ev.event_id for ev in events] return [ev.event_id for ev in events]
def test_filter_relation_senders(self): def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to. # Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]} filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_type(self): def test_filter_relation_type(self) -> None:
# Messages which have annotations. # Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_senders_and_type(self): def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to. # Messages which second user reacted to.
filter = { filter = {
"related_by_senders": [self.second_user_id], "related_by_senders": [self.second_user_id],
@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter) chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1]) self.assertEqual(chunk, [self.event_id_1])
def test_duplicate_relation(self): def test_duplicate_relation(self) -> None:
"""An event should only be returned once if there are multiple relations to it.""" """An event should only be returned once if there are multiple relations to it."""
self.helper.send_event( self.helper.send_event(
room_id=self.room_id, room_id=self.room_id,

View File

@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings from synapse.storage.databases.main.transactions import DestinationRetryTimings
from synapse.util import Clock
from synapse.util.retryutils import MAX_RETRY_INTERVAL from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
class TransactionStoreTestCase(HomeserverTestCase): class TransactionStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
def test_get_set_transactions(self): def test_get_set_transactions(self) -> None:
"""Tests that we can successfully get a non-existent entry for """Tests that we can successfully get a non-existent entry for
destination retries, as well as testing tht we can set and get destination retries, as well as testing tht we can set and get
correctly. correctly.
@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
r, r,
) )
def test_initial_set_transactions(self): def test_initial_set_transactions(self) -> None:
"""Tests that we can successfully set the destination retries (there """Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this) was a bug around invalidating the cache that broke this)
""" """
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d) self.get_success(d)
def test_large_destination_retry(self): def test_large_destination_retry(self) -> None:
d = self.store.set_destination_retry_timings( d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
) )
self.get_success(d) self.get_success(d)
d = self.store.get_destination_retry_timings("example.com") d2 = self.store.get_destination_retry_timings("example.com")
self.get_success(d) self.get_success(d2)

View File

@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.types import Cursor
from synapse.util import Clock
from tests import unittest from tests import unittest
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
"""Test SQL transaction limit doesn't break transactions.""" """Test SQL transaction limit doesn't break transactions."""
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(db_txn_limit=1000) return self.setup_test_homeserver(db_txn_limit=1000)
def test_config(self): def test_config(self) -> None:
db_config = self.hs.config.database.get_single_database() db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000) self.assertEqual(db_config.config["txn_limit"], 1000)
def test_select(self): def test_select(self) -> None:
def do_select(txn): def do_select(txn: Cursor) -> None:
txn.execute("SELECT 1") txn.execute("SELECT 1")
db_pool = self.hs.get_datastores().databases[0] db_pool = self.hs.get_datastores().databases[0]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Collection, Dict
from unittest import mock from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred from twisted.internet.defer import CancelledError, ensureDeferred
@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
# the results to be returned by the mocked get_partial_state_events # the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {} self._events_dict: Dict[str, bool] = {}
async def get_partial_state_events(events): async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
return {e: self._events_dict[e] for e in events} return {e: self._events_dict[e] for e in events}
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"]) self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker = PartialStateEventsTracker(self.mock_store) self.tracker = PartialStateEventsTracker(self.mock_store)
def test_does_not_block_for_full_state_events(self): def test_does_not_block_for_full_state_events(self) -> None:
self._events_dict = {"event1": False, "event2": False} self._events_dict = {"event1": False, "event2": False}
self.successResultOf( self.successResultOf(
@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
["event1", "event2"] ["event1", "event2"]
) )
def test_blocks_for_partial_state_events(self): def test_blocks_for_partial_state_events(self) -> None:
self._events_dict = {"event1": True, "event2": False} self._events_dict = {"event1": True, "event2": False}
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1") self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d) self.successResultOf(d)
def test_un_partial_state_race(self): def test_un_partial_state_race(self) -> None:
# if the event is un-partial-stated between the initial check and the # if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block. # registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False} self._events_dict = {"event1": True, "event2": False}
async def get_partial_state_events(events): async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
res = {e: self._events_dict[e] for e in events} res = {e: self._events_dict[e] for e in events}
# change the result for next time # change the result for next time
self._events_dict = {"event1": False, "event2": False} self._events_dict = {"event1": False, "event2": False}
@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
) )
def test_un_partial_state_during_get_partial_state_events(self): def test_un_partial_state_during_get_partial_state_events(self) -> None:
# we should correctly handle a call to notify_un_partial_stated during the # we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events. # second call to get_partial_state_events.
self._events_dict = {"event1": True, "event2": False} self._events_dict = {"event1": True, "event2": False}
async def get_partial_state_events1(events): async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
self.mock_store.get_partial_state_events.side_effect = ( self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2 get_partial_state_events2
) )
return {e: self._events_dict[e] for e in events} return {e: self._events_dict[e] for e in events}
async def get_partial_state_events2(events): async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
self.tracker.notify_un_partial_stated("event1") self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events} return {e: self._events_dict[e] for e in events}
@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
) )
def test_cancellation(self): def test_cancellation(self) -> None:
self._events_dict = {"event1": True, "event2": False} self._events_dict = {"event1": True, "event2": False}
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"])) d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker = PartialCurrentStateTracker(self.mock_store) self.tracker = PartialCurrentStateTracker(self.mock_store)
def test_does_not_block_for_full_state_rooms(self): def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False) self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_blocks_for_partial_room_state(self): def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d = ensureDeferred(self.tracker.await_full_state("room_id")) d = ensureDeferred(self.tracker.await_full_state("room_id"))
@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("room_id") self.tracker.notify_un_partial_stated("room_id")
self.successResultOf(d) self.successResultOf(d)
def test_un_partial_state_race(self): def test_un_partial_state_race(self) -> None:
# We should correctly handle race between awaiting the state and us # We should correctly handle race between awaiting the state and us
# un-partialling the state # un-partialling the state
async def is_partial_state_room(events): async def is_partial_state_room(room_id: str) -> bool:
self.tracker.notify_un_partial_stated("room_id") self.tracker.notify_un_partial_stated("room_id")
return True return True
@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_cancellation(self): def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d1 = ensureDeferred(self.tracker.await_full_state("room_id")) d1 = ensureDeferred(self.tracker.await_full_state("room_id"))