Require types in tests.storage. (#14646)
Adds missing type hints to `tests.storage` package and does not allow untyped definitions.
This commit is contained in:
parent
94bc21e69f
commit
3ac412b4e2
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints.
|
14
mypy.ini
14
mypy.ini
|
@ -88,6 +88,9 @@ disallow_untyped_defs = False
|
||||||
[mypy-tests.*]
|
[mypy-tests.*]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
[mypy-tests.handlers.test_sso]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.handlers.test_user_directory]
|
[mypy-tests.handlers.test_user_directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -103,16 +106,7 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.state.test_profile]
|
[mypy-tests.state.test_profile]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.storage.test_id_generators]
|
[mypy-tests.storage.*]
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-tests.storage.test_profile]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-tests.handlers.test_sso]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-tests.storage.test_user_directory]
|
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.rest.*]
|
[mypy-tests.rest.*]
|
||||||
|
|
|
@ -140,7 +140,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_e2e_device_keys_for_cs_api(
|
async def get_e2e_device_keys_for_cs_api(
|
||||||
self,
|
self,
|
||||||
query_list: List[Tuple[str, Optional[str]]],
|
query_list: Collection[Tuple[str, Optional[str]]],
|
||||||
include_displaynames: bool = True,
|
include_displaynames: bool = True,
|
||||||
) -> Dict[str, Dict[str, JsonDict]]:
|
) -> Dict[str, Dict[str, JsonDict]]:
|
||||||
"""Fetch a list of device keys, formatted suitably for the C/S API.
|
"""Fetch a list of device keys, formatted suitably for the C/S API.
|
||||||
|
|
|
@ -12,8 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import devices
|
from synapse.rest.client import devices
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.user_id = self.register_user("foo", "pass")
|
self.user_id = self.register_user("foo", "pass")
|
||||||
|
|
||||||
def test_background_remove_deleted_devices_from_device_inbox(self):
|
def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
|
||||||
"""Test that the background task to delete old device_inboxes works properly."""
|
"""Test that the background task to delete old device_inboxes works properly."""
|
||||||
|
|
||||||
# create a valid device
|
# create a valid device
|
||||||
|
@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(1, len(res))
|
self.assertEqual(1, len(res))
|
||||||
self.assertEqual(res[0], "cur_device")
|
self.assertEqual(res[0], "cur_device")
|
||||||
|
|
||||||
def test_background_remove_hidden_devices_from_device_inbox(self):
|
def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
|
||||||
"""Test that the background task to delete hidden devices
|
"""Test that the background task to delete hidden devices
|
||||||
from device_inboxes works properly."""
|
from device_inboxes works properly."""
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store: EventsWorkerStore = hs.get_datastores().main
|
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.event_ids.append(event.event_id)
|
self.event_ids.append(event.event_id)
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
with LoggingContext(name="test") as ctx:
|
with LoggingContext(name="test") as ctx:
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.store.have_seen_events(
|
self.store.have_seen_events(
|
||||||
|
@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(res, {self.event_ids[0]})
|
self.assertEqual(res, {self.event_ids[0]})
|
||||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
|
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
|
||||||
|
|
||||||
def test_persisting_event_invalidates_cache(self):
|
def test_persisting_event_invalidates_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that the `have_seen_event` cache
|
Test to make sure that the `have_seen_event` cache
|
||||||
is invalidated after we persist an event and returns
|
is invalidated after we persist an event and returns
|
||||||
|
@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
# That should result in a single db query to lookup
|
# That should result in a single db query to lookup
|
||||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
|
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
|
||||||
|
|
||||||
def test_invalidate_cache_by_room_id(self):
|
def test_invalidate_cache_by_room_id(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that all events associated with the given `(room_id,)`
|
Test to make sure that all events associated with the given `(room_id,)`
|
||||||
are invalidated in the `have_seen_event` cache.
|
are invalidated in the `have_seen_event` cache.
|
||||||
|
@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store: EventsWorkerStore = hs.get_datastores().main
|
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||||
|
|
||||||
self.user = self.register_user("user", "pass")
|
self.user = self.register_user("user", "pass")
|
||||||
|
@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||||
# Reset the event cache so the tests start with it empty
|
# Reset the event cache so the tests start with it empty
|
||||||
self.get_success(self.store._get_event_cache.clear())
|
self.get_success(self.store._get_event_cache.clear())
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
"""Test that we cache events that we pull from the DB."""
|
"""Test that we cache events that we pull from the DB."""
|
||||||
|
|
||||||
with LoggingContext("test") as ctx:
|
with LoggingContext("test") as ctx:
|
||||||
|
@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||||
# We should have fetched the event from the DB
|
# We should have fetched the event from the DB
|
||||||
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
|
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
|
||||||
|
|
||||||
def test_event_ref(self):
|
def test_event_ref(self) -> None:
|
||||||
"""Test that we reuse events that are still in memory but have fallen
|
"""Test that we reuse events that are still in memory but have fallen
|
||||||
out of the cache, rather than requesting them from the DB.
|
out of the cache, rather than requesting them from the DB.
|
||||||
"""
|
"""
|
||||||
|
@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||||
# from the DB
|
# from the DB
|
||||||
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
|
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
|
||||||
|
|
||||||
def test_dedupe(self):
|
def test_dedupe(self) -> None:
|
||||||
"""Test that if we request the same event multiple times we only pull it
|
"""Test that if we request the same event multiple times we only pull it
|
||||||
out once.
|
out once.
|
||||||
"""
|
"""
|
||||||
|
@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||||
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
|
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
|
||||||
"""Test event fetching during a database outage."""
|
"""Test event fetching during a database outage."""
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store: EventsWorkerStore = hs.get_datastores().main
|
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||||
|
|
||||||
self.room_id = f"!room:{hs.hostname}"
|
self.room_id = f"!room:{hs.hostname}"
|
||||||
|
@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store: EventsWorkerStore = hs.get_datastores().main
|
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||||
|
|
||||||
self.user = self.register_user("user", "pass")
|
self.user = self.register_user("user", "pass")
|
||||||
|
@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
unblock: "Deferred[None]" = Deferred()
|
unblock: "Deferred[None]" = Deferred()
|
||||||
original_runWithConnection = self.store.db_pool.runWithConnection
|
original_runWithConnection = self.store.db_pool.runWithConnection
|
||||||
|
|
||||||
async def runWithConnection(*args, **kwargs):
|
# Don't bother with the types here, we just pass into the original function.
|
||||||
|
async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
await unblock
|
await unblock
|
||||||
return await original_runWithConnection(*args, **kwargs)
|
return await original_runWithConnection(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
|
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
|
||||||
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
|
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
|
||||||
|
|
||||||
def test_first_get_event_cancelled(self):
|
def test_first_get_event_cancelled(self) -> None:
|
||||||
"""Test cancellation of the first `get_event` call sharing a database fetch.
|
"""Test cancellation of the first `get_event` call sharing a database fetch.
|
||||||
|
|
||||||
The first `get_event` call is the one which initiates the fetch. We expect the
|
The first `get_event` call is the one which initiates the fetch. We expect the
|
||||||
|
@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
# The second `get_event` call should complete successfully.
|
# The second `get_event` call should complete successfully.
|
||||||
self.get_success(get_event2)
|
self.get_success(get_event2)
|
||||||
|
|
||||||
def test_second_get_event_cancelled(self):
|
def test_second_get_event_cancelled(self) -> None:
|
||||||
"""Test cancellation of the second `get_event` call sharing a database fetch."""
|
"""Test cancellation of the second `get_event` call sharing a database fetch."""
|
||||||
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
||||||
# Cancel the second `get_event` call.
|
# Cancel the second `get_event` call.
|
||||||
|
|
|
@ -15,18 +15,20 @@
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.base import ReactorBase
|
from twisted.internet.base import ReactorBase
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
|
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class LockTestCase(unittest.HomeserverTestCase):
|
class LockTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def test_acquire_contention(self):
|
def test_acquire_contention(self) -> None:
|
||||||
# Track the number of tasks holding the lock.
|
# Track the number of tasks holding the lock.
|
||||||
# Should be at most 1.
|
# Should be at most 1.
|
||||||
in_lock = 0
|
in_lock = 0
|
||||||
|
@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
release_lock: "Deferred[None]" = Deferred()
|
release_lock: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
async def task():
|
async def task() -> None:
|
||||||
nonlocal in_lock
|
nonlocal in_lock
|
||||||
nonlocal max_in_lock
|
nonlocal max_in_lock
|
||||||
|
|
||||||
|
@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
# At most one task should have held the lock at a time.
|
# At most one task should have held the lock at a time.
|
||||||
self.assertEqual(max_in_lock, 1)
|
self.assertEqual(max_in_lock, 1)
|
||||||
|
|
||||||
def test_simple_lock(self):
|
def test_simple_lock(self) -> None:
|
||||||
"""Test that we can take out a lock and that while we hold it nobody
|
"""Test that we can take out a lock and that while we hold it nobody
|
||||||
else can take it out.
|
else can take it out.
|
||||||
"""
|
"""
|
||||||
|
@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(lock3.__aenter__())
|
self.get_success(lock3.__aenter__())
|
||||||
self.get_success(lock3.__aexit__(None, None, None))
|
self.get_success(lock3.__aexit__(None, None, None))
|
||||||
|
|
||||||
def test_maintain_lock(self):
|
def test_maintain_lock(self) -> None:
|
||||||
"""Test that we don't time out locks while they're still active"""
|
"""Test that we don't time out locks while they're still active"""
|
||||||
|
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
|
@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(lock.__aexit__(None, None, None))
|
self.get_success(lock.__aexit__(None, None, None))
|
||||||
|
|
||||||
def test_timeout_lock(self):
|
def test_timeout_lock(self) -> None:
|
||||||
"""Test that we time out locks if they're not updated for ages"""
|
"""Test that we time out locks if they're not updated for ages"""
|
||||||
|
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
|
@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertFalse(self.get_success(lock.is_still_valid()))
|
self.assertFalse(self.get_success(lock.is_still_valid()))
|
||||||
|
|
||||||
def test_drop(self):
|
def test_drop(self) -> None:
|
||||||
"""Test that dropping the context manager means we stop renewing the lock"""
|
"""Test that dropping the context manager means we stop renewing the lock"""
|
||||||
|
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
|
@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
self.assertIsNotNone(lock2)
|
self.assertIsNotNone(lock2)
|
||||||
|
|
||||||
def test_shutdown(self):
|
def test_shutdown(self) -> None:
|
||||||
"""Test that shutting down Synapse releases the locks"""
|
"""Test that shutting down Synapse releases the locks"""
|
||||||
# Acquire two locks
|
# Acquire two locks
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.user_id = self.register_user("foo", "pass")
|
self.user_id = self.register_user("foo", "pass")
|
||||||
self.token = self.login("foo", "pass")
|
self.token = self.login("foo", "pass")
|
||||||
|
@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
table: str,
|
table: str,
|
||||||
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
|
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
|
||||||
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
|
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
|
||||||
):
|
) -> None:
|
||||||
"""Test that the background update to uniqueify non-thread receipts in
|
"""Test that the background update to uniqueify non-thread receipts in
|
||||||
the given receipts table works properly.
|
the given receipts table works properly.
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
f"Background update did not remove all duplicate receipts from {table}",
|
f"Background update did not remove all duplicate receipts from {table}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_background_receipts_linearized_unique_index(self):
|
def test_background_receipts_linearized_unique_index(self) -> None:
|
||||||
"""Test that the background update to uniqueify non-thread receipts in
|
"""Test that the background update to uniqueify non-thread receipts in
|
||||||
`receipts_linearized` works properly.
|
`receipts_linearized` works properly.
|
||||||
"""
|
"""
|
||||||
|
@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_background_receipts_graph_unique_index(self):
|
def test_background_receipts_graph_unique_index(self) -> None:
|
||||||
"""Test that the background update to uniqueify non-thread receipts in
|
"""Test that the background update to uniqueify non-thread receipts in
|
||||||
`receipts_graph` works properly.
|
`receipts_graph` works properly.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,10 +14,14 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import RoomTypes
|
from synapse.api.constants import RoomTypes
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.room import _BackgroundUpdates
|
from synapse.storage.databases.main.room import _BackgroundUpdates
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.user_id = self.register_user("foo", "pass")
|
self.user_id = self.register_user("foo", "pass")
|
||||||
self.token = self.login("foo", "pass")
|
self.token = self.login("foo", "pass")
|
||||||
|
@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
def test_background_populate_rooms_creator_column(self):
|
def test_background_populate_rooms_creator_column(self) -> None:
|
||||||
"""Test that the background update to populate the rooms creator column
|
"""Test that the background update to populate the rooms creator column
|
||||||
works properly.
|
works properly.
|
||||||
"""
|
"""
|
||||||
|
@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(room_creator_after, self.user_id)
|
self.assertEqual(room_creator_after, self.user_id)
|
||||||
|
|
||||||
def test_background_add_room_type_column(self):
|
def test_background_add_room_type_column(self) -> None:
|
||||||
"""Test that the background update to populate the `room_type` column in
|
"""Test that the background update to populate the `room_type` column in
|
||||||
`room_stats_state` works properly.
|
`room_stats_state` works properly.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
|
||||||
{(1, "user1", "hello"), (2, "user2", "bleb")},
|
{(1, "user1", "hello"), (2, "user2", "bleb")},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_simple_update_many(self):
|
def test_simple_update_many(self) -> None:
|
||||||
"""
|
"""
|
||||||
simple_update_many performs many updates at once.
|
simple_update_many performs many updates at once.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,13 +14,17 @@
|
||||||
|
|
||||||
from typing import Iterable, Optional, Set
|
from typing import Iterable, Optional, Set
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import AccountDataTypes
|
from synapse.api.constants import AccountDataTypes
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, hs, reactor, clock):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
self.user = "@user:test"
|
self.user = "@user:test"
|
||||||
|
|
||||||
|
@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
expected_ignored_user_ids,
|
expected_ignored_user_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_ignoring_users(self):
|
def test_ignoring_users(self) -> None:
|
||||||
"""Basic adding/removing of users from the ignore list."""
|
"""Basic adding/removing of users from the ignore list."""
|
||||||
self._update_ignore_list("@other:test", "@another:remote")
|
self._update_ignore_list("@other:test", "@another:remote")
|
||||||
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
|
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
|
||||||
|
@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
# Check the removed user.
|
# Check the removed user.
|
||||||
self.assert_ignorers("@another:remote", {self.user})
|
self.assert_ignorers("@another:remote", {self.user})
|
||||||
|
|
||||||
def test_caching(self):
|
def test_caching(self) -> None:
|
||||||
"""Ensure that caching works properly between different users."""
|
"""Ensure that caching works properly between different users."""
|
||||||
# The first user ignores a user.
|
# The first user ignores a user.
|
||||||
self._update_ignore_list("@other:test")
|
self._update_ignore_list("@other:test")
|
||||||
|
@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.assert_ignored(self.user, set())
|
self.assert_ignored(self.user, set())
|
||||||
self.assert_ignorers("@other:test", {"@second:test"})
|
self.assert_ignorers("@other:test", {"@second:test"})
|
||||||
|
|
||||||
def test_invalid_data(self):
|
def test_invalid_data(self) -> None:
|
||||||
"""Invalid data ends up clearing out the ignored users list."""
|
"""Invalid data ends up clearing out the ignored users list."""
|
||||||
# Add some data and ensure it is there.
|
# Add some data and ensure it is there.
|
||||||
self._update_ignore_list("@other:test")
|
self._update_ignore_list("@other:test")
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.database import DatabasePool, make_conn
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
|
||||||
from synapse.storage.databases.main.appservice import (
|
from synapse.storage.databases.main.appservice import (
|
||||||
ApplicationServiceStore,
|
ApplicationServiceStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
|
@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
|
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super(ApplicationServiceStoreTestCase, self).setUp()
|
super(ApplicationServiceStoreTestCase, self).setUp()
|
||||||
|
|
||||||
self.as_yaml_files: List[str] = []
|
self.as_yaml_files: List[str] = []
|
||||||
|
@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
super(ApplicationServiceStoreTestCase, self).tearDown()
|
super(ApplicationServiceStoreTestCase, self).tearDown()
|
||||||
|
|
||||||
def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
|
def _add_appservice(
|
||||||
|
self, as_token: str, id: str, url: str, hs_token: str, sender: str
|
||||||
|
) -> None:
|
||||||
as_yaml = {
|
as_yaml = {
|
||||||
"url": url,
|
"url": url,
|
||||||
"as_token": as_token,
|
"as_token": as_token,
|
||||||
|
@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||||
database, make_conn(db_config, self.engine, "test"), self.hs
|
database, make_conn(db_config, self.engine, "test"), self.hs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_service(self, url, as_token, id) -> None:
|
def _add_service(self, url: str, as_token: str, id: str) -> None:
|
||||||
as_yaml = {
|
as_yaml = {
|
||||||
"url": url,
|
"url": url,
|
||||||
"as_token": as_token,
|
"as_token": as_token,
|
||||||
|
@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||||
outfile.write(yaml.dump(as_yaml))
|
outfile.write(yaml.dump(as_yaml))
|
||||||
self.as_yaml_files.append(as_token)
|
self.as_yaml_files.append(as_token)
|
||||||
|
|
||||||
def _set_state(self, id: str, state: ApplicationServiceState):
|
def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
|
||||||
return self.db_pool.runOperation(
|
return self.db_pool.runOperation(
|
||||||
self.engine.convert_param_style(
|
self.engine.convert_param_style(
|
||||||
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
|
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
|
||||||
|
@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||||
(id, state.value),
|
(id, state.value),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _insert_txn(self, as_id, txn_id, events):
|
def _insert_txn(
|
||||||
|
self, as_id: str, txn_id: int, events: List[Mock]
|
||||||
|
) -> "defer.Deferred[None]":
|
||||||
return self.db_pool.runOperation(
|
return self.db_pool.runOperation(
|
||||||
self.engine.convert_param_style(
|
self.engine.convert_param_style(
|
||||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||||
|
@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# required for ApplicationServiceTransactionStoreTestCase tests
|
# required for ApplicationServiceTransactionStoreTestCase tests
|
||||||
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
|
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs) -> None:
|
def __init__(
|
||||||
|
self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
|
||||||
|
) -> None:
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
|
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
|
||||||
def _write_config(self, suffix, **kwargs) -> str:
|
def _write_config(self, suffix: str, **kwargs: str) -> str:
|
||||||
vals = {
|
vals = {
|
||||||
"id": "id" + suffix,
|
"id": "id" + suffix,
|
||||||
"url": "url" + suffix,
|
"url": "url" + suffix,
|
||||||
|
|
|
@ -12,8 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Generator
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -30,7 +30,7 @@ from tests.utils import default_config
|
||||||
class SQLBaseStoreTestCase(unittest.TestCase):
|
class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
"""Test the "simple" SQL generating methods in SQLBaseStore."""
|
"""Test the "simple" SQL generating methods in SQLBaseStore."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.db_pool = Mock(spec=["runInteraction"])
|
self.db_pool = Mock(spec=["runInteraction"])
|
||||||
self.mock_txn = Mock()
|
self.mock_txn = Mock()
|
||||||
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
|
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
|
||||||
|
@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.mock_conn.rollback.return_value = None
|
self.mock_conn.rollback.return_value = None
|
||||||
# Our fake runInteraction just runs synchronously inline
|
# Our fake runInteraction just runs synchronously inline
|
||||||
|
|
||||||
def runInteraction(func, *args, **kwargs):
|
def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
|
||||||
return defer.succeed(func(self.mock_txn, *args, **kwargs))
|
return defer.succeed(func(self.mock_txn, *args, **kwargs))
|
||||||
|
|
||||||
self.db_pool.runInteraction = runInteraction
|
self.db_pool.runInteraction = runInteraction
|
||||||
|
|
||||||
def runWithConnection(func, *args, **kwargs):
|
def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
return defer.succeed(func(self.mock_conn, *args, **kwargs))
|
return defer.succeed(func(self.mock_conn, *args, **kwargs))
|
||||||
|
|
||||||
self.db_pool.runWithConnection = runWithConnection
|
self.db_pool.runWithConnection = runWithConnection
|
||||||
|
@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
|
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_insert_1col(self):
|
def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_insert_3cols(self):
|
def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_one_1col(self):
|
def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_one_3col(self):
|
def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
self.mock_txn.fetchone.return_value = (1, 2, 3)
|
self.mock_txn.fetchone.return_value = (1, 2, 3)
|
||||||
|
|
||||||
|
@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_one_missing(self):
|
def test_select_one_missing(
|
||||||
|
self,
|
||||||
|
) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 0
|
self.mock_txn.rowcount = 0
|
||||||
self.mock_txn.fetchone.return_value = None
|
self.mock_txn.fetchone.return_value = None
|
||||||
|
|
||||||
|
@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
self.assertFalse(ret)
|
self.assertFalse(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_select_list(self):
|
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 3
|
self.mock_txn.rowcount = 3
|
||||||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
||||||
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
|
||||||
|
@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_one_1col(self):
|
def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_one_4cols(self):
|
def test_update_one_4cols(
|
||||||
|
self,
|
||||||
|
) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_delete_one(self):
|
def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
|
|
@ -15,11 +15,16 @@
|
||||||
import os.path
|
import os.path
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import prepare_database
|
from synapse.storage import prepare_database
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
Test the background update to clean forward extremities table.
|
Test the background update to clean forward extremities table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
self.room_creator = homeserver.get_room_creation_handler()
|
self.room_creator = homeserver.get_room_creation_handler()
|
||||||
|
|
||||||
|
@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
||||||
self.room_id = info["room_id"]
|
self.room_id = info["room_id"]
|
||||||
|
|
||||||
def run_background_update(self):
|
def run_background_update(self) -> None:
|
||||||
"""Re run the background update to clean up the extremities."""
|
"""Re run the background update to clean up the extremities."""
|
||||||
# Make sure we don't clash with in progress updates.
|
# Make sure we don't clash with in progress updates.
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
"delete_forward_extremities.sql",
|
"delete_forward_extremities.sql",
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_delta_file(txn):
|
def run_delta_file(txn: Cursor) -> None:
|
||||||
prepare_database.executescript(txn, schema_path)
|
prepare_database.executescript(txn, schema_path)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
(room_id,)
|
(room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_soft_failed_extremities_handled_correctly(self):
|
def test_soft_failed_extremities_handled_correctly(self) -> None:
|
||||||
"""Test that extremities are correctly calculated in the presence of
|
"""Test that extremities are correctly calculated in the presence of
|
||||||
soft failed events.
|
soft failed events.
|
||||||
|
|
||||||
|
@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(latest_event_ids, [event_id_4])
|
self.assertEqual(latest_event_ids, [event_id_4])
|
||||||
|
|
||||||
def test_basic_cleanup(self):
|
def test_basic_cleanup(self) -> None:
|
||||||
"""Test that extremities are correctly calculated in the presence of
|
"""Test that extremities are correctly calculated in the presence of
|
||||||
soft failed events.
|
soft failed events.
|
||||||
|
|
||||||
|
@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(latest_event_ids, [event_id_b])
|
self.assertEqual(latest_event_ids, [event_id_b])
|
||||||
|
|
||||||
def test_chain_of_fail_cleanup(self):
|
def test_chain_of_fail_cleanup(self) -> None:
|
||||||
"""Test that extremities are correctly calculated in the presence of
|
"""Test that extremities are correctly calculated in the presence of
|
||||||
soft failed events.
|
soft failed events.
|
||||||
|
|
||||||
|
@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(latest_event_ids, [event_id_b])
|
self.assertEqual(latest_event_ids, [event_id_b])
|
||||||
|
|
||||||
def test_forked_graph_cleanup(self):
|
def test_forked_graph_cleanup(self) -> None:
|
||||||
r"""Test that extremities are correctly calculated in the presence of
|
r"""Test that extremities are correctly calculated in the presence of
|
||||||
soft failed events.
|
soft failed events.
|
||||||
|
|
||||||
|
@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["cleanup_extremities_with_dummy_events"] = True
|
config["cleanup_extremities_with_dummy_events"] = True
|
||||||
return self.setup_test_homeserver(config=config)
|
return self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
self.room_creator = homeserver.get_room_creation_handler()
|
self.room_creator = homeserver.get_room_creation_handler()
|
||||||
self.event_creator_handler = homeserver.get_event_creation_handler()
|
self.event_creator_handler = homeserver.get_event_creation_handler()
|
||||||
|
@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
self.event_creator = homeserver.get_event_creation_handler()
|
self.event_creator = homeserver.get_event_creation_handler()
|
||||||
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
|
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
|
||||||
|
|
||||||
def test_send_dummy_event(self):
|
def test_send_dummy_event(self) -> None:
|
||||||
self._create_extremity_rich_graph()
|
self._create_extremity_rich_graph()
|
||||||
|
|
||||||
# Pump the reactor repeatedly so that the background updates have a
|
# Pump the reactor repeatedly so that the background updates have a
|
||||||
|
@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
|
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
|
||||||
|
|
||||||
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
|
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
|
||||||
def test_send_dummy_events_when_insufficient_power(self):
|
def test_send_dummy_events_when_insufficient_power(self) -> None:
|
||||||
self._create_extremity_rich_graph()
|
self._create_extremity_rich_graph()
|
||||||
# Criple power levels
|
# Criple power levels
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
|
@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
|
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
|
||||||
|
|
||||||
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
|
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
|
||||||
def test_expiry_logic(self):
|
def test_expiry_logic(self) -> None:
|
||||||
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
|
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
|
||||||
expires old entries correctly.
|
expires old entries correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_extremity_rich_graph(self):
|
def _create_extremity_rich_graph(self) -> None:
|
||||||
"""Helper method to create bushy graph on demand"""
|
"""Helper method to create bushy graph on demand"""
|
||||||
|
|
||||||
event_id_start = self.create_and_send_event(self.room_id, self.user)
|
event_id_start = self.create_and_send_event(self.room_id, self.user)
|
||||||
|
@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(len(latest_event_ids), 50)
|
self.assertEqual(len(latest_event_ids), 50)
|
||||||
|
|
||||||
def _enable_consent_checking(self):
|
def _enable_consent_checking(self) -> None:
|
||||||
"""Helper method to enable consent checking"""
|
"""Helper method to enable consent checking"""
|
||||||
self.event_creator._block_events_without_consent_error = "No consent from user"
|
self.event_creator._block_events_without_consent_error = "No consent from user"
|
||||||
consent_uri_builder = Mock()
|
consent_uri_builder = Mock()
|
||||||
|
|
|
@ -13,15 +13,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.http.site import XForwardedForRequest
|
from synapse.http.site import XForwardedForRequest
|
||||||
from synapse.rest.client import login
|
from synapse.rest.client import login
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
|
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request
|
from tests.server import make_request
|
||||||
|
@ -30,14 +35,10 @@ from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
hs = self.setup_test_homeserver()
|
self.store = hs.get_datastores().main
|
||||||
return hs
|
|
||||||
|
|
||||||
def prepare(self, hs, reactor, clock):
|
def test_insert_new_client_ip(self) -> None:
|
||||||
self.store = self.hs.get_datastores().main
|
|
||||||
|
|
||||||
def test_insert_new_client_ip(self):
|
|
||||||
self.reactor.advance(12345678)
|
self.reactor.advance(12345678)
|
||||||
|
|
||||||
user_id = "@user:id"
|
user_id = "@user:id"
|
||||||
|
@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_insert_new_client_ip_none_device_id(self):
|
def test_insert_new_client_ip_none_device_id(self) -> None:
|
||||||
"""
|
"""
|
||||||
An insert with a device ID of NULL will not create a new entry, but
|
An insert with a device ID of NULL will not create a new entry, but
|
||||||
update an existing entry in the user_ips table.
|
update an existing entry in the user_ips table.
|
||||||
|
@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
def test_get_last_client_ip_by_device(self, after_persisting: bool):
|
def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
|
||||||
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
|
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
|
||||||
self.reactor.advance(12345678)
|
self.reactor.advance(12345678)
|
||||||
|
|
||||||
|
@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_last_client_ip_by_device_combined_data(self):
|
def test_get_last_client_ip_by_device_combined_data(self) -> None:
|
||||||
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
|
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
|
||||||
data together correctly
|
data together correctly
|
||||||
"""
|
"""
|
||||||
|
@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
def test_get_user_ip_and_agents(self, after_persisting: bool):
|
def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
|
||||||
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
|
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
|
||||||
self.reactor.advance(12345678)
|
self.reactor.advance(12345678)
|
||||||
|
|
||||||
|
@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_user_ip_and_agents_combined_data(self):
|
def test_get_user_ip_and_agents_combined_data(self) -> None:
|
||||||
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
|
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
|
||||||
together correctly
|
together correctly
|
||||||
"""
|
"""
|
||||||
|
@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
|
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
|
||||||
def test_disabled_monthly_active_user(self):
|
def test_disabled_monthly_active_user(self) -> None:
|
||||||
user_id = "@user:server"
|
user_id = "@user:server"
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.insert_client_ip(
|
self.store.insert_client_ip(
|
||||||
|
@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(active)
|
self.assertFalse(active)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
||||||
def test_adding_monthly_active_user_when_full(self):
|
def test_adding_monthly_active_user_when_full(self) -> None:
|
||||||
lots_of_users = 100
|
lots_of_users = 100
|
||||||
user_id = "@user:server"
|
user_id = "@user:server"
|
||||||
|
|
||||||
|
@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(active)
|
self.assertFalse(active)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
||||||
def test_adding_monthly_active_user_when_space(self):
|
def test_adding_monthly_active_user_when_space(self) -> None:
|
||||||
user_id = "@user:server"
|
user_id = "@user:server"
|
||||||
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
|
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
|
||||||
self.assertFalse(active)
|
self.assertFalse(active)
|
||||||
|
@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(active)
|
self.assertTrue(active)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
|
||||||
def test_updating_monthly_active_user_when_space(self):
|
def test_updating_monthly_active_user_when_space(self) -> None:
|
||||||
user_id = "@user:server"
|
user_id = "@user:server"
|
||||||
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
|
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
|
||||||
|
|
||||||
|
@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
|
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
|
||||||
self.assertTrue(active)
|
self.assertTrue(active)
|
||||||
|
|
||||||
def test_devices_last_seen_bg_update(self):
|
def test_devices_last_seen_bg_update(self) -> None:
|
||||||
# First make sure we have completed all updates.
|
# First make sure we have completed all updates.
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
|
@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_old_user_ips_pruned(self):
|
def test_old_user_ips_pruned(self) -> None:
|
||||||
# First make sure we have completed all updates.
|
# First make sure we have completed all updates.
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
|
@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(result, [])
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
# But we should still get the correct values for the device
|
# But we should still get the correct values for the device
|
||||||
result = self.get_success(
|
result2 = self.get_success(
|
||||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
r = result[(user_id, device_id)]
|
r = result2[(user_id, device_id)]
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
hs = self.setup_test_homeserver()
|
|
||||||
return hs
|
|
||||||
|
|
||||||
def prepare(self, hs, reactor, clock):
|
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
self.user_id = self.register_user("bob", "abc123", True)
|
self.user_id = self.register_user("bob", "abc123", True)
|
||||||
|
|
||||||
def test_request_with_xforwarded(self):
|
def test_request_with_xforwarded(self) -> None:
|
||||||
"""
|
"""
|
||||||
The IP in X-Forwarded-For is entered into the client IPs table.
|
The IP in X-Forwarded-For is entered into the client IPs table.
|
||||||
"""
|
"""
|
||||||
|
@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|
||||||
{"request": XForwardedForRequest},
|
{"request": XForwardedForRequest},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_request_from_getPeer(self):
|
def test_request_from_getPeer(self) -> None:
|
||||||
"""
|
"""
|
||||||
The IP returned by getPeer is entered into the client IPs table, if
|
The IP returned by getPeer is entered into the client IPs table, if
|
||||||
there's no X-Forwarded-For header.
|
there's no X-Forwarded-For header.
|
||||||
"""
|
"""
|
||||||
self._runtest({}, "127.0.0.1", {})
|
self._runtest({}, "127.0.0.1", {})
|
||||||
|
|
||||||
def _runtest(self, headers, expected_ip, make_request_args):
|
def _runtest(
|
||||||
|
self,
|
||||||
|
headers: Dict[bytes, bytes],
|
||||||
|
expected_ip: str,
|
||||||
|
make_request_args: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
device_id = "bleb"
|
device_id = "bleb"
|
||||||
|
|
||||||
access_token = self.login("bob", "abc123", device_id=device_id)
|
access_token = self.login("bob", "abc123", device_id=device_id)
|
||||||
|
|
|
@ -31,7 +31,7 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class TupleComparisonClauseTestCase(unittest.TestCase):
|
class TupleComparisonClauseTestCase(unittest.TestCase):
|
||||||
def test_native_tuple_comparison(self):
|
def test_native_tuple_comparison(self) -> None:
|
||||||
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
|
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
|
||||||
self.assertEqual(clause, "(a,b) > (?,?)")
|
self.assertEqual(clause, "(a,b) > (?,?)")
|
||||||
self.assertEqual(args, [1, 2])
|
self.assertEqual(args, [1, 2])
|
||||||
|
|
|
@ -12,17 +12,24 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Collection, List, Tuple
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class DeviceStoreTestCase(HomeserverTestCase):
|
class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def add_device_change(self, user_id, device_ids, host):
|
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
|
||||||
"""Add a device list change for the given device to
|
"""Add a device list change for the given device to
|
||||||
`device_lists_outbound_pokes` table.
|
`device_lists_outbound_pokes` table.
|
||||||
"""
|
"""
|
||||||
|
@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_store_new_device(self):
|
def test_store_new_device(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device_id", "display_name")
|
self.store.store_device("user_id", "device_id", "display_name")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
|
assert res is not None
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
"user_id": "user_id",
|
"user_id": "user_id",
|
||||||
|
@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
res,
|
res,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_devices_by_user(self):
|
def test_get_devices_by_user(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device1", "display_name 1")
|
self.store.store_device("user_id", "device1", "display_name 1")
|
||||||
)
|
)
|
||||||
|
@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
res["device2"],
|
res["device2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_count_devices_by_users(self):
|
def test_count_devices_by_users(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device1", "display_name 1")
|
self.store.store_device("user_id", "device1", "display_name 1")
|
||||||
)
|
)
|
||||||
|
@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(3, res)
|
self.assertEqual(3, res)
|
||||||
|
|
||||||
def test_get_device_updates_by_remote(self):
|
def test_get_device_updates_by_remote(self) -> None:
|
||||||
device_ids = ["device_id1", "device_id2"]
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
# Add two device updates with sequential `stream_id`s
|
# Add two device updates with sequential `stream_id`s
|
||||||
|
@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
# Check original device_ids are contained within these updates
|
# Check original device_ids are contained within these updates
|
||||||
self._check_devices_in_updates(device_ids, device_updates)
|
self._check_devices_in_updates(device_ids, device_updates)
|
||||||
|
|
||||||
def test_get_device_updates_by_remote_can_limit_properly(self):
|
def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that `get_device_updates_by_remote` returns an appropriate
|
Tests that `get_device_updates_by_remote` returns an appropriate
|
||||||
stream_id to resume fetching from (without skipping any results).
|
stream_id to resume fetching from (without skipping any results).
|
||||||
|
@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(device_updates, [])
|
self.assertEqual(device_updates, [])
|
||||||
|
|
||||||
def _check_devices_in_updates(self, expected_device_ids, device_updates):
|
def _check_devices_in_updates(
|
||||||
|
self,
|
||||||
|
expected_device_ids: Collection[str],
|
||||||
|
device_updates: List[Tuple[str, JsonDict]],
|
||||||
|
) -> None:
|
||||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||||
|
|
||||||
|
@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||||
|
|
||||||
def test_update_device(self):
|
def test_update_device(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device_id", "display_name 1")
|
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
|
assert res is not None
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do a no-op first
|
# do a no-op first
|
||||||
self.get_success(self.store.update_device("user_id", "device_id"))
|
self.get_success(self.store.update_device("user_id", "device_id"))
|
||||||
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
|
assert res is not None
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do the update
|
# do the update
|
||||||
|
@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check it worked
|
# check it worked
|
||||||
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
|
assert res is not None
|
||||||
self.assertEqual("display_name 2", res["display_name"])
|
self.assertEqual("display_name 2", res["display_name"])
|
||||||
|
|
||||||
def test_update_unknown_device(self):
|
def test_update_unknown_device(self) -> None:
|
||||||
exc = self.get_failure(
|
exc = self.get_failure(
|
||||||
self.store.update_device(
|
self.store.update_device(
|
||||||
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
||||||
|
|
|
@ -12,19 +12,23 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import RoomAlias, RoomID
|
from synapse.types import RoomAlias, RoomID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class DirectoryStoreTestCase(HomeserverTestCase):
|
class DirectoryStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
self.alias = RoomAlias.from_string("#my-room:test")
|
self.alias = RoomAlias.from_string("#my-room:test")
|
||||||
|
|
||||||
def test_room_to_alias(self):
|
def test_room_to_alias(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
|
||||||
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
|
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_alias_to_room(self):
|
def test_alias_to_room(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
|
||||||
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
|
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias(self):
|
def test_delete_alias(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
|
|
@ -12,7 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.e2e_room_keys import RoomKey
|
from synapse.storage.databases.main.e2e_room_keys import RoomKey
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -26,12 +30,12 @@ room_key: RoomKey = {
|
||||||
|
|
||||||
|
|
||||||
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_room_keys_version_delete(self):
|
def test_room_keys_version_delete(self) -> None:
|
||||||
# test that deleting a room key backup deletes the keys
|
# test that deleting a room key backup deletes the keys
|
||||||
version1 = self.get_success(
|
version1 = self.get_success(
|
||||||
self.store.create_e2e_room_keys_version(
|
self.store.create_e2e_room_keys_version(
|
||||||
|
|
|
@ -12,14 +12,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def test_key_without_device_name(self):
|
def test_key_without_device_name(self) -> None:
|
||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
|
@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||||
dev = res["user"]["device"]
|
dev = res["user"]["device"]
|
||||||
self.assertDictContainsSubset(json, dev)
|
self.assertDictContainsSubset(json, dev)
|
||||||
|
|
||||||
def test_reupload_key(self):
|
def test_reupload_key(self) -> None:
|
||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
|
@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(changed)
|
self.assertFalse(changed)
|
||||||
|
|
||||||
def test_get_key_with_device_name(self):
|
def test_get_key_with_device_name(self) -> None:
|
||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
|
@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
|
||||||
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
|
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multiple_devices(self):
|
def test_multiple_devices(self) -> None:
|
||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
|
|
||||||
self.get_success(self.store.store_device("user1", "device1", None))
|
self.get_success(self.store.store_device("user1", "device1", None))
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
@ -22,18 +23,22 @@ from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.databases.main.events import _LinkMap
|
from synapse.storage.databases.main.events import _LinkMap
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class EventChainStoreTestCase(HomeserverTestCase):
|
class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._next_stream_ordering = 1
|
self._next_stream_ordering = 1
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
|
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
|
||||||
works.
|
works.
|
||||||
"""
|
"""
|
||||||
|
@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_out_of_order_events(self):
|
def test_out_of_order_events(self) -> None:
|
||||||
"""Test that we handle persisting events that we don't have the full
|
"""Test that we handle persisting events that we don't have the full
|
||||||
auth chain for yet (which should only happen for out of band memberships).
|
auth chain for yet (which should only happen for out of band memberships).
|
||||||
"""
|
"""
|
||||||
|
@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
def persist(
|
def persist(
|
||||||
self,
|
self,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
):
|
) -> None:
|
||||||
"""Persist the given events and check that the links generated match
|
"""Persist the given events and check that the links generated match
|
||||||
those given.
|
those given.
|
||||||
"""
|
"""
|
||||||
|
@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
e.internal_metadata.stream_ordering = self._next_stream_ordering
|
e.internal_metadata.stream_ordering = self._next_stream_ordering
|
||||||
self._next_stream_ordering += 1
|
self._next_stream_ordering += 1
|
||||||
|
|
||||||
def _persist(txn):
|
def _persist(txn: LoggingTransaction) -> None:
|
||||||
# We need to persist the events to the events and state_events
|
# We need to persist the events to the events and state_events
|
||||||
# tables.
|
# tables.
|
||||||
persist_events_store._store_event_txn(
|
persist_events_store._store_event_txn(
|
||||||
|
@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
|
|
||||||
class LinkMapTestCase(unittest.TestCase):
|
class LinkMapTestCase(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
"""Basic tests for the LinkMap."""
|
"""Basic tests for the LinkMap."""
|
||||||
link_map = _LinkMap()
|
link_map = _LinkMap()
|
||||||
|
|
||||||
|
@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.user_id = self.register_user("foo", "pass")
|
self.user_id = self.register_user("foo", "pass")
|
||||||
self.token = self.login("foo", "pass")
|
self.token = self.login("foo", "pass")
|
||||||
|
@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# Delete the chain cover info.
|
# Delete the chain cover info.
|
||||||
|
|
||||||
def _delete_tables(txn):
|
def _delete_tables(txn: Cursor) -> None:
|
||||||
txn.execute("DELETE FROM event_auth_chains")
|
txn.execute("DELETE FROM event_auth_chains")
|
||||||
txn.execute("DELETE FROM event_auth_chain_links")
|
txn.execute("DELETE FROM event_auth_chain_links")
|
||||||
|
|
||||||
|
@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return room_id, [state1, state2]
|
return room_id, [state1, state2]
|
||||||
|
|
||||||
def test_background_update_single_room(self):
|
def test_background_update_single_room(self) -> None:
|
||||||
"""Test that the background update to calculate auth chains for historic
|
"""Test that the background update to calculate auth chains for historic
|
||||||
rooms works correctly.
|
rooms works correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_background_update_multiple_rooms(self):
|
def test_background_update_multiple_rooms(self) -> None:
|
||||||
"""Test that the background update to calculate auth chains for historic
|
"""Test that the background update to calculate auth chains for historic
|
||||||
rooms works correctly.
|
rooms works correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_background_update_single_large_room(self):
|
def test_background_update_single_large_room(self) -> None:
|
||||||
"""Test that the background update to calculate auth chains for historic
|
"""Test that the background update to calculate auth chains for historic
|
||||||
rooms works correctly.
|
rooms works correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_background_update_multiple_large_room(self):
|
def test_background_update_multiple_large_room(self) -> None:
|
||||||
"""Test that the background update to calculate auth chains for historic
|
"""Test that the background update to calculate auth chains for historic
|
||||||
rooms works correctly.
|
rooms works correctly.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
@ -26,11 +26,12 @@ from synapse.api.room_versions import (
|
||||||
EventFormatVersions,
|
EventFormatVersions,
|
||||||
RoomVersion,
|
RoomVersion,
|
||||||
)
|
)
|
||||||
from synapse.events import _EventInternalMetadata
|
from synapse.events import EventBase, _EventInternalMetadata
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock, json_encoder
|
from synapse.util import Clock, json_encoder
|
||||||
|
|
||||||
|
@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
def test_get_prev_events_for_room(self):
|
def test_get_prev_events_for_room(self) -> None:
|
||||||
room_id = "@ROOM:local"
|
room_id = "@ROOM:local"
|
||||||
|
|
||||||
# add a bunch of events and hashes to act as forward extremities
|
# add a bunch of events and hashes to act as forward extremities
|
||||||
def insert_event(txn, i):
|
def insert_event(txn: Cursor, i: int) -> None:
|
||||||
event_id = "$event_%i:local" % i
|
event_id = "$event_%i:local" % i
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
for i in range(0, 10):
|
for i in range(0, 10):
|
||||||
self.assertEqual("$event_%i:local" % (19 - i), r[i])
|
self.assertEqual("$event_%i:local" % (19 - i), r[i])
|
||||||
|
|
||||||
def test_get_rooms_with_many_extremities(self):
|
def test_get_rooms_with_many_extremities(self) -> None:
|
||||||
room1 = "#room1"
|
room1 = "#room1"
|
||||||
room2 = "#room2"
|
room2 = "#room2"
|
||||||
room3 = "#room3"
|
room3 = "#room3"
|
||||||
|
|
||||||
def insert_event(txn, i, room_id):
|
def insert_event(txn: Cursor, i: int, room_id: str) -> None:
|
||||||
event_id = "$event_%i:local" % i
|
event_id = "$event_%i:local" % i
|
||||||
txn.execute(
|
txn.execute(
|
||||||
(
|
(
|
||||||
|
@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# | |
|
# | |
|
||||||
# K J
|
# K J
|
||||||
|
|
||||||
auth_graph = {
|
auth_graph: Dict[str, List[str]] = {
|
||||||
"a": ["e"],
|
"a": ["e"],
|
||||||
"b": ["e"],
|
"b": ["e"],
|
||||||
"c": ["g", "i"],
|
"c": ["g", "i"],
|
||||||
|
@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Mark the room as maybe having a cover index.
|
# Mark the room as maybe having a cover index.
|
||||||
|
|
||||||
def store_room(txn):
|
def store_room(txn: LoggingTransaction) -> None:
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
"rooms",
|
"rooms",
|
||||||
|
@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# We rudely fiddle with the appropriate tables directly, as that's much
|
# We rudely fiddle with the appropriate tables directly, as that's much
|
||||||
# easier than constructing events properly.
|
# easier than constructing events properly.
|
||||||
|
|
||||||
def insert_event(txn):
|
def insert_event(txn: LoggingTransaction) -> None:
|
||||||
stream_ordering = 0
|
stream_ordering = 0
|
||||||
|
|
||||||
for event_id in auth_graph:
|
for event_id in auth_graph:
|
||||||
|
@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[
|
[
|
||||||
FakeEvent(event_id, room_id, auth_graph[event_id])
|
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||||
for event_id in auth_graph
|
for event_id in auth_graph
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
@parameterized.expand([(True,), (False,)])
|
@parameterized.expand([(True,), (False,)])
|
||||||
def test_auth_chain_ids(self, use_chain_cover_index: bool):
|
def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
|
||||||
room_id = self._setup_auth_chain(use_chain_cover_index)
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
# a and b have the same auth chain.
|
# a and b have the same auth chain.
|
||||||
|
@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.assertCountEqual(auth_chain_ids, ["i", "j"])
|
self.assertCountEqual(auth_chain_ids, ["i", "j"])
|
||||||
|
|
||||||
@parameterized.expand([(True,), (False,)])
|
@parameterized.expand([(True,), (False,)])
|
||||||
def test_auth_difference(self, use_chain_cover_index: bool):
|
def test_auth_difference(self, use_chain_cover_index: bool) -> None:
|
||||||
room_id = self._setup_auth_chain(use_chain_cover_index)
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
# Now actually test that various combinations give the right result:
|
# Now actually test that various combinations give the right result:
|
||||||
|
@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, set())
|
self.assertSetEqual(difference, set())
|
||||||
|
|
||||||
def test_auth_difference_partial_cover(self):
|
def test_auth_difference_partial_cover(self) -> None:
|
||||||
"""Test that we correctly handle rooms where not all events have a chain
|
"""Test that we correctly handle rooms where not all events have a chain
|
||||||
cover calculated. This can happen in some obscure edge cases, including
|
cover calculated. This can happen in some obscure edge cases, including
|
||||||
during the background update that calculates the chain cover for old
|
during the background update that calculates the chain cover for old
|
||||||
|
@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# | |
|
# | |
|
||||||
# K J
|
# K J
|
||||||
|
|
||||||
auth_graph = {
|
auth_graph: Dict[str, List[str]] = {
|
||||||
"a": ["e"],
|
"a": ["e"],
|
||||||
"b": ["e"],
|
"b": ["e"],
|
||||||
"c": ["g", "i"],
|
"c": ["g", "i"],
|
||||||
|
@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# We rudely fiddle with the appropriate tables directly, as that's much
|
# We rudely fiddle with the appropriate tables directly, as that's much
|
||||||
# easier than constructing events properly.
|
# easier than constructing events properly.
|
||||||
|
|
||||||
def insert_event(txn):
|
def insert_event(txn: LoggingTransaction) -> None:
|
||||||
# First insert the room and mark it as having a chain cover.
|
# First insert the room and mark it as having a chain cover.
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[
|
[
|
||||||
FakeEvent(event_id, room_id, auth_graph[event_id])
|
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||||
for event_id in auth_graph
|
for event_id in auth_graph
|
||||||
if event_id != "b"
|
if event_id != "b"
|
||||||
],
|
],
|
||||||
|
@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[FakeEvent("b", room_id, auth_graph["b"])],
|
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.db_pool.simple_update_txn(
|
self.store.db_pool.simple_update_txn(
|
||||||
|
@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
|
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
|
||||||
)
|
)
|
||||||
def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
|
def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
|
||||||
"""Test that pruning of inbound federation queues work"""
|
"""Test that pruning of inbound federation queues work"""
|
||||||
|
|
||||||
room_id = "some_room_id"
|
room_id = "some_room_id"
|
||||||
|
@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
stream_ordering += 1
|
stream_ordering += 1
|
||||||
|
|
||||||
def populate_db(txn: LoggingTransaction):
|
def populate_db(txn: LoggingTransaction) -> None:
|
||||||
# Insert the room to satisfy the foreign key constraint of
|
# Insert the room to satisfy the foreign key constraint of
|
||||||
# `event_failed_pull_attempts`
|
# `event_failed_pull_attempts`
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
|
@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
|
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
|
||||||
|
|
||||||
def test_get_backfill_points_in_room(self):
|
def test_get_backfill_points_in_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure only backfill points that are older and come before
|
Test to make sure only backfill points that are older and come before
|
||||||
the `current_depth` are returned.
|
the `current_depth` are returned.
|
||||||
|
@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
|
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that events we have attempted to backfill (and within
|
Test to make sure that events we have attempted to backfill (and within
|
||||||
backoff timeout duration) do not show up as an event to backfill again.
|
backoff timeout duration) do not show up as an event to backfill again.
|
||||||
|
@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
|
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure after we fake attempt to backfill event "b3" many times,
|
Test to make sure after we fake attempt to backfill event "b3" many times,
|
||||||
we can see retry and see the "b3" again after the backoff timeout duration
|
we can see retry and see the "b3" again after the backoff timeout duration
|
||||||
|
@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"5": 7,
|
"5": 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
def populate_db(txn: LoggingTransaction):
|
def populate_db(txn: LoggingTransaction) -> None:
|
||||||
# Insert the room to satisfy the foreign key constraint of
|
# Insert the room to satisfy the foreign key constraint of
|
||||||
# `event_failed_pull_attempts`
|
# `event_failed_pull_attempts`
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
|
@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
|
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room(self):
|
def test_get_insertion_event_backward_extremities_in_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure only insertion event backward extremities that are
|
Test to make sure only insertion event backward extremities that are
|
||||||
older and come before the `current_depth` are returned.
|
older and come before the `current_depth` are returned.
|
||||||
|
@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
|
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure that insertion events we have attempted to backfill
|
Test to make sure that insertion events we have attempted to backfill
|
||||||
(and within backoff timeout duration) do not show up as an event to
|
(and within backoff timeout duration) do not show up as an event to
|
||||||
|
@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
|
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure after we fake attempt to backfill event
|
Test to make sure after we fake attempt to backfill event
|
||||||
"insertion_eventA" many times, we can see retry and see the
|
"insertion_eventA" many times, we can see retry and see the
|
||||||
|
@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
|
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
|
||||||
|
|
||||||
def test_get_event_ids_to_not_pull_from_backoff(
|
def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
|
||||||
self,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Test to make sure only event IDs we should backoff from are returned.
|
Test to make sure only event IDs we should backoff from are returned.
|
||||||
"""
|
"""
|
||||||
|
@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
|
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
|
||||||
self,
|
self,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test to make sure no event IDs are returned after the backoff duration has
|
Test to make sure no event IDs are returned after the backoff duration has
|
||||||
elapsed.
|
elapsed.
|
||||||
|
@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.assertEqual(event_ids_to_backoff, [])
|
self.assertEqual(event_ids_to_backoff, [])
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(auto_attribs=True)
|
||||||
class FakeEvent:
|
class FakeEvent:
|
||||||
event_id = attr.ib()
|
event_id: str
|
||||||
room_id = attr.ib()
|
room_id: str
|
||||||
auth_events = attr.ib()
|
auth_events: List[str]
|
||||||
|
|
||||||
type = "foo"
|
type = "foo"
|
||||||
state_key = "foo"
|
state_key = "foo"
|
||||||
|
|
||||||
internal_metadata = _EventInternalMetadata({})
|
internal_metadata = _EventInternalMetadata({})
|
||||||
|
|
||||||
def auth_event_ids(self):
|
def auth_event_ids(self) -> List[str]:
|
||||||
return self.auth_events
|
return self.auth_events
|
||||||
|
|
||||||
def is_state(self):
|
def is_state(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class ExtremStatisticsTestCase(HomeserverTestCase):
|
class ExtremStatisticsTestCase(HomeserverTestCase):
|
||||||
def test_exposed_to_prometheus(self):
|
def test_exposed_to_prometheus(self) -> None:
|
||||||
"""
|
"""
|
||||||
Forward extremity counts are exposed via Prometheus.
|
Forward extremity counts are exposed via Prometheus.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,12 +12,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.federation.federation_base import event_from_pdu_json
|
from synapse.federation.federation_base import event_from_pdu_json
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import StateMap
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self._persistence = self.hs.get_storage_controllers().persistence
|
self._persistence = self.hs.get_storage_controllers().persistence
|
||||||
self._state_storage_controller = self.hs.get_storage_controllers().state
|
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||||
|
@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check that the current extremities is the remote event.
|
# Check that the current extremities is the remote event.
|
||||||
self.assert_extremities([self.remote_event_1.event_id])
|
self.assert_extremities([self.remote_event_1.event_id])
|
||||||
|
|
||||||
def persist_event(self, event, state=None):
|
def persist_event(
|
||||||
|
self, event: EventBase, state: Optional[StateMap[str]] = None
|
||||||
|
) -> None:
|
||||||
"""Persist the event, with optional state"""
|
"""Persist the event, with optional state"""
|
||||||
context = self.get_success(
|
context = self.get_success(
|
||||||
self.state.compute_event_context(
|
self.state.compute_event_context(
|
||||||
|
@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.get_success(self._persistence.persist_event(event, context))
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
|
|
||||||
def assert_extremities(self, expected_extremities):
|
def assert_extremities(self, expected_extremities: List[str]) -> None:
|
||||||
"""Assert the current extremities for the room"""
|
"""Assert the current extremities for the room"""
|
||||||
extremities = self.get_success(
|
extremities = self.get_success(
|
||||||
self.store.get_prev_events_for_room(self.room_id)
|
self.store.get_prev_events_for_room(self.room_id)
|
||||||
)
|
)
|
||||||
self.assertCountEqual(extremities, expected_extremities)
|
self.assertCountEqual(extremities, expected_extremities)
|
||||||
|
|
||||||
def test_prune_gap(self):
|
def test_prune_gap(self) -> None:
|
||||||
"""Test that we drop extremities after a gap when we see an event from
|
"""Test that we drop extremities after a gap when we see an event from
|
||||||
the same domain.
|
the same domain.
|
||||||
"""
|
"""
|
||||||
|
@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
|
||||||
def test_do_not_prune_gap_if_state_different(self):
|
def test_do_not_prune_gap_if_state_different(self) -> None:
|
||||||
"""Test that we don't prune extremities after a gap if the resolved
|
"""Test that we don't prune extremities after a gap if the resolved
|
||||||
state is different.
|
state is different.
|
||||||
"""
|
"""
|
||||||
|
@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check that we haven't dropped the old extremity.
|
# Check that we haven't dropped the old extremity.
|
||||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||||
|
|
||||||
def test_prune_gap_if_old(self):
|
def test_prune_gap_if_old(self) -> None:
|
||||||
"""Test that we drop extremities after a gap when the previous extremity
|
"""Test that we drop extremities after a gap when the previous extremity
|
||||||
is "old"
|
is "old"
|
||||||
"""
|
"""
|
||||||
|
@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
|
||||||
def test_do_not_prune_gap_if_other_server(self):
|
def test_do_not_prune_gap_if_other_server(self) -> None:
|
||||||
"""Test that we do not drop extremities after a gap when we see an event
|
"""Test that we do not drop extremities after a gap when we see an event
|
||||||
from a different domain.
|
from a different domain.
|
||||||
"""
|
"""
|
||||||
|
@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||||
|
|
||||||
def test_prune_gap_if_dummy_remote(self):
|
def test_prune_gap_if_dummy_remote(self) -> None:
|
||||||
"""Test that we drop extremities after a gap when the previous extremity
|
"""Test that we drop extremities after a gap when the previous extremity
|
||||||
is a local dummy event and only points to remote events.
|
is a local dummy event and only points to remote events.
|
||||||
"""
|
"""
|
||||||
|
@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id])
|
self.assert_extremities([remote_event_2.event_id])
|
||||||
|
|
||||||
def test_prune_gap_if_dummy_local(self):
|
def test_prune_gap_if_dummy_local(self) -> None:
|
||||||
"""Test that we don't drop extremities after a gap when the previous
|
"""Test that we don't drop extremities after a gap when the previous
|
||||||
extremity is a local dummy event and points to local events.
|
extremity is a local dummy event and points to local events.
|
||||||
"""
|
"""
|
||||||
|
@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
# Check the new extremity is just the new remote event.
|
# Check the new extremity is just the new remote event.
|
||||||
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
||||||
|
|
||||||
def test_do_not_prune_gap_if_not_dummy(self):
|
def test_do_not_prune_gap_if_not_dummy(self) -> None:
|
||||||
"""Test that we do not drop extremities after a gap when the previous extremity
|
"""Test that we do not drop extremities after a gap when the previous extremity
|
||||||
is not a dummy event.
|
is not a dummy event.
|
||||||
"""
|
"""
|
||||||
|
@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self._persistence = self.hs.get_storage_controllers().persistence
|
self._persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
def test_remote_user_rooms_cache_invalidated(self):
|
def test_remote_user_rooms_cache_invalidated(self) -> None:
|
||||||
"""Test that if the server leaves a room the `get_rooms_for_user` cache
|
"""Test that if the server leaves a room the `get_rooms_for_user` cache
|
||||||
is invalidated for remote users.
|
is invalidated for remote users.
|
||||||
"""
|
"""
|
||||||
|
@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
||||||
self.assertEqual(set(rooms), set())
|
self.assertEqual(set(rooms), set())
|
||||||
|
|
||||||
def test_room_remote_user_cache_invalidated(self):
|
def test_room_remote_user_cache_invalidated(self) -> None:
|
||||||
"""Test that if the server leaves a room the `get_users_in_room` cache
|
"""Test that if the server leaves a room the `get_users_in_room` cache
|
||||||
is invalidated for remote users.
|
is invalidated for remote users.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import signedjson.key
|
import signedjson.key
|
||||||
|
import signedjson.types
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
|
|
||||||
|
|
||||||
def decode_verify_key_base64(key_id: str, key_base64: str):
|
def decode_verify_key_base64(
|
||||||
|
key_id: str, key_base64: str
|
||||||
|
) -> signedjson.types.VerifyKey:
|
||||||
key_bytes = unpaddedbase64.decode_base64(key_base64)
|
key_bytes = unpaddedbase64.decode_base64(key_base64)
|
||||||
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
|
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
|
||||||
|
|
||||||
|
@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(
|
||||||
|
|
||||||
|
|
||||||
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
def test_get_server_verify_keys(self):
|
def test_get_server_verify_keys(self) -> None:
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
key_id_1 = "ed25519:key1"
|
key_id_1 = "ed25519:key1"
|
||||||
|
@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# non-existent result gives None
|
# non-existent result gives None
|
||||||
self.assertIsNone(res[("server1", "ed25519:key3")])
|
self.assertIsNone(res[("server1", "ed25519:key3")])
|
||||||
|
|
||||||
def test_cache(self):
|
def test_cache(self) -> None:
|
||||||
"""Check that updates correctly invalidate the cache."""
|
"""Check that updates correctly invalidate the cache."""
|
||||||
|
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
|
@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(FORTY_DAYS)
|
self.reactor.advance(FORTY_DAYS)
|
||||||
|
|
||||||
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
|
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
|
||||||
def test_initialise_reserved_users(self):
|
def test_initialise_reserved_users(self) -> None:
|
||||||
threepids = self.hs.config.server.mau_limits_reserved_threepids
|
threepids = self.hs.config.server.mau_limits_reserved_threepids
|
||||||
|
|
||||||
# register three users, of which two have reserved 3pids, and a third
|
# register three users, of which two have reserved 3pids, and a third
|
||||||
|
@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
active_count = self.get_success(self.store.get_monthly_active_count())
|
active_count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(active_count, 3)
|
self.assertEqual(active_count, 3)
|
||||||
|
|
||||||
def test_can_insert_and_count_mau(self):
|
def test_can_insert_and_count_mau(self) -> None:
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(count, 0)
|
self.assertEqual(count, 0)
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(count, 1)
|
self.assertEqual(count, 1)
|
||||||
|
|
||||||
def test_appservice_user_not_counted_in_mau(self):
|
def test_appservice_user_not_counted_in_mau(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.register_user(
|
self.store.register_user(
|
||||||
user_id="@appservice_user:server", appservice_id="wibble"
|
user_id="@appservice_user:server", appservice_id="wibble"
|
||||||
|
@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(count, 0)
|
self.assertEqual(count, 0)
|
||||||
|
|
||||||
def test_user_last_seen_monthly_active(self):
|
def test_user_last_seen_monthly_active(self) -> None:
|
||||||
user_id1 = "@user1:server"
|
user_id1 = "@user1:server"
|
||||||
user_id2 = "@user2:server"
|
user_id2 = "@user2:server"
|
||||||
user_id3 = "@user3:server"
|
user_id3 = "@user3:server"
|
||||||
|
@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@override_config({"max_mau_value": 5})
|
@override_config({"max_mau_value": 5})
|
||||||
def test_reap_monthly_active_users(self):
|
def test_reap_monthly_active_users(self) -> None:
|
||||||
initial_users = 10
|
initial_users = 10
|
||||||
for i in range(initial_users):
|
for i in range(initial_users):
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
# Note that below says mau_limit (no s), this is the name of the config
|
# Note that below says mau_limit (no s), this is the name of the config
|
||||||
# value, although it gets stored on the config object as mau_limits.
|
# value, although it gets stored on the config object as mau_limits.
|
||||||
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
|
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
|
||||||
def test_reap_monthly_active_users_reserved_users(self):
|
def test_reap_monthly_active_users_reserved_users(self) -> None:
|
||||||
"""Tests that reaping correctly handles reaping where reserved users are
|
"""Tests that reaping correctly handles reaping where reserved users are
|
||||||
present"""
|
present"""
|
||||||
threepids = self.hs.config.server.mau_limits_reserved_threepids
|
threepids = self.hs.config.server.mau_limits_reserved_threepids
|
||||||
|
@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(count, self.hs.config.server.max_mau_value)
|
self.assertEqual(count, self.hs.config.server.max_mau_value)
|
||||||
|
|
||||||
def test_populate_monthly_users_is_guest(self):
|
def test_populate_monthly_users_is_guest(self) -> None:
|
||||||
# Test that guest users are not added to mau list
|
# Test that guest users are not added to mau list
|
||||||
user_id = "@user_id:host"
|
user_id = "@user_id:host"
|
||||||
|
|
||||||
|
@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user.assert_not_called()
|
self.store.upsert_monthly_active_user.assert_not_called()
|
||||||
|
|
||||||
def test_populate_monthly_users_should_update(self):
|
def test_populate_monthly_users_should_update(self) -> None:
|
||||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||||
|
@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user.assert_called_once()
|
self.store.upsert_monthly_active_user.assert_called_once()
|
||||||
|
|
||||||
def test_populate_monthly_users_should_not_update(self):
|
def test_populate_monthly_users_should_not_update(self) -> None:
|
||||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||||
|
@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user.assert_not_called()
|
self.store.upsert_monthly_active_user.assert_not_called()
|
||||||
|
|
||||||
def test_get_reserved_real_user_account(self):
|
def test_get_reserved_real_user_account(self) -> None:
|
||||||
# Test no reserved users, or reserved threepids
|
# Test no reserved users, or reserved threepids
|
||||||
users = self.get_success(self.store.get_registered_reserved_users())
|
users = self.get_success(self.store.get_registered_reserved_users())
|
||||||
self.assertEqual(len(users), 0)
|
self.assertEqual(len(users), 0)
|
||||||
|
@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
users = self.get_success(self.store.get_registered_reserved_users())
|
users = self.get_success(self.store.get_registered_reserved_users())
|
||||||
self.assertEqual(len(users), len(threepids))
|
self.assertEqual(len(users), len(threepids))
|
||||||
|
|
||||||
def test_support_user_not_add_to_mau_limits(self):
|
def test_support_user_not_add_to_mau_limits(self) -> None:
|
||||||
support_user_id = "@support:test"
|
support_user_id = "@support:test"
|
||||||
|
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
|
@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
|
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
|
||||||
)
|
)
|
||||||
def test_track_monthly_users_without_cap(self):
|
def test_track_monthly_users_without_cap(self) -> None:
|
||||||
count = self.get_success(self.store.get_monthly_active_count())
|
count = self.get_success(self.store.get_monthly_active_count())
|
||||||
self.assertEqual(0, count)
|
self.assertEqual(0, count)
|
||||||
|
|
||||||
|
@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(2, count)
|
self.assertEqual(2, count)
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
|
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
|
||||||
def test_no_users_when_not_tracking(self):
|
def test_no_users_when_not_tracking(self) -> None:
|
||||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
|
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user.assert_not_called()
|
self.store.upsert_monthly_active_user.assert_not_called()
|
||||||
|
|
||||||
def test_get_monthly_active_count_by_service(self):
|
def test_get_monthly_active_count_by_service(self) -> None:
|
||||||
appservice1_user1 = "@appservice1_user1:example.com"
|
appservice1_user1 = "@appservice1_user1:example.com"
|
||||||
appservice1_user2 = "@appservice1_user2:example.com"
|
appservice1_user2 = "@appservice1_user2:example.com"
|
||||||
|
|
||||||
|
@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(result[service2], 1)
|
self.assertEqual(result[service2], 1)
|
||||||
self.assertEqual(result[native], 1)
|
self.assertEqual(result[native], 1)
|
||||||
|
|
||||||
def test_get_monthly_active_users_by_service(self):
|
def test_get_monthly_active_users_by_service(self) -> None:
|
||||||
# (No users, no filtering) -> empty result
|
# (No users, no filtering) -> empty result
|
||||||
result = self.get_success(self.store.get_monthly_active_users_by_service())
|
result = self.get_success(self.store.get_monthly_active_users_by_service())
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError, SynapseError
|
from synapse.api.errors import NotFoundError, SynapseError
|
||||||
from synapse.rest.client import room
|
from synapse.rest.client import room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
|
||||||
user_id = "@red:server"
|
user_id = "@red:server"
|
||||||
servlets = [room.register_servlets]
|
servlets = [room.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage_controllers = self.hs.get_storage_controllers()
|
self._storage_controllers = self.hs.get_storage_controllers()
|
||||||
|
|
||||||
def test_purge_history(self):
|
def test_purge_history(self) -> None:
|
||||||
"""
|
"""
|
||||||
Purging a room history will delete everything before the topological point.
|
Purging a room history will delete everything before the topological point.
|
||||||
"""
|
"""
|
||||||
|
@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
|
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
|
||||||
self.get_success(self.store.get_event(last["event_id"]))
|
self.get_success(self.store.get_event(last["event_id"]))
|
||||||
|
|
||||||
def test_purge_history_wont_delete_extrems(self):
|
def test_purge_history_wont_delete_extrems(self) -> None:
|
||||||
"""
|
"""
|
||||||
Purging a room history will delete everything before the topological point.
|
Purging a room history will delete everything before the topological point.
|
||||||
"""
|
"""
|
||||||
|
@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
token = self.get_success(
|
token = self.get_success(
|
||||||
self.store.get_topological_token_for_event(last["event_id"])
|
self.store.get_topological_token_for_event(last["event_id"])
|
||||||
)
|
)
|
||||||
|
assert token.topological is not None
|
||||||
event = f"t{token.topological + 1}-{token.stream + 1}"
|
event = f"t{token.topological + 1}-{token.stream + 1}"
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
|
@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
self.get_success(self.store.get_event(third["event_id"]))
|
self.get_success(self.store.get_event(third["event_id"]))
|
||||||
self.get_success(self.store.get_event(last["event_id"]))
|
self.get_success(self.store.get_event(last["event_id"]))
|
||||||
|
|
||||||
def test_purge_room(self):
|
def test_purge_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
Purging a room will delete everything about it.
|
Purging a room will delete everything about it.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,8 +14,12 @@
|
||||||
|
|
||||||
from typing import Collection, Optional
|
from typing import Collection, Optional
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import ReceiptTypes
|
from synapse.api.constants import ReceiptTypes
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils.event_injection import create_event
|
from tests.test_utils.event_injection import create_event
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"
|
||||||
|
|
||||||
|
|
||||||
class ReceiptTestCase(HomeserverTestCase):
|
class ReceiptTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver) -> None:
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
super().prepare(reactor, clock, homeserver)
|
super().prepare(reactor, clock, homeserver)
|
||||||
|
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
|
@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(res, {})
|
self.assertEqual(res, {})
|
||||||
|
|
||||||
res = self.get_last_unthreaded_receipt(
|
res2 = self.get_last_unthreaded_receipt(
|
||||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(res, None)
|
self.assertIsNone(res2)
|
||||||
|
|
||||||
def test_get_receipts_for_user(self) -> None:
|
def test_get_receipts_for_user(self) -> None:
|
||||||
# Send some events into the first room
|
# Send some events into the first room
|
||||||
|
|
|
@ -11,27 +11,35 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.types import RoomID, UserID
|
from synapse.events import EventBase, _EventInternalMetadata
|
||||||
|
from synapse.events.builder import EventBuilder
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, RoomID, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import create_room
|
from tests.utils import create_room
|
||||||
|
|
||||||
|
|
||||||
class RedactionTestCase(unittest.HomeserverTestCase):
|
class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self) -> JsonDict:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["redaction_retention_period"] = "30d"
|
config["redaction_retention_period"] = "30d"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage = hs.get_storage_controllers()
|
storage = hs.get_storage_controllers()
|
||||||
|
assert storage.persistence is not None
|
||||||
|
self._persistence = storage.persistence
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
||||||
|
@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.depth = 1
|
self.depth = 1
|
||||||
|
|
||||||
def inject_room_member(
|
def inject_room_member( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
room,
|
room: RoomID,
|
||||||
user,
|
user: UserID,
|
||||||
membership,
|
membership: str,
|
||||||
replaces_state=None,
|
extra_content: Optional[JsonDict] = None,
|
||||||
extra_content: Optional[dict] = None,
|
) -> EventBase:
|
||||||
):
|
|
||||||
content = {"membership": membership}
|
content = {"membership": membership}
|
||||||
content.update(extra_content or {})
|
content.update(extra_content or {})
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
|
@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def inject_message(self, room, user, body):
|
def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
|
||||||
self.depth += 1
|
self.depth += 1
|
||||||
|
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
|
@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def inject_redaction(self, room, event_id, user, reason):
|
def inject_redaction(
|
||||||
|
self, room: RoomID, event_id: str, user: UserID, reason: str
|
||||||
|
) -> EventBase:
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
{
|
{
|
||||||
|
@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def test_redact(self):
|
def test_redact(self) -> None:
|
||||||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
||||||
|
|
||||||
msg_event = self.inject_message(self.room1, self.u_alice, "t")
|
msg_event = self.inject_message(self.room1, self.u_alice, "t")
|
||||||
|
@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
event.unsigned["redacted_because"],
|
event.unsigned["redacted_because"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_redact_join(self):
|
def test_redact_join(self) -> None:
|
||||||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
||||||
|
|
||||||
msg_event = self.inject_room_member(
|
msg_event = self.inject_room_member(
|
||||||
|
@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
event.unsigned["redacted_because"],
|
event.unsigned["redacted_because"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_circular_redaction(self):
|
def test_circular_redaction(self) -> None:
|
||||||
redaction_event_id1 = "$redaction1_id:test"
|
redaction_event_id1 = "$redaction1_id:test"
|
||||||
redaction_event_id2 = "$redaction2_id:test"
|
redaction_event_id2 = "$redaction2_id:test"
|
||||||
|
|
||||||
class EventIdManglingBuilder:
|
class EventIdManglingBuilder:
|
||||||
def __init__(self, base_builder, event_id):
|
def __init__(self, base_builder: EventBuilder, event_id: str):
|
||||||
self._base_builder = base_builder
|
self._base_builder = base_builder
|
||||||
self._event_id = event_id
|
self._event_id = event_id
|
||||||
|
|
||||||
|
@ -227,31 +236,33 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
prev_event_ids: List[str],
|
prev_event_ids: List[str],
|
||||||
auth_event_ids: Optional[List[str]],
|
auth_event_ids: Optional[List[str]],
|
||||||
depth: Optional[int] = None,
|
depth: Optional[int] = None,
|
||||||
):
|
) -> EventBase:
|
||||||
built_event = await self._base_builder.build(
|
built_event = await self._base_builder.build(
|
||||||
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
|
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
built_event._event_id = self._event_id
|
built_event._event_id = self._event_id # type: ignore[attr-defined]
|
||||||
built_event._dict["event_id"] = self._event_id
|
built_event._dict["event_id"] = self._event_id
|
||||||
assert built_event.event_id == self._event_id
|
assert built_event.event_id == self._event_id
|
||||||
|
|
||||||
return built_event
|
return built_event
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def room_id(self):
|
def room_id(self) -> str:
|
||||||
return self._base_builder.room_id
|
return self._base_builder.room_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self) -> str:
|
||||||
return self._base_builder.type
|
return self._base_builder.type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def internal_metadata(self):
|
def internal_metadata(self) -> _EventInternalMetadata:
|
||||||
return self._base_builder.internal_metadata
|
return self._base_builder.internal_metadata
|
||||||
|
|
||||||
event_1, context_1 = self.get_success(
|
event_1, context_1 = self.get_success(
|
||||||
self.event_creation_handler.create_new_client_event(
|
self.event_creation_handler.create_new_client_event(
|
||||||
|
cast(
|
||||||
|
EventBuilder,
|
||||||
EventIdManglingBuilder(
|
EventIdManglingBuilder(
|
||||||
self.event_builder_factory.for_room_version(
|
self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
|
@ -264,14 +275,17 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
redaction_event_id1,
|
redaction_event_id1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._storage.persistence.persist_event(event_1, context_1))
|
self.get_success(self._persistence.persist_event(event_1, context_1))
|
||||||
|
|
||||||
event_2, context_2 = self.get_success(
|
event_2, context_2 = self.get_success(
|
||||||
self.event_creation_handler.create_new_client_event(
|
self.event_creation_handler.create_new_client_event(
|
||||||
|
cast(
|
||||||
|
EventBuilder,
|
||||||
EventIdManglingBuilder(
|
EventIdManglingBuilder(
|
||||||
self.event_builder_factory.for_room_version(
|
self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
|
@ -284,10 +298,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
redaction_event_id2,
|
redaction_event_id2,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
|
self.get_success(self._persistence.persist_event(event_2, context_2))
|
||||||
|
|
||||||
# fetch one of the redactions
|
# fetch one of the redactions
|
||||||
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
||||||
|
@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
|
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_redact_censor(self):
|
def test_redact_censor(self) -> None:
|
||||||
"""Test that a redacted event gets censored in the DB after a month"""
|
"""Test that a redacted event gets censored in the DB after a month"""
|
||||||
|
|
||||||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
||||||
|
@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assert_dict({"content": {}}, json.loads(event_json))
|
self.assert_dict({"content": {}}, json.loads(event_json))
|
||||||
|
|
||||||
def test_redact_redaction(self):
|
def test_redact_redaction(self) -> None:
|
||||||
"""Tests that we can redact a redaction and can fetch it again."""
|
"""Tests that we can redact a redaction and can fetch it again."""
|
||||||
|
|
||||||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
||||||
|
@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_event(first_redact_event.event_id, allow_none=True)
|
self.store.get_event(first_redact_event.event_id, allow_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_store_redacted_redaction(self):
|
def test_store_redacted_redaction(self) -> None:
|
||||||
"""Tests that we can store a redacted redaction."""
|
"""Tests that we can store a redacted redaction."""
|
||||||
|
|
||||||
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
|
||||||
|
@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(self._persistence.persist_event(redaction_event, context))
|
||||||
self._storage.persistence.persist_event(redaction_event, context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now lets jump to the future where we have censored the redaction event
|
# Now lets jump to the future where we have censored the redaction event
|
||||||
# in the DB.
|
# in the DB.
|
||||||
|
|
|
@ -14,10 +14,15 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.database import LoggingDatabaseConnection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
|
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
|
||||||
from synapse.storage.schema import SCHEMA_VERSION
|
from synapse.storage.schema import SCHEMA_VERSION
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:
|
||||||
|
|
||||||
|
|
||||||
class WorkerSchemaTests(HomeserverTestCase):
|
class WorkerSchemaTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
federation_http_client=None, homeserver_to_use=GenericWorkerServer
|
federation_http_client=None, homeserver_to_use=GenericWorkerServer
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> JsonDict:
|
||||||
conf = super().default_config()
|
conf = super().default_config()
|
||||||
|
|
||||||
# Mark this as a worker app.
|
# Mark this as a worker app.
|
||||||
|
@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):
|
||||||
|
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
def test_rolling_back(self):
|
def test_rolling_back(self) -> None:
|
||||||
"""Test that workers can start if the DB is a newer schema version"""
|
"""Test that workers can start if the DB is a newer schema version"""
|
||||||
|
|
||||||
db_pool = self.hs.get_datastores().main.db_pool
|
db_pool = self.hs.get_datastores().main.db_pool
|
||||||
|
@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):
|
||||||
|
|
||||||
prepare_database(db_conn, db_pool.engine, self.hs.config)
|
prepare_database(db_conn, db_pool.engine, self.hs.config)
|
||||||
|
|
||||||
def test_not_upgraded_old_schema_version(self):
|
def test_not_upgraded_old_schema_version(self) -> None:
|
||||||
"""Test that workers don't start if the DB has an older schema version"""
|
"""Test that workers don't start if the DB has an older schema version"""
|
||||||
db_pool = self.hs.get_datastores().main.db_pool
|
db_pool = self.hs.get_datastores().main.db_pool
|
||||||
db_conn = LoggingDatabaseConnection(
|
db_conn = LoggingDatabaseConnection(
|
||||||
|
@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
|
||||||
with self.assertRaises(PrepareDatabaseException):
|
with self.assertRaises(PrepareDatabaseException):
|
||||||
prepare_database(db_conn, db_pool.engine, self.hs.config)
|
prepare_database(db_conn, db_pool.engine, self.hs.config)
|
||||||
|
|
||||||
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
|
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that workers don't start if the DB is on the current schema version,
|
Test that workers don't start if the DB is on the current schema version,
|
||||||
but there are still outstanding delta migrations to run.
|
but there are still outstanding delta migrations to run.
|
||||||
|
|
|
@ -12,14 +12,18 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import RoomAlias, RoomID, UserID
|
from synapse.types import RoomAlias, RoomID, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class RoomStoreTestCase(HomeserverTestCase):
|
class RoomStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# We can't test RoomStore on its own without the DirectoryStore, for
|
# We can't test RoomStore on its own without the DirectoryStore, for
|
||||||
# management of the 'room_aliases' table
|
# management of the 'room_aliases' table
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_room(self):
|
def test_get_room(self) -> None:
|
||||||
|
res = self.get_success(self.store.get_room(self.room.to_string()))
|
||||||
|
assert res is not None
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
"room_id": self.room.to_string(),
|
"room_id": self.room.to_string(),
|
||||||
"creator": self.u_creator.to_string(),
|
"creator": self.u_creator.to_string(),
|
||||||
"is_public": True,
|
"is_public": True,
|
||||||
},
|
},
|
||||||
(self.get_success(self.store.get_room(self.room.to_string()))),
|
res,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_room_unknown_room(self):
|
def test_get_room_unknown_room(self) -> None:
|
||||||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
|
||||||
|
|
||||||
def test_get_room_with_stats(self):
|
def test_get_room_with_stats(self) -> None:
|
||||||
|
res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
|
||||||
|
assert res is not None
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
"room_id": self.room.to_string(),
|
"room_id": self.room.to_string(),
|
||||||
"creator": self.u_creator.to_string(),
|
"creator": self.u_creator.to_string(),
|
||||||
"public": True,
|
"public": True,
|
||||||
},
|
},
|
||||||
(self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
|
res,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_room_with_stats_unknown_room(self):
|
def test_get_room_with_stats_unknown_room(self) -> None:
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
(self.get_success(self.store.get_room_with_stats("!uknown:test"))),
|
self.get_success(self.store.get_room_with_stats("!uknown:test"))
|
||||||
)
|
)
|
||||||
|
|
|
@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_null_byte(self):
|
def test_null_byte(self) -> None:
|
||||||
"""
|
"""
|
||||||
Postgres/SQLite don't like null bytes going into the search tables. Internally
|
Postgres/SQLite don't like null bytes going into the search tables. Internally
|
||||||
we replace those with a space.
|
we replace those with a space.
|
||||||
|
@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
if isinstance(store.database_engine, PostgresEngine):
|
if isinstance(store.database_engine, PostgresEngine):
|
||||||
self.assertIn("alice", result.get("highlights"))
|
self.assertIn("alice", result.get("highlights"))
|
||||||
|
|
||||||
def test_non_string(self):
|
def test_non_string(self) -> None:
|
||||||
"""Test that non-string `value`s are not inserted into `event_search`.
|
"""Test that non-string `value`s are not inserted into `event_search`.
|
||||||
|
|
||||||
This is particularly important when using sqlite, since a sqlite column can hold
|
This is particularly important when using sqlite, since a sqlite column can hold
|
||||||
|
@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
self.assertEqual(f.value.code, 404)
|
self.assertEqual(f.value.code, 404)
|
||||||
|
|
||||||
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
|
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
|
||||||
def test_sqlite_non_string_deletion_background_update(self):
|
def test_sqlite_non_string_deletion_background_update(self) -> None:
|
||||||
"""Test the background update to delete bad rows from `event_search`."""
|
"""Test the background update to delete bad rows from `event_search`."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
|
||||||
"results array length should match count",
|
"results array length should match count",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_postgres_web_search_for_phrase(self):
|
def test_postgres_web_search_for_phrase(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
|
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
|
||||||
This test is skipped unless the postgres instance supports websearch_to_tsquery.
|
This test is skipped unless the postgres instance supports websearch_to_tsquery.
|
||||||
|
@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):
|
||||||
|
|
||||||
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
|
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
|
||||||
|
|
||||||
def test_sqlite_search(self):
|
def test_sqlite_search(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test sqlite searching for phrases.
|
Test sqlite searching for phrases.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,10 +16,15 @@ import logging
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import RoomID, UserID
|
from synapse.types import JsonDict, RoomID, StateMap, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase, TestCase
|
from tests.unittest import HomeserverTestCase, TestCase
|
||||||
|
|
||||||
|
@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateStoreTestCase(HomeserverTestCase):
|
class StateStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage_controllers()
|
self.storage = hs.get_storage_controllers()
|
||||||
self.state_datastore = self.storage.state.stores.state
|
self.state_datastore = self.storage.state.stores.state
|
||||||
|
@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def inject_state_event(self, room, sender, typ, state_key, content):
|
def inject_state_event(
|
||||||
|
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
|
||||||
|
) -> EventBase:
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
{
|
{
|
||||||
|
@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert self.storage.persistence is not None
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def assertStateMapEqual(self, s1, s2):
|
def assertStateMapEqual(
|
||||||
|
self, s1: StateMap[EventBase], s2: StateMap[EventBase]
|
||||||
|
) -> None:
|
||||||
for t in s1:
|
for t in s1:
|
||||||
# just compare event IDs for simplicity
|
# just compare event IDs for simplicity
|
||||||
self.assertEqual(s1[t].event_id, s2[t].event_id)
|
self.assertEqual(s1[t].event_id, s2[t].event_id)
|
||||||
self.assertEqual(len(s1), len(s2))
|
self.assertEqual(len(s1), len(s2))
|
||||||
|
|
||||||
def test_get_state_groups_ids(self):
|
def test_get_state_groups_ids(self) -> None:
|
||||||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
||||||
e2 = self.inject_state_event(
|
e2 = self.inject_state_event(
|
||||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = self.get_success(
|
state_group_map = self.get_success(
|
||||||
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
|
self.storage.state.get_state_groups_ids(
|
||||||
|
self.room.to_string(), [e2.event_id]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_map = list(state_group_map.values())[0]
|
state_map = list(state_group_map.values())[0]
|
||||||
|
@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
|
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_state_groups(self):
|
def test_get_state_groups(self) -> None:
|
||||||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
||||||
e2 = self.inject_state_event(
|
e2 = self.inject_state_event(
|
||||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = self.get_success(
|
state_group_map = self.get_success(
|
||||||
self.storage.state.get_state_groups(self.room, [e2.event_id])
|
self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
|
||||||
)
|
)
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_list = list(state_group_map.values())[0]
|
state_list = list(state_group_map.values())[0]
|
||||||
|
|
||||||
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
|
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
|
||||||
|
|
||||||
def test_get_state_for_event(self):
|
def test_get_state_for_event(self) -> None:
|
||||||
# this defaults to a linear DAG as each new injection defaults to whatever
|
# this defaults to a linear DAG as each new injection defaults to whatever
|
||||||
# forward extremities are currently in the DB for this room.
|
# forward extremities are currently in the DB for this room.
|
||||||
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
|
||||||
|
@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
class StateFilterDifferenceTestCase(TestCase):
|
class StateFilterDifferenceTestCase(TestCase):
|
||||||
def assert_difference(
|
def assert_difference(
|
||||||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
||||||
):
|
) -> None:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
minuend.approx_difference(subtrahend),
|
minuend.approx_difference(subtrahend),
|
||||||
expected,
|
expected,
|
||||||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_state_filter_difference_no_include_other_minus_no_include_other(self):
|
def test_state_filter_difference_no_include_other_minus_no_include_other(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the StateFilter.approx_difference method
|
Tests the StateFilter.approx_difference method
|
||||||
where, in a.approx_difference(b), both a and b do not have the
|
where, in a.approx_difference(b), both a and b do not have the
|
||||||
|
@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_state_filter_difference_include_other_minus_no_include_other(self):
|
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the StateFilter.approx_difference method
|
Tests the StateFilter.approx_difference method
|
||||||
where, in a.approx_difference(b), only a has the include_others flag set.
|
where, in a.approx_difference(b), only a has the include_others flag set.
|
||||||
|
@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_state_filter_difference_include_other_minus_include_other(self):
|
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the StateFilter.approx_difference method
|
Tests the StateFilter.approx_difference method
|
||||||
where, in a.approx_difference(b), both a and b have the include_others
|
where, in a.approx_difference(b), both a and b have the include_others
|
||||||
|
@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_state_filter_difference_no_include_other_minus_include_other(self):
|
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the StateFilter.approx_difference method
|
Tests the StateFilter.approx_difference method
|
||||||
where, in a.approx_difference(b), only b has the include_others flag set.
|
where, in a.approx_difference(b), only b has the include_others flag set.
|
||||||
|
@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_state_filter_difference_simple_cases(self):
|
def test_state_filter_difference_simple_cases(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests some very simple cases of the StateFilter approx_difference,
|
Tests some very simple cases of the StateFilter approx_difference,
|
||||||
that are not explicitly tested by the more in-depth tests.
|
that are not explicitly tested by the more in-depth tests.
|
||||||
|
@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
|
||||||
|
|
||||||
|
|
||||||
class StateFilterTestCase(TestCase):
|
class StateFilterTestCase(TestCase):
|
||||||
def test_return_expanded(self):
|
def test_return_expanded(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the behaviour of the return_expanded() function that expands
|
Tests the behaviour of the return_expanded() function that expands
|
||||||
StateFilters to include more state types (for the sake of cache hit rate).
|
StateFilters to include more state types (for the sake of cache hit rate).
|
||||||
|
|
|
@ -14,11 +14,15 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, RelationTypes
|
from synapse.api.constants import EventTypes, RelationTypes
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> JsonDict:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["experimental_features"] = {"msc3874_enabled": True}
|
config["experimental_features"] = {"msc3874_enabled": True}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.user_id = self.register_user("test", "test")
|
self.user_id = self.register_user("test", "test")
|
||||||
self.tok = self.login("test", "test")
|
self.tok = self.login("test", "test")
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return [ev.event_id for ev in events]
|
return [ev.event_id for ev in events]
|
||||||
|
|
||||||
def test_filter_relation_senders(self):
|
def test_filter_relation_senders(self) -> None:
|
||||||
# Messages which second user reacted to.
|
# Messages which second user reacted to.
|
||||||
filter = {"related_by_senders": [self.second_user_id]}
|
filter = {"related_by_senders": [self.second_user_id]}
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
|
@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
|
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
|
||||||
|
|
||||||
def test_filter_relation_type(self):
|
def test_filter_relation_type(self) -> None:
|
||||||
# Messages which have annotations.
|
# Messages which have annotations.
|
||||||
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
|
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
|
@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
|
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
|
||||||
|
|
||||||
def test_filter_relation_senders_and_type(self):
|
def test_filter_relation_senders_and_type(self) -> None:
|
||||||
# Messages which second user reacted to.
|
# Messages which second user reacted to.
|
||||||
filter = {
|
filter = {
|
||||||
"related_by_senders": [self.second_user_id],
|
"related_by_senders": [self.second_user_id],
|
||||||
|
@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
chunk = self._filter_messages(filter)
|
chunk = self._filter_messages(filter)
|
||||||
self.assertEqual(chunk, [self.event_id_1])
|
self.assertEqual(chunk, [self.event_id_1])
|
||||||
|
|
||||||
def test_duplicate_relation(self):
|
def test_duplicate_relation(self) -> None:
|
||||||
"""An event should only be returned once if there are multiple relations to it."""
|
"""An event should only be returned once if there are multiple relations to it."""
|
||||||
self.helper.send_event(
|
self.helper.send_event(
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
|
|
|
@ -12,17 +12,23 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.transactions import DestinationRetryTimings
|
from synapse.storage.databases.main.transactions import DestinationRetryTimings
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.retryutils import MAX_RETRY_INTERVAL
|
from synapse.util.retryutils import MAX_RETRY_INTERVAL
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class TransactionStoreTestCase(HomeserverTestCase):
|
class TransactionStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
|
|
||||||
def test_get_set_transactions(self):
|
def test_get_set_transactions(self) -> None:
|
||||||
"""Tests that we can successfully get a non-existent entry for
|
"""Tests that we can successfully get a non-existent entry for
|
||||||
destination retries, as well as testing tht we can set and get
|
destination retries, as well as testing tht we can set and get
|
||||||
correctly.
|
correctly.
|
||||||
|
@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
|
||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initial_set_transactions(self):
|
def test_initial_set_transactions(self) -> None:
|
||||||
"""Tests that we can successfully set the destination retries (there
|
"""Tests that we can successfully set the destination retries (there
|
||||||
was a bug around invalidating the cache that broke this)
|
was a bug around invalidating the cache that broke this)
|
||||||
"""
|
"""
|
||||||
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
|
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
def test_large_destination_retry(self):
|
def test_large_destination_retry(self) -> None:
|
||||||
d = self.store.set_destination_retry_timings(
|
d = self.store.set_destination_retry_timings(
|
||||||
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
|
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
d = self.store.get_destination_retry_timings("example.com")
|
d2 = self.store.get_destination_retry_timings("example.com")
|
||||||
self.get_success(d)
|
self.get_success(d2)
|
||||||
|
|
|
@ -12,21 +12,27 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
|
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
|
||||||
"""Test SQL transaction limit doesn't break transactions."""
|
"""Test SQL transaction limit doesn't break transactions."""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(db_txn_limit=1000)
|
return self.setup_test_homeserver(db_txn_limit=1000)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self) -> None:
|
||||||
db_config = self.hs.config.database.get_single_database()
|
db_config = self.hs.config.database.get_single_database()
|
||||||
self.assertEqual(db_config.config["txn_limit"], 1000)
|
self.assertEqual(db_config.config["txn_limit"], 1000)
|
||||||
|
|
||||||
def test_select(self):
|
def test_select(self) -> None:
|
||||||
def do_select(txn):
|
def do_select(txn: Cursor) -> None:
|
||||||
txn.execute("SELECT 1")
|
txn.execute("SELECT 1")
|
||||||
|
|
||||||
db_pool = self.hs.get_datastores().databases[0]
|
db_pool = self.hs.get_datastores().databases[0]
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Collection, Dict
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from twisted.internet.defer import CancelledError, ensureDeferred
|
from twisted.internet.defer import CancelledError, ensureDeferred
|
||||||
|
@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
# the results to be returned by the mocked get_partial_state_events
|
# the results to be returned by the mocked get_partial_state_events
|
||||||
self._events_dict: Dict[str, bool] = {}
|
self._events_dict: Dict[str, bool] = {}
|
||||||
|
|
||||||
async def get_partial_state_events(events):
|
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
|
||||||
return {e: self._events_dict[e] for e in events}
|
return {e: self._events_dict[e] for e in events}
|
||||||
|
|
||||||
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
|
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
|
||||||
|
@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
|
|
||||||
self.tracker = PartialStateEventsTracker(self.mock_store)
|
self.tracker = PartialStateEventsTracker(self.mock_store)
|
||||||
|
|
||||||
def test_does_not_block_for_full_state_events(self):
|
def test_does_not_block_for_full_state_events(self) -> None:
|
||||||
self._events_dict = {"event1": False, "event2": False}
|
self._events_dict = {"event1": False, "event2": False}
|
||||||
|
|
||||||
self.successResultOf(
|
self.successResultOf(
|
||||||
|
@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
["event1", "event2"]
|
["event1", "event2"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blocks_for_partial_state_events(self):
|
def test_blocks_for_partial_state_events(self) -> None:
|
||||||
self._events_dict = {"event1": True, "event2": False}
|
self._events_dict = {"event1": True, "event2": False}
|
||||||
|
|
||||||
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
||||||
|
@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
self.tracker.notify_un_partial_stated("event1")
|
self.tracker.notify_un_partial_stated("event1")
|
||||||
self.successResultOf(d)
|
self.successResultOf(d)
|
||||||
|
|
||||||
def test_un_partial_state_race(self):
|
def test_un_partial_state_race(self) -> None:
|
||||||
# if the event is un-partial-stated between the initial check and the
|
# if the event is un-partial-stated between the initial check and the
|
||||||
# registration of the listener, it should not block.
|
# registration of the listener, it should not block.
|
||||||
self._events_dict = {"event1": True, "event2": False}
|
self._events_dict = {"event1": True, "event2": False}
|
||||||
|
|
||||||
async def get_partial_state_events(events):
|
async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
|
||||||
res = {e: self._events_dict[e] for e in events}
|
res = {e: self._events_dict[e] for e in events}
|
||||||
# change the result for next time
|
# change the result for next time
|
||||||
self._events_dict = {"event1": False, "event2": False}
|
self._events_dict = {"event1": False, "event2": False}
|
||||||
|
@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_un_partial_state_during_get_partial_state_events(self):
|
def test_un_partial_state_during_get_partial_state_events(self) -> None:
|
||||||
# we should correctly handle a call to notify_un_partial_stated during the
|
# we should correctly handle a call to notify_un_partial_stated during the
|
||||||
# second call to get_partial_state_events.
|
# second call to get_partial_state_events.
|
||||||
|
|
||||||
self._events_dict = {"event1": True, "event2": False}
|
self._events_dict = {"event1": True, "event2": False}
|
||||||
|
|
||||||
async def get_partial_state_events1(events):
|
async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
|
||||||
self.mock_store.get_partial_state_events.side_effect = (
|
self.mock_store.get_partial_state_events.side_effect = (
|
||||||
get_partial_state_events2
|
get_partial_state_events2
|
||||||
)
|
)
|
||||||
return {e: self._events_dict[e] for e in events}
|
return {e: self._events_dict[e] for e in events}
|
||||||
|
|
||||||
async def get_partial_state_events2(events):
|
async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
|
||||||
self.tracker.notify_un_partial_stated("event1")
|
self.tracker.notify_un_partial_stated("event1")
|
||||||
self._events_dict["event1"] = False
|
self._events_dict["event1"] = False
|
||||||
return {e: self._events_dict[e] for e in events}
|
return {e: self._events_dict[e] for e in events}
|
||||||
|
@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cancellation(self):
|
def test_cancellation(self) -> None:
|
||||||
self._events_dict = {"event1": True, "event2": False}
|
self._events_dict = {"event1": True, "event2": False}
|
||||||
|
|
||||||
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
|
||||||
|
@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):
|
||||||
|
|
||||||
self.tracker = PartialCurrentStateTracker(self.mock_store)
|
self.tracker = PartialCurrentStateTracker(self.mock_store)
|
||||||
|
|
||||||
def test_does_not_block_for_full_state_rooms(self):
|
def test_does_not_block_for_full_state_rooms(self) -> None:
|
||||||
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
|
||||||
|
|
||||||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
||||||
|
|
||||||
def test_blocks_for_partial_room_state(self):
|
def test_blocks_for_partial_room_state(self) -> None:
|
||||||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
||||||
|
|
||||||
d = ensureDeferred(self.tracker.await_full_state("room_id"))
|
d = ensureDeferred(self.tracker.await_full_state("room_id"))
|
||||||
|
@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
|
||||||
self.tracker.notify_un_partial_stated("room_id")
|
self.tracker.notify_un_partial_stated("room_id")
|
||||||
self.successResultOf(d)
|
self.successResultOf(d)
|
||||||
|
|
||||||
def test_un_partial_state_race(self):
|
def test_un_partial_state_race(self) -> None:
|
||||||
# We should correctly handle race between awaiting the state and us
|
# We should correctly handle race between awaiting the state and us
|
||||||
# un-partialling the state
|
# un-partialling the state
|
||||||
async def is_partial_state_room(events):
|
async def is_partial_state_room(room_id: str) -> bool:
|
||||||
self.tracker.notify_un_partial_stated("room_id")
|
self.tracker.notify_un_partial_stated("room_id")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
|
||||||
|
|
||||||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
||||||
|
|
||||||
def test_cancellation(self):
|
def test_cancellation(self) -> None:
|
||||||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
||||||
|
|
||||||
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
|
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
|
||||||
|
|
Loading…
Reference in New Issue