Type annotations for `test_v2` (#12985)

This commit is contained in:
David Robertson 2022-06-09 09:48:04 +01:00 committed by GitHub
parent 04ca3a52f6
commit 97053c9406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 58 deletions

1
changelog.d/12985.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `tests.state.test_v2`.

View File

@ -56,7 +56,6 @@ exclude = (?x)
|tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_media_storage.py
|tests/server.py |tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
|tests/test_metrics.py |tests/test_metrics.py
|tests/test_server.py |tests/test_server.py
|tests/test_state.py |tests/test_state.py
@ -115,6 +114,9 @@ disallow_untyped_defs = False
[mypy-tests.handlers.test_user_directory] [mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.state.test_profile]
disallow_untyped_defs = True
[mypy-tests.storage.test_profile] [mypy-tests.storage.test_profile]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -17,12 +17,14 @@ import itertools
import logging import logging
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
Collection, Collection,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -30,33 +32,58 @@ from typing import (
overload, overload,
) )
from typing_extensions import Literal from typing_extensions import Literal, Protocol
import synapse.state
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Clock(Protocol):
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
# We only ever sleep(0) though, so that other async functions can make forward
# progress without waiting for stateres to complete.
def sleep(self, duration_ms: float) -> Awaitable[None]:
...
class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...
def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
...
# We want to await to the reactor occasionally during state res when dealing # We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by # with large data sets, so that we don't exhaust the reactor. This is done by
# awaiting to reactor during loops every N iterations. # awaiting to reactor during loops every N iterations.
_AWAIT_AFTER_ITERATIONS = 100 _AWAIT_AFTER_ITERATIONS = 100
__all__ = [
"resolve_events_with_store",
]
async def resolve_events_with_store( async def resolve_events_with_store(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: RoomVersion, room_version: RoomVersion,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> StateMap[str]: ) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm """Resolves the state using the v2 state resolution algorithm
@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
room_id: str, room_id: str,
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> int: ) -> int:
"""Return the power level of the sender of the given event according to """Return the power level of the sender of the given event according to
their auth events. their auth events.
@ -243,9 +270,9 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference( async def _get_auth_chain_difference(
room_id: str, room_id: str,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[Mapping[Any, str]],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> Set[str]: ) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events """Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains. that only appear in some but not all of the auth chains.
@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
room_id: str, room_id: str,
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
auth_diff: Set[str], auth_diff: Set[str],
) -> None: ) -> None:
"""Helper function for _reverse_topological_power_sort that add the event """Helper function for _reverse_topological_power_sort that add the event
@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
room_id: str, room_id: str,
event_ids: Iterable[str], event_ids: Iterable[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
auth_diff: Set[str], auth_diff: Set[str],
) -> List[str]: ) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering, """Returns a list of the event_ids sorted by reverse topological ordering,
@ -501,7 +528,7 @@ async def _iterative_auth_checks(
event_ids: List[str], event_ids: List[str],
base_state: StateMap[str], base_state: StateMap[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> MutableStateMap[str]: ) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
@ -570,7 +597,7 @@ async def _mainline_sort(
event_ids: List[str], event_ids: List[str],
resolved_power_event_id: Optional[str], resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> List[str]: ) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on """Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id the given event resolved_power_event_id
@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event(
event: EventBase, event: EventBase,
mainline_map: Dict[str, int], mainline_map: Dict[str, int],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
) -> int: ) -> int:
"""Get the mainline depths for the given event based on the mainline map """Get the mainline depths for the given event based on the mainline map
@ -683,7 +710,7 @@ async def _get_event(
room_id: str, room_id: str,
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
) -> EventBase: ) -> EventBase:
... ...
@ -694,7 +721,7 @@ async def _get_event(
room_id: str, room_id: str,
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
allow_none: Literal[True], allow_none: Literal[True],
) -> Optional[EventBase]: ) -> Optional[EventBase]:
... ...
@ -704,7 +731,7 @@ async def _get_event(
room_id: str, room_id: str,
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: StateResolutionStore,
allow_none: bool = False, allow_none: bool = False,
) -> Optional[EventBase]: ) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking """Helper function to look up event in event_map, falling back to looking

View File

@ -13,7 +13,17 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from typing import List from typing import (
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
)
import attr import attr
@ -22,13 +32,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.state.v2 import ( from synapse.state.v2 import (
_get_auth_chain_difference, _get_auth_chain_difference,
lexicographical_topological_sort, lexicographical_topological_sort,
resolve_events_with_store, resolve_events_with_store,
) )
from synapse.types import EventID from synapse.types import EventID, StateMap
from tests import unittest from tests import unittest
@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
class FakeClock: class FakeClock:
def sleep(self, msec): def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None) return defer.succeed(None)
@ -60,7 +70,14 @@ class FakeEvent:
as domain. as domain.
""" """
def __init__(self, id, sender, type, state_key, content): def __init__(
self,
id: str,
sender: str,
type: str,
state_key: Optional[str],
content: Mapping[str, object],
):
self.node_id = id self.node_id = id
self.event_id = EventID(id, "example.com").to_string() self.event_id = EventID(id, "example.com").to_string()
self.sender = sender self.sender = sender
@ -69,12 +86,12 @@ class FakeEvent:
self.content = content self.content = content
self.room_id = ROOM_ID self.room_id = ROOM_ID
def to_event(self, auth_events, prev_events): def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
"""Given the auth_events and prev_events, convert to a Frozen Event """Given the auth_events and prev_events, convert to a Frozen Event
Args: Args:
auth_events (list[str]): list of event_ids auth_events: list of event_ids
prev_events (list[str]): list of event_ids prev_events: list of event_ids
Returns: Returns:
FrozenEvent FrozenEvent
@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def test_ban_vs_pl(self): def test_ban_vs_pl(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="PA", id="PA",
@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_join_rule_evasion(self): def test_join_rule_evasion(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="JR", id="JR",
@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_offtopic_pl(self): def test_offtopic_pl(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="PA", id="PA",
@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_topic_basic(self): def test_topic_basic(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_topic_reset(self): def test_topic_reset(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_topic(self): def test_topic(self) -> None:
events = [ events = [
FakeEvent( FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def test_mainline_sort(self): def test_mainline_sort(self) -> None:
"""Tests that the mainline ordering works correctly.""" """Tests that the mainline ordering works correctly."""
events = [ events = [
@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids) self.do_check(events, edges, expected_state_ids)
def do_check(self, events, edges, expected_state_ids): def do_check(
self,
events: List[FakeEvent],
edges: List[List[str]],
expected_state_ids: List[str],
) -> None:
"""Take a list of events and edges and calculate the state of the """Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids` graph at END, and asserts it matches `expected_state_ids`
Args: Args:
events (list[FakeEvent]) events
edges (list[list[str]]): A list of chains of event edges, e.g. edges: A list of chains of event edges, e.g.
`[[A, B, C]]` are edges A->B and B->C. `[[A, B, C]]` are edges A->B and B->C.
expected_state_ids (list[str]): The expected state at END, (excluding expected_state_ids: The expected state at END, (excluding
the keys that haven't changed since START). the keys that haven't changed since START).
""" """
# We want to sort the events into topological order for processing. # We want to sort the events into topological order for processing.
graph = {} graph: Dict[str, Set[str]] = {}
# node_id -> FakeEvent fake_event_map: Dict[str, FakeEvent] = {}
fake_event_map = {}
for ev in itertools.chain(INITIAL_EVENTS, events): for ev in itertools.chain(INITIAL_EVENTS, events):
graph[ev.node_id] = set() graph[ev.node_id] = set()
@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
for a, b in pairwise(edge_list): for a, b in pairwise(edge_list):
graph[a].add(b) graph[a].add(b)
# event_id -> FrozenEvent event_map: Dict[str, EventBase] = {}
event_map = {} state_at_event: Dict[str, StateMap[str]] = {}
# node_id -> state
state_at_event = {}
# We copy the map as the sort consumes the graph # We copy the map as the sort consumes the graph
graph_copy = {k: set(v) for k, v in graph.items()} graph_copy = {k: set(v) for k, v in graph.items()}
@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None: if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id state_after[(fake_event.type, fake_event.state_key)] = event_id
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # This type ignore is a bit sad. Things we have tried:
# 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
# EventBuilder. But this is Hard because the relevant attributes are
# DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
# 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
# change this function to accept Union[Event, EventBase, EventBuilder].
# This seems reasonable to me, but mypy isn't happy. I think that's
# a mypy bug, see https://github.com/python/mypy/issues/5570
# Instead, resort to a type-ignore.
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type]
auth_events = [] auth_events = []
for key in auth_types: for key in auth_types:
@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
class LexicographicalTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase):
def test_simple(self): def test_simple(self) -> None:
graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} graph: Dict[str, Set[str]] = {
"l": {"o"},
"m": {"n", "o"},
"n": {"o"},
"o": set(),
"p": {"o"},
}
res = list(lexicographical_topological_sort(graph, key=lambda x: x)) res = list(lexicographical_topological_sort(graph, key=lambda x: x))
@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
class SimpleParamStateTestCase(unittest.TestCase): class SimpleParamStateTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
# We build up a simple DAG. # We build up a simple DAG.
event_map = {} event_map = {}
@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
] ]
} }
def test_event_map_none(self): def test_event_map_none(self) -> None:
# Test that we correctly handle passing `None` as the event_map # Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
events. events.
""" """
def test_simple(self): def test_simple(self) -> None:
# Test getting the auth difference for a simple chain with a single # Test getting the auth difference for a simple chain with a single
# unpersisted event: # unpersisted event:
# #
@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {c.event_id}) self.assertEqual(difference, {c.event_id})
def test_multiple_unpersisted_chain(self): def test_multiple_unpersisted_chain(self) -> None:
# Test getting the auth difference for a simple chain with multiple # Test getting the auth difference for a simple chain with multiple
# unpersisted events: # unpersisted events:
# #
@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, c.event_id}) self.assertEqual(difference, {d.event_id, c.event_id})
def test_unpersisted_events_different_sets(self): def test_unpersisted_events_different_sets(self) -> None:
# Test getting the auth difference for with multiple unpersisted events # Test getting the auth difference for with multiple unpersisted events
# in different branches: # in different branches:
# #
@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, e.event_id}) self.assertEqual(difference, {d.event_id, e.event_id})
def pairwise(iterable): T = TypeVar("T")
def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
"s -> (s0,s1), (s1,s2), (s2, s3), ..." "s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable) a, b = itertools.tee(iterable)
next(b, None) next(b, None)
@ -829,24 +866,26 @@ def pairwise(iterable):
@attr.s @attr.s
class TestStateResolutionStore: class TestStateResolutionStore:
event_map = attr.ib() event_map: Dict[str, EventBase] = attr.ib()
def get_events(self, event_ids, allow_rejected=False): def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
) -> "defer.Deferred[Dict[str, EventBase]]":
"""Get events from the database """Get events from the database
Args: Args:
event_ids (list): The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events. allow_rejected: If True return rejected events.
Returns: Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. Dict from event_id to event.
""" """
return defer.succeed( return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
) )
def _get_auth_chain(self, event_ids: List[str]) -> List[str]: def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected """Gets the full auth chain for a set of events (including rejected
events). events).
@ -880,7 +919,9 @@ class TestStateResolutionStore:
return list(result) return list(result)
def get_auth_chain_difference(self, room_id, auth_sets): def get_auth_chain_difference(
self, room_id: str, auth_sets: List[Set[str]]
) -> "defer.Deferred[Set[str]]":
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:]) common = set(chains[0]).intersection(*chains[1:])