In sync wait for worker to catch up since token (#17215)
Otherwise things will get confused. An alternative would be to make sure that for lagging stream we don't return anything (and make sure the returned next_batch token doesn't go backwards). But that is a faff.
This commit is contained in:
parent
4e3868dc46
commit
5624c8b961
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where duplicate events could be sent down sync when using workers that are overloaded.
|
|
@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
|
||||||
# add a lower bound to the Jinja2 dependency.
|
# add a lower bound to the Jinja2 dependency.
|
||||||
Jinja2 = ">=3.0"
|
Jinja2 = ">=3.0"
|
||||||
bleach = ">=1.4.3"
|
bleach = ">=1.4.3"
|
||||||
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
|
# We use `Self`, which were added in `typing-extensions` 4.0.
|
||||||
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
|
typing-extensions = ">=4.0"
|
||||||
# generic over ParamSpecs.
|
|
||||||
typing-extensions = ">=3.10.0.1"
|
|
||||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||||
# with the latest security patches.
|
# with the latest security patches.
|
||||||
cryptography = ">=3.4.7"
|
cryptography = ">=3.4.7"
|
||||||
|
|
|
@ -284,6 +284,23 @@ class SyncResult:
|
||||||
or self.device_lists
|
or self.device_lists
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty(next_batch: StreamToken) -> "SyncResult":
|
||||||
|
"Return a new empty result"
|
||||||
|
return SyncResult(
|
||||||
|
next_batch=next_batch,
|
||||||
|
presence=[],
|
||||||
|
account_data=[],
|
||||||
|
joined=[],
|
||||||
|
invited=[],
|
||||||
|
knocked=[],
|
||||||
|
archived=[],
|
||||||
|
to_device=[],
|
||||||
|
device_lists=DeviceListUpdates(),
|
||||||
|
device_one_time_keys_count={},
|
||||||
|
device_unused_fallback_key_types=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class E2eeSyncResult:
|
class E2eeSyncResult:
|
||||||
|
@ -497,6 +514,24 @@ class SyncHandler:
|
||||||
if context:
|
if context:
|
||||||
context.tag = sync_label
|
context.tag = sync_label
|
||||||
|
|
||||||
|
if since_token is not None:
|
||||||
|
# We need to make sure this worker has caught up with the token. If
|
||||||
|
# this returns false it means we timed out waiting, and we should
|
||||||
|
# just return an empty response.
|
||||||
|
start = self.clock.time_msec()
|
||||||
|
if not await self.notifier.wait_for_stream_token(since_token):
|
||||||
|
logger.warning(
|
||||||
|
"Timed out waiting for worker to catch up. Returning empty response"
|
||||||
|
)
|
||||||
|
return SyncResult.empty(since_token)
|
||||||
|
|
||||||
|
# If we've spent significant time waiting to catch up, take it off
|
||||||
|
# the timeout.
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
if now - start > 1_000:
|
||||||
|
timeout -= now - start
|
||||||
|
timeout = max(timeout, 0)
|
||||||
|
|
||||||
# if we have a since token, delete any to-device messages before that token
|
# if we have a since token, delete any to-device messages before that token
|
||||||
# (since we now know that the device has received them)
|
# (since we now know that the device has received them)
|
||||||
if since_token is not None:
|
if since_token is not None:
|
||||||
|
|
|
@ -763,6 +763,29 @@ class Notifier:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
||||||
|
"""Wait for this worker to catch up with the given stream token."""
|
||||||
|
|
||||||
|
start = self.clock.time_msec()
|
||||||
|
while True:
|
||||||
|
current_token = self.event_sources.get_current_token()
|
||||||
|
if stream_token.is_before_or_eq(current_token):
|
||||||
|
return True
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
if now - start > 10_000:
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Waiting for current token to reach %s; currently at %s",
|
||||||
|
stream_token,
|
||||||
|
current_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: be better
|
||||||
|
await self.clock.sleep(0.5)
|
||||||
|
|
||||||
async def _get_room_ids(
|
async def _get_room_ids(
|
||||||
self, user: UserID, explicit_room_id: Optional[str]
|
self, user: UserID, explicit_room_id: Optional[str]
|
||||||
) -> Tuple[StrCollection, bool]:
|
) -> Tuple[StrCollection, bool]:
|
||||||
|
|
|
@ -95,6 +95,10 @@ class DeltaState:
|
||||||
to_insert: StateMap[str]
|
to_insert: StateMap[str]
|
||||||
no_longer_in_room: bool = False
|
no_longer_in_room: bool = False
|
||||||
|
|
||||||
|
def is_noop(self) -> bool:
|
||||||
|
"""Whether this state delta is actually empty"""
|
||||||
|
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||||
|
|
||||||
|
|
||||||
class PersistEventsStore:
|
class PersistEventsStore:
|
||||||
"""Contains all the functions for writing events to the database.
|
"""Contains all the functions for writing events to the database.
|
||||||
|
@ -1017,6 +1021,9 @@ class PersistEventsStore:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the current state stored in the datatabase for the given room"""
|
"""Update the current state stored in the datatabase for the given room"""
|
||||||
|
|
||||||
|
if state_delta.is_noop():
|
||||||
|
return
|
||||||
|
|
||||||
async with self._stream_id_gen.get_next() as stream_ordering:
|
async with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"update_current_state",
|
"update_current_state",
|
||||||
|
|
|
@ -200,7 +200,11 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
notifier=hs.get_replication_notifier(),
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="events",
|
stream_name="events",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
tables=[
|
||||||
|
("events", "instance_name", "stream_ordering"),
|
||||||
|
("current_state_delta_stream", "instance_name", "stream_id"),
|
||||||
|
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||||
|
],
|
||||||
sequence_name="events_stream_seq",
|
sequence_name="events_stream_seq",
|
||||||
writers=hs.config.worker.writers.events,
|
writers=hs.config.worker.writers.events,
|
||||||
)
|
)
|
||||||
|
@ -210,7 +214,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
notifier=hs.get_replication_notifier(),
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="backfill",
|
stream_name="backfill",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
tables=[
|
||||||
|
("events", "instance_name", "stream_ordering"),
|
||||||
|
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||||
|
],
|
||||||
sequence_name="events_backfill_stream_seq",
|
sequence_name="events_backfill_stream_seq",
|
||||||
positive=False,
|
positive=False,
|
||||||
writers=hs.config.worker.writers.events,
|
writers=hs.config.worker.writers.events,
|
||||||
|
|
|
@ -48,7 +48,7 @@ import attr
|
||||||
from immutabledict import immutabledict
|
from immutabledict import immutabledict
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.types import VerifyKey
|
from signedjson.types import VerifyKey
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import Self, TypedDict
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
from zope.interface import Interface
|
from zope.interface import Interface
|
||||||
|
|
||||||
|
@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
||||||
# at `self.stream`.
|
# at `self.stream`.
|
||||||
return self.instance_map.get(instance_name, self.stream)
|
return self.instance_map.get(instance_name, self.stream)
|
||||||
|
|
||||||
|
def is_before_or_eq(self, other_token: Self) -> bool:
|
||||||
|
"""Wether this token is before the other token, i.e. every constituent
|
||||||
|
part is before the other.
|
||||||
|
|
||||||
|
Essentially it is `self <= other`.
|
||||||
|
|
||||||
|
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||||
|
imply that the reverse is True.
|
||||||
|
"""
|
||||||
|
if self.stream > other_token.stream:
|
||||||
|
return False
|
||||||
|
|
||||||
|
instances = self.instance_map.keys() | other_token.instance_map.keys()
|
||||||
|
for instance in instances:
|
||||||
|
if self.instance_map.get(
|
||||||
|
instance, self.stream
|
||||||
|
) > other_token.instance_map.get(instance, other_token.stream):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, order=False)
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||||
|
@ -1008,6 +1029,41 @@ class StreamToken:
|
||||||
"""Returns the stream ID for the given key."""
|
"""Returns the stream ID for the given key."""
|
||||||
return getattr(self, key.value)
|
return getattr(self, key.value)
|
||||||
|
|
||||||
|
def is_before_or_eq(self, other_token: "StreamToken") -> bool:
|
||||||
|
"""Wether this token is before the other token, i.e. every constituent
|
||||||
|
part is before the other.
|
||||||
|
|
||||||
|
Essentially it is `self <= other`.
|
||||||
|
|
||||||
|
Note: if `self.is_before_or_eq(other_token) is False` then that does not
|
||||||
|
imply that the reverse is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for _, key in StreamKeyType.__members__.items():
|
||||||
|
if key == StreamKeyType.TYPING:
|
||||||
|
# Typing stream is allowed to "reset", and so comparisons don't
|
||||||
|
# really make sense as is.
|
||||||
|
# TODO: Figure out a better way of tracking resets.
|
||||||
|
continue
|
||||||
|
|
||||||
|
self_value = self.get_field(key)
|
||||||
|
other_value = other_token.get_field(key)
|
||||||
|
|
||||||
|
if isinstance(self_value, RoomStreamToken):
|
||||||
|
assert isinstance(other_value, RoomStreamToken)
|
||||||
|
if not self_value.is_before_or_eq(other_value):
|
||||||
|
return False
|
||||||
|
elif isinstance(self_value, MultiWriterStreamToken):
|
||||||
|
assert isinstance(other_value, MultiWriterStreamToken)
|
||||||
|
if not self_value.is_before_or_eq(other_value):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
assert isinstance(other_value, int)
|
||||||
|
if self_value > other_value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
StreamToken.START = StreamToken(
|
StreamToken.START = StreamToken(
|
||||||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
||||||
|
|
Loading…
Reference in New Issue