Fix bug where sync could get stuck when using workers (#17438)
This is because we serialized the token wrong if the instance map contained entries from before the minimum token.
This commit is contained in:
parent
d88ba45db9
commit
df11af14db
|
@ -0,0 +1 @@
|
|||
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.
|
|
@ -699,10 +699,17 @@ class SlidingSyncHandler:
|
|||
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
|
||||
|
||||
# Then assemble the `RoomStreamToken`
|
||||
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
|
||||
membership_snapshot_token = RoomStreamToken(
|
||||
# Minimum position in the `instance_map`
|
||||
stream=min(instance_to_max_stream_ordering_map.values()),
|
||||
instance_map=immutabledict(instance_to_max_stream_ordering_map),
|
||||
stream=min_stream_pos,
|
||||
instance_map=immutabledict(
|
||||
{
|
||||
instance_name: stream_pos
|
||||
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
|
||||
if stream_pos > min_stream_pos
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Since we fetched the users room list at some point in time after the from/to
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#
|
||||
#
|
||||
import abc
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
from enum import Enum
|
||||
|
@ -74,6 +75,9 @@ if TYPE_CHECKING:
|
|||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define a state map type from type/state_key to T (usually an event ID or
|
||||
# event)
|
||||
T = TypeVar("T")
|
||||
|
@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
represented by a default `stream` attribute and a map of instance name to
|
||||
stream position of any writers that are ahead of the default stream
|
||||
position.
|
||||
|
||||
The values in `instance_map` must be greater than the `stream` attribute.
|
||||
"""
|
||||
|
||||
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
|
||||
|
@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
kw_only=True,
|
||||
)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
# Enforce that all instances have a value greater than the min stream
|
||||
# position.
|
||||
for i, v in self.instance_map.items():
|
||||
if v <= self.stream:
|
||||
raise ValueError(
|
||||
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def parse(cls, store: "DataStore", string: str) -> "Self":
|
||||
|
@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
for instance in set(self.instance_map).union(other.instance_map)
|
||||
}
|
||||
|
||||
# Filter out any redundant entries.
|
||||
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
|
||||
|
||||
return attr.evolve(
|
||||
self, stream=max_stream, instance_map=immutabledict(instance_map)
|
||||
)
|
||||
|
@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
def bound_stream_token(self, max_stream: int) -> "Self":
|
||||
"""Bound the stream positions to a maximum value"""
|
||||
|
||||
min_pos = min(self.stream, max_stream)
|
||||
return type(self)(
|
||||
stream=min(self.stream, max_stream),
|
||||
stream=min_pos,
|
||||
instance_map=immutabledict(
|
||||
{k: min(s, max_stream) for k, s in self.instance_map.items()}
|
||||
{
|
||||
k: min(s, max_stream)
|
||||
for k, s in self.instance_map.items()
|
||||
if min(s, max_stream) > min_pos
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
|
||||
)
|
||||
|
||||
super().__attrs_post_init__()
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
|
||||
try:
|
||||
|
@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
|
||||
instance_map = {}
|
||||
for part in parts[1:]:
|
||||
if not part:
|
||||
# Handle tokens of the form `m5~`, which were created by
|
||||
# a bug
|
||||
continue
|
||||
|
||||
key, value = part.split(".")
|
||||
instance_id = int(key)
|
||||
pos = int(value)
|
||||
|
@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
# We log an exception here as even though this *might* be a client
|
||||
# handing a bad token, its more likely that Synapse returned a bad
|
||||
# token (and we really want to catch those!).
|
||||
logger.exception("Failed to parse stream token: %r", string)
|
||||
raise SynapseError(400, "Invalid room stream token %r" % (string,))
|
||||
|
||||
@classmethod
|
||||
|
@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
return self.instance_map.get(instance_name, self.stream)
|
||||
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
"""See class level docstring for information about the format."""
|
||||
|
||||
if self.topological is not None:
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
elif self.instance_map:
|
||||
|
@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
instance_id = await store.get_id_for_instance(name)
|
||||
entries.append(f"{instance_id}.{pos}")
|
||||
|
||||
encoded_map = "~".join(entries)
|
||||
return f"m{self.stream}~{encoded_map}"
|
||||
if entries:
|
||||
encoded_map = "~".join(entries)
|
||||
return f"m{self.stream}~{encoded_map}"
|
||||
return f"s{self.stream}"
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||
|
||||
instance_map = {}
|
||||
for part in parts[1:]:
|
||||
if not part:
|
||||
# Handle tokens of the form `m5~`, which were created by
|
||||
# a bug
|
||||
continue
|
||||
|
||||
key, value = part.split(".")
|
||||
instance_id = int(key)
|
||||
pos = int(value)
|
||||
|
@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
# We log an exception here as even though this *might* be a client
|
||||
# handing a bad token, its more likely that Synapse returned a bad
|
||||
# token (and we really want to catch those!).
|
||||
logger.exception("Failed to parse stream token: %r", string)
|
||||
raise SynapseError(400, "Invalid stream token %r" % (string,))
|
||||
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
"""See class level docstring for information about the format."""
|
||||
|
||||
if self.instance_map:
|
||||
entries = []
|
||||
for name, pos in self.instance_map.items():
|
||||
|
@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||
instance_id = await store.get_id_for_instance(name)
|
||||
entries.append(f"{instance_id}.{pos}")
|
||||
|
||||
encoded_map = "~".join(entries)
|
||||
return f"m{self.stream}~{encoded_map}"
|
||||
if entries:
|
||||
encoded_map = "~".join(entries)
|
||||
return f"m{self.stream}~{encoded_map}"
|
||||
return str(self.stream)
|
||||
else:
|
||||
return str(self.stream)
|
||||
|
||||
|
|
|
@ -19,9 +19,18 @@
|
|||
#
|
||||
#
|
||||
|
||||
from typing import Type
|
||||
from unittest import skipUnless
|
||||
|
||||
from immutabledict import immutabledict
|
||||
from parameterized import parameterized_class
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import (
|
||||
AbstractMultiWriterStreamToken,
|
||||
MultiWriterStreamToken,
|
||||
RoomAlias,
|
||||
RoomStreamToken,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
get_localpart_from_id,
|
||||
|
@ -29,6 +38,7 @@ from synapse.types import (
|
|||
)
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class IsMineIDTests(unittest.HomeserverTestCase):
|
||||
|
@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
|
|||
# this should work with either a unicode or a bytes
|
||||
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
|
||||
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("token_type",),
|
||||
[
|
||||
(MultiWriterStreamToken,),
|
||||
(RoomStreamToken,),
|
||||
],
|
||||
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
|
||||
)
|
||||
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests for the different types of multi writer tokens."""
|
||||
|
||||
token_type: Type[AbstractMultiWriterStreamToken]
|
||||
|
||||
def test_basic_token(self) -> None:
|
||||
"""Test that a simple stream token can be serialized and unserialized"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
token = self.token_type(stream=5)
|
||||
|
||||
string_token = self.get_success(token.to_string(store))
|
||||
|
||||
if isinstance(token, RoomStreamToken):
|
||||
self.assertEqual(string_token, "s5")
|
||||
else:
|
||||
self.assertEqual(string_token, "5")
|
||||
|
||||
parsed_token = self.get_success(self.token_type.parse(store, string_token))
|
||||
self.assertEqual(parsed_token, token)
|
||||
|
||||
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
|
||||
def test_instance_map(self) -> None:
|
||||
"""Test for stream token with instance map"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
|
||||
|
||||
string_token = self.get_success(token.to_string(store))
|
||||
self.assertEqual(string_token, "m5~1.6")
|
||||
|
||||
parsed_token = self.get_success(self.token_type.parse(store, string_token))
|
||||
self.assertEqual(parsed_token, token)
|
||||
|
||||
def test_instance_map_assertion(self) -> None:
|
||||
"""Test that we assert values in the instance map are greater than the
|
||||
min stream position"""
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
|
||||
|
||||
def test_parse_bad_token(self) -> None:
|
||||
"""Test that we can parse tokens produced by a bug in Synapse of the
|
||||
form `m5~`"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
|
||||
self.assertEqual(parsed_token, self.token_type(stream=5))
|
||||
|
|
Loading…
Reference in New Issue