Merge device list replication streams (#14833)
This commit is contained in:
parent
db5145a31d
commit
2b084c5b71
|
@ -1 +1 @@
|
|||
Merge tag and normal account data replication streams.
|
||||
Merge the two account data and the two device list replication streams.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Merge the two account data and the two device list replication streams.
|
|
@ -92,12 +92,13 @@ process, for example:
|
|||
|
||||
## Changes to the account data replication streams
|
||||
|
||||
Synapse has changed the format of the account data replication streams (between
|
||||
workers). This is a forwards- and backwards-incompatible change: v1.75 workers
|
||||
cannot process account data replicated by v1.76 workers, and vice versa.
|
||||
Synapse has changed the format of the account data and devices replication
|
||||
streams (between workers). This is a forwards- and backwards-incompatible
|
||||
change: v1.75 workers cannot process account data replicated by v1.76 workers,
|
||||
and vice versa.
|
||||
|
||||
Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data
|
||||
replication will resume as normal.
|
||||
and device replication will resume as normal.
|
||||
|
||||
|
||||
# Upgrading to v1.74.0
|
||||
|
|
|
@ -187,7 +187,7 @@ class ReplicationDataHandler:
|
|||
elif stream_name == DeviceListsStream.NAME:
|
||||
all_room_ids: Set[str] = set()
|
||||
for row in rows:
|
||||
if row.entity.startswith("@"):
|
||||
if row.entity.startswith("@") and not row.is_signature:
|
||||
room_ids = await self.store.get_rooms_for_user(row.entity)
|
||||
all_room_ids.update(room_ids)
|
||||
self.notifier.on_new_event(
|
||||
|
@ -422,7 +422,11 @@ class FederationSenderHandler:
|
|||
# The entities are either user IDs (starting with '@') whose devices
|
||||
# have changed, or remote servers that we need to tell about
|
||||
# changes.
|
||||
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
|
||||
hosts = {
|
||||
row.entity
|
||||
for row in rows
|
||||
if not row.entity.startswith("@") and not row.is_signature
|
||||
}
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(host, immediate=False)
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import (
|
|||
Stream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
|
@ -62,7 +61,6 @@ STREAMS_MAP = {
|
|||
ToDeviceStream,
|
||||
FederationStream,
|
||||
AccountDataStream,
|
||||
UserSignatureStream,
|
||||
UnPartialStatedRoomStream,
|
||||
UnPartialStatedEventStream,
|
||||
)
|
||||
|
@ -82,7 +80,6 @@ __all__ = [
|
|||
"DeviceListsStream",
|
||||
"ToDeviceStream",
|
||||
"AccountDataStream",
|
||||
"UserSignatureStream",
|
||||
"UnPartialStatedRoomStream",
|
||||
"UnPartialStatedEventStream",
|
||||
]
|
||||
|
|
|
@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
|
|||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceListsStreamRow:
|
||||
entity: str
|
||||
# Indicates that a user has signed their own device with their user-signing key
|
||||
is_signature: bool
|
||||
|
||||
NAME = "device_lists"
|
||||
ROW_TYPE = DeviceListsStreamRow
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
store = hs.get_datastores().main
|
||||
self.store = hs.get_datastores().main
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
store.get_all_device_list_changes_for_remotes,
|
||||
current_token_without_instance(self.store.get_device_stream_token),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self,
|
||||
instance_name: str,
|
||||
from_token: Token,
|
||||
current_token: Token,
|
||||
target_row_count: int,
|
||||
) -> StreamUpdateResult:
|
||||
(
|
||||
device_updates,
|
||||
devices_to_token,
|
||||
devices_limited,
|
||||
) = await self.store.get_all_device_list_changes_for_remotes(
|
||||
instance_name, from_token, current_token, target_row_count
|
||||
)
|
||||
|
||||
(
|
||||
signatures_updates,
|
||||
signatures_to_token,
|
||||
signatures_limited,
|
||||
) = await self.store.get_all_user_signature_changes_for_remotes(
|
||||
instance_name, from_token, current_token, target_row_count
|
||||
)
|
||||
|
||||
upper_limit_token = current_token
|
||||
if devices_limited:
|
||||
upper_limit_token = min(upper_limit_token, devices_to_token)
|
||||
if signatures_limited:
|
||||
upper_limit_token = min(upper_limit_token, signatures_to_token)
|
||||
|
||||
device_updates = [
|
||||
(stream_id, (entity, False))
|
||||
for stream_id, (entity,) in device_updates
|
||||
if stream_id <= upper_limit_token
|
||||
]
|
||||
|
||||
signatures_updates = [
|
||||
(stream_id, (entity, True))
|
||||
for stream_id, (entity,) in signatures_updates
|
||||
if stream_id <= upper_limit_token
|
||||
]
|
||||
|
||||
updates = list(
|
||||
heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
|
||||
)
|
||||
|
||||
return updates, upper_limit_token, devices_limited or signatures_limited
|
||||
|
||||
|
||||
class ToDeviceStream(Stream):
|
||||
"""New to_device messages for a client"""
|
||||
|
@ -583,22 +632,3 @@ class AccountDataStream(Stream):
|
|||
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
|
||||
)
|
||||
return updates, to_token, limited
|
||||
|
||||
|
||||
class UserSignatureStream(Stream):
|
||||
"""A user has signed their own device with their user-signing key"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class UserSignatureStreamRow:
|
||||
user_id: str
|
||||
|
||||
NAME = "user_signature"
|
||||
ROW_TYPE = UserSignatureStreamRow
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
store = hs.get_datastores().main
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
store.get_all_user_signature_changes_for_remotes,
|
||||
)
|
||||
|
|
|
@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
|
|||
whitelisted_homeserver,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
|
@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
|
@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def _invalidate_caches_for_devices(
|
||||
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
|
||||
) -> None:
|
||||
for row in rows:
|
||||
if row.is_signature:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
|
||||
continue
|
||||
|
||||
# The entities are either user IDs (starting with '@') whose devices
|
||||
# have changed, or remote servers that we need to tell about
|
||||
# changes.
|
||||
|
|
Loading…
Reference in New Issue