Update all stream IDs after processing replication rows (#14723)
This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703
This commit is contained in:
parent
c4456114e1
commit
db1cfe9c80
|
@ -0,0 +1 @@
|
|||
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -152,6 +152,9 @@ class ReplicationDataHandler:
|
|||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||
# NOTE: this must be called after process_replication_rows to ensure any
|
||||
# cache invalidations are first handled before any stream ID advances.
|
||||
self.store.process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
if self.send_handler:
|
||||
await self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||
|
|
|
@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
pass
|
||||
"""
|
||||
Used by storage classes to invalidate caches based on incoming replication data. These
|
||||
must not update any ID generators, use `process_replication_position`.
|
||||
"""
|
||||
|
||||
def process_replication_position( # noqa: B027 (no-op by design)
|
||||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: int,
|
||||
) -> None:
|
||||
"""
|
||||
Used by storage classes to advance ID generators based on incoming replication data. This
|
||||
is called after process_replication_rows such that caches are invalidated before any token
|
||||
positions advance.
|
||||
"""
|
||||
|
||||
def _invalidate_state_caches(
|
||||
self, room_id: str, members_changed: Collection[str]
|
||||
|
|
|
@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
elif stream_name == AccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
if stream_name == AccountDataStream.NAME:
|
||||
for row in rows:
|
||||
if not row.room_id:
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
|
@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
elif stream_name == AccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
async def add_account_data_to_room(
|
||||
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||
) -> int:
|
||||
|
|
|
@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
backfilled=True,
|
||||
)
|
||||
elif stream_name == CachesStream.NAME:
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(instance_name, token)
|
||||
|
||||
for row in rows:
|
||||
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
||||
if row.keys is None:
|
||||
|
@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == CachesStream.NAME:
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
|
||||
data = row.data
|
||||
|
||||
|
|
|
@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == ToDeviceStream.NAME:
|
||||
self._device_inbox_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def get_to_device_stream_token(self) -> int:
|
||||
return self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
|
|
|
@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def _invalidate_caches_for_devices(
|
||||
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
|
||||
) -> None:
|
||||
|
|
|
@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(instance_name, -token)
|
||||
elif stream_name == UnPartialStatedEventStream.NAME:
|
||||
if stream_name == UnPartialStatedEventStream.NAME:
|
||||
for row in rows:
|
||||
assert isinstance(row, UnPartialStatedEventStreamRow)
|
||||
|
||||
|
@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(instance_name, -token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
async def have_censored_event(self, event_id: str) -> bool:
|
||||
"""Check if an event has been censored, i.e. if the content of the event has been erased
|
||||
from the database due to a redaction.
|
||||
|
|
|
@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||
self._get_presence_for_user.invalidate((row.user_id,))
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
|
|
@ -154,6 +154,13 @@ class PushRulesWorkerStore(
|
|||
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
|
|
|
@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
def get_pushers_stream_token(self) -> int:
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == PushersStream.NAME:
|
||||
self._pushers_id_gen.advance(instance_name, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
async def get_pushers_by_app_id_and_pushkey(
|
||||
self, app_id: str, pushkey: str
|
||||
|
|
|
@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == ReceiptsStream.NAME:
|
||||
self._receipts_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def _insert_linearized_receipt_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
|
|
@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.get_tags_for_user.invalidate((row.user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
|
||||
class TagsStore(TagsWorkerStore):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue