Improve type checking in `replication.tcp.Stream` (#7291)
The general idea here is to get rid of the type: ignore annotations on all of the current_token and update_function assignments, which would have caught #7290. After a bit of experimentation, it seems like the least-awful way to do this is to pass the offending functions in as parameters to the Stream constructor. Unfortunately that means that the concrete implementations no longer have the same constructor signature as Stream itself, which means that it gets hard to correctly annotate STREAMS_MAP. I've also introduced a couple of new types, to take out some duplication.
This commit is contained in:
parent
c07fca9e2f
commit
67ff7b8ba0
|
@ -0,0 +1 @@
|
|||
Improve typing annotations in `synapse.replication.tcp.streams.Stream`.
|
|
@ -25,8 +25,6 @@ Each stream is defined by the following information:
|
|||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
AccountDataStream,
|
||||
BackfillStream,
|
||||
|
@ -67,8 +65,7 @@ STREAMS_MAP = {
|
|||
GroupServerStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
} # type: Dict[str, Type[Stream]]
|
||||
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"STREAMS_MAP",
|
||||
|
|
|
@ -16,12 +16,11 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000
|
|||
# A stream position token
|
||||
Token = int
|
||||
|
||||
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||
StreamRow = Tuple[Token, tuple]
|
||||
# The type of a stream update row, after JSON deserialisation, but before
|
||||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
#
|
||||
# It consists of a triplet `(updates, new_last_token, limited)`, where:
|
||||
# * `updates` is a list of `(token, row)` entries.
|
||||
# * `new_last_token` is the new position in stream.
|
||||
# * `limited` is whether there are more updates to fetch.
|
||||
#
|
||||
StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
||||
|
||||
# The type of an update_function for a stream
|
||||
#
|
||||
# The arguments are:
|
||||
#
|
||||
# * from_token: the previous stream token: the starting point for fetching the
|
||||
# updates
|
||||
# * to_token: the new stream token: the point to get updates up to
|
||||
# * limit: the maximum number of rows to return
|
||||
#
|
||||
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
|
@ -50,7 +73,7 @@ class Stream(object):
|
|||
ROW_TYPE = None # type: Any
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
def parse_row(cls, row: StreamRow):
|
||||
"""Parse a row received over replication
|
||||
|
||||
By default, assumes that the row data is an array object and passes its contents
|
||||
|
@ -64,7 +87,28 @@ class Stream(object):
|
|||
"""
|
||||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(
|
||||
self,
|
||||
current_token_function: Callable[[], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
"""Instantiate a Stream
|
||||
|
||||
current_token_function and update_function are callbacks which should be
|
||||
implemented by subclasses.
|
||||
|
||||
current_token_function is called to get the current token of the underlying
|
||||
stream.
|
||||
|
||||
update_function is called to get updates for this stream between a pair of
|
||||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
|
||||
Args:
|
||||
current_token_function: callback to get the current token, as above
|
||||
update_function: callback go get stream updates, as above
|
||||
"""
|
||||
self.current_token = current_token_function
|
||||
self.update_function = update_function
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
|
@ -75,7 +119,7 @@ class Stream(object):
|
|||
"""
|
||||
self.last_token = self.current_token()
|
||||
|
||||
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
async def get_updates(self) -> StreamUpdateResult:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
since the stream was constructed if it hadn't been called before).
|
||||
|
||||
|
@ -95,7 +139,7 @@ class Stream(object):
|
|||
|
||||
async def get_updates_since(
|
||||
self, from_token: Token, upto_token: Token, limit: int = 100
|
||||
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
) -> StreamUpdateResult:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
||||
|
@ -112,33 +156,14 @@ class Stream(object):
|
|||
return [], upto_token, False
|
||||
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, limit=limit,
|
||||
from_token, upto_token, limit,
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
def current_token(self):
|
||||
"""Gets the current token of the underlying streams. Should be provided
|
||||
by the sub classes
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_function(self, from_token, current_token, limit):
|
||||
"""Get updates between from_token and to_token.
|
||||
|
||||
Returns:
|
||||
Deferred(list(tuple)): the first entry in the tuple is the token for
|
||||
that update, and the rest of the tuple gets used to construct
|
||||
a ``ROW_TYPE`` instance
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
@ -157,9 +182,7 @@ def db_query_to_update_function(
|
|||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(
|
||||
hs, stream_name: str
|
||||
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
|
@ -168,7 +191,7 @@ def make_http_update_function(
|
|||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
) -> StreamUpdateResult:
|
||||
result = await client(
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
|
@ -202,10 +225,10 @@ class BackfillStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_current_backfill_token,
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
|
||||
|
||||
class PresenceStream(Stream):
|
||||
|
@ -227,19 +250,18 @@ class PresenceStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||
# on the master, query the presence handler
|
||||
presence_handler = hs.get_presence_handler()
|
||||
update_function = db_query_to_update_function(
|
||||
presence_handler.get_all_presence_updates
|
||||
)
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
super().__init__(store.get_current_presence_token, update_function)
|
||||
|
||||
|
||||
class TypingStream(Stream):
|
||||
|
@ -253,15 +275,16 @@ class TypingStream(Stream):
|
|||
def __init__(self, hs):
|
||||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||
# on the master, query the typing handler
|
||||
update_function = db_query_to_update_function(
|
||||
typing_handler.get_all_typing_updates
|
||||
)
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super(TypingStream, self).__init__(hs)
|
||||
super().__init__(typing_handler.get_current_token, update_function)
|
||||
|
||||
|
||||
class ReceiptsStream(Stream):
|
||||
|
@ -281,11 +304,10 @@ class ReceiptsStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_max_receipt_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
||||
|
||||
class PushRulesStream(Stream):
|
||||
|
@ -299,13 +321,15 @@ class PushRulesStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(PushRulesStream, self).__init__(hs)
|
||||
super(PushRulesStream, self).__init__(
|
||||
self._current_token, self._update_function
|
||||
)
|
||||
|
||||
def current_token(self):
|
||||
def _current_token(self) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
|
||||
limited = False
|
||||
|
@ -331,10 +355,10 @@ class PushersStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_pushers_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
|
||||
|
||||
class CachesStream(Stream):
|
||||
|
@ -362,11 +386,10 @@ class CachesStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
)
|
||||
|
||||
|
||||
class PublicRoomsStream(Stream):
|
||||
|
@ -388,11 +411,10 @@ class PublicRoomsStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_current_public_room_stream_id,
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
||||
|
||||
class DeviceListsStream(Stream):
|
||||
|
@ -409,11 +431,10 @@ class DeviceListsStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
||||
|
||||
class ToDeviceStream(Stream):
|
||||
|
@ -427,11 +448,10 @@ class ToDeviceStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_to_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
||||
|
||||
class TagAccountDataStream(Stream):
|
||||
|
@ -447,11 +467,10 @@ class TagAccountDataStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
||||
|
||||
class AccountDataStream(Stream):
|
||||
|
@ -467,11 +486,10 @@ class AccountDataStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
self.store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
|
@ -498,11 +516,10 @@ class GroupServerStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_group_stream_token,
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
||||
|
||||
class UserSignatureStream(Stream):
|
||||
|
@ -516,8 +533,9 @@ class UserSignatureStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
super().__init__(
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
),
|
||||
)
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
from typing import Tuple, Type
|
||||
from typing import Iterable, Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, db_query_to_update_function
|
||||
from ._base import Stream, Token, db_query_to_update_function
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
@ -116,12 +116,14 @@ class EventsStream(Stream):
|
|||
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
super().__init__(
|
||||
self._store.get_current_events_token,
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
async def _update_function(self, from_token, current_token, limit=None):
|
||||
async def _update_function(
|
||||
self, from_token: Token, current_token: Token, limit: int
|
||||
) -> Iterable[tuple]:
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
|
@ -35,7 +33,6 @@ class FederationStream(Stream):
|
|||
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
_QUERY_MASTER = True
|
||||
|
||||
def __init__(self, hs):
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
|
@ -43,10 +40,16 @@ class FederationStream(Stream):
|
|||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
federation_sender = hs.get_federation_sender()
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||
current_token = federation_sender.get_current_token
|
||||
update_function = db_query_to_update_function(
|
||||
federation_sender.get_replication_rows
|
||||
)
|
||||
else:
|
||||
self.current_token = lambda: 0 # type: ignore
|
||||
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
|
||||
current_token = lambda: 0
|
||||
update_function = self._stub_update_function
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
super().__init__(current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
|
|
Loading…
Reference in New Issue