Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator` (#17229)

Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer.
This commit is contained in:
Erik Johnston 2024-05-30 12:07:32 +01:00 committed by GitHub
parent 225f378ffa
commit d16910ca02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 227 additions and 363 deletions

1
changelog.d/17229.misc Normal file
View File

@ -0,0 +1 @@
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.

View File

@ -777,22 +777,74 @@ class Porter:
await self._setup_events_stream_seqs() await self._setup_events_stream_seqs()
await self._setup_sequence( await self._setup_sequence(
"un_partial_stated_event_stream_sequence", "un_partial_stated_event_stream_sequence",
("un_partial_stated_event_stream",), [("un_partial_stated_event_stream", "stream_id")],
) )
await self._setup_sequence( await self._setup_sequence(
"device_inbox_sequence", ("device_inbox", "device_federation_outbox") "device_inbox_sequence",
[
("device_inbox", "stream_id"),
("device_federation_outbox", "stream_id"),
],
) )
await self._setup_sequence( await self._setup_sequence(
"account_data_sequence", "account_data_sequence",
("room_account_data", "room_tags_revisions", "account_data"), [
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
("account_data", "stream_id"),
],
)
await self._setup_sequence(
"receipts_sequence",
[
("receipts_linearized", "stream_id"),
],
)
await self._setup_sequence(
"presence_stream_sequence",
[
("presence_stream", "stream_id"),
],
) )
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
await self._setup_auth_chain_sequence() await self._setup_auth_chain_sequence()
await self._setup_sequence( await self._setup_sequence(
"application_services_txn_id_seq", "application_services_txn_id_seq",
("application_services_txns",), [
"txn_id", (
"application_services_txns",
"txn_id",
)
],
)
await self._setup_sequence(
"device_lists_sequence",
[
("device_lists_stream", "stream_id"),
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"),
],
)
await self._setup_sequence(
"e2e_cross_signing_keys_sequence",
[
("e2e_cross_signing_keys", "stream_id"),
],
)
await self._setup_sequence(
"push_rules_stream_sequence",
[
("push_rules_stream", "stream_id"),
],
)
await self._setup_sequence(
"pushers_sequence",
[
("pushers", "id"),
("deleted_pushers", "stream_id"),
],
) )
# Step 3. Get tables. # Step 3. Get tables.
@ -1101,12 +1153,11 @@ class Porter:
async def _setup_sequence( async def _setup_sequence(
self, self,
sequence_name: str, sequence_name: str,
stream_id_tables: Iterable[str], stream_id_tables: Iterable[Tuple[str, str]],
column_name: str = "stream_id",
) -> None: ) -> None:
"""Set a sequence to the correct value.""" """Set a sequence to the correct value."""
current_stream_ids = [] current_stream_ids = []
for stream_id_table in stream_id_tables: for stream_id_table, column_name in stream_id_tables:
max_stream_id = cast( max_stream_id = cast(
int, int,
await self.sqlite_store.db_pool.simple_select_one_onecol( await self.sqlite_store.db_pool.simple_select_one_onecol(

View File

@ -57,10 +57,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import MultiWriterIdGenerator
AbstractStreamIdGenerator,
StreamIdGenerator,
)
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping, JsonMapping,
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker # In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process. # class below that is used on the main process.
self._device_list_id_gen = StreamIdGenerator( self._device_list_id_gen = MultiWriterIdGenerator(
db_conn, db_conn=db_conn,
hs.get_replication_notifier(), db=database,
"device_lists_stream", notifier=hs.get_replication_notifier(),
"stream_id", stream_name="device_lists_stream",
extra_tables=[ instance_name=self._instance_name,
("user_signature_stream", "stream_id"), tables=[
("device_lists_outbound_pokes", "stream_id"), ("device_lists_stream", "instance_name", "stream_id"),
("device_lists_changes_in_room", "stream_id"), ("user_signature_stream", "instance_name", "stream_id"),
("device_lists_remote_pending", "stream_id"), ("device_lists_outbound_pokes", "instance_name", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"), ("device_lists_changes_in_room", "instance_name", "stream_id"),
("device_lists_remote_pending", "instance_name", "stream_id"),
], ],
is_writer=hs.config.worker.worker_app is None, sequence_name="device_lists_sequence",
writers=["master"],
) )
device_list_max = self._device_list_id_gen.get_current_token() device_list_max = self._device_list_id_gen.get_current_token()
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"stream_id": stream_id, "stream_id": stream_id,
"from_user_id": from_user_id, "from_user_id": from_user_id,
"user_ids": json_encoder.encode(user_ids), "user_ids": json_encoder.encode(user_ids),
"instance_name": self._instance_name,
}, },
) )
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx", "device_lists_stream_idx",
index_name="device_lists_stream_user_id", index_name="device_lists_stream_user_id",
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"device_lists_outbound_pokes", "device_lists_outbound_pokes",
{ {
"stream_id": stream_id, "stream_id": stream_id,
"instance_name": self._instance_name,
"destination": destination, "destination": destination,
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
# (see DeviceWorkerStore.__init__)
_device_list_id_gen: AbstractStreamIdGenerator
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="device_lists_stream", table="device_lists_stream",
keys=("stream_id", "user_id", "device_id"), keys=("instance_name", "stream_id", "user_id", "device_id"),
values=[ values=[
(stream_id, user_id, device_id) (self._instance_name, stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids) for stream_id, device_id in zip(stream_ids, device_ids)
], ],
) )
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values = [ values = [
( (
destination, destination,
self._instance_name,
next(stream_id_iterator), next(stream_id_iterator),
user_id, user_id,
device_id, device_id,
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_outbound_pokes", table="device_lists_outbound_pokes",
keys=( keys=(
"destination", "destination",
"instance_name",
"stream_id", "stream_id",
"user_id", "user_id",
"device_id", "device_id",
@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id, device_id,
{ {
stream_id: destination stream_id: destination
for (destination, stream_id, _, _, _, _, _) in values for (destination, _, stream_id, _, _, _, _, _) in values
}, },
) )
@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id", "device_id",
"room_id", "room_id",
"stream_id", "stream_id",
"instance_name",
"converted_to_destinations", "converted_to_destinations",
"opentracing_context", "opentracing_context",
), ),
@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id, device_id,
room_id, room_id,
stream_id, stream_id,
self._instance_name,
# We only need to calculate outbound pokes for local users # We only need to calculate outbound pokes for local users
not self.hs.is_mine_id(user_id), not self.hs.is_mine_id(user_id),
encoded_context, encoded_context,
@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
}, },
values={"stream_id": stream_id}, values={
"stream_id": stream_id,
"instance_name": self._instance_name,
},
desc="add_remote_device_list_to_pending", desc="add_remote_device_list_to_pending",
) )

View File

@ -58,7 +58,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator( self._cross_signing_id_gen = MultiWriterIdGenerator(
db_conn, db_conn=db_conn,
hs.get_replication_notifier(), db=database,
"e2e_cross_signing_keys", notifier=hs.get_replication_notifier(),
"stream_id", stream_name="e2e_cross_signing_keys",
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),
],
sequence_name="e2e_cross_signing_keys_sequence",
writers=["master"],
) )
async def set_e2e_device_keys( async def set_e2e_device_keys(
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"keytype": key_type, "keytype": key_type,
"keydata": json_encoder.encode(key), "keydata": json_encoder.encode(key),
"stream_id": stream_id, "stream_id": stream_id,
"instance_name": self._instance_name,
}, },
) )

View File

@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError from synapse.util import json_encoder, unwrapFirstError
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
_push_rules_stream_id_gen: StreamIdGenerator _push_rules_stream_id_gen: MultiWriterIdGenerator
def __init__( def __init__(
self, self,
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
hs.get_instance_name() in hs.config.worker.writers.push_rules hs.get_instance_name() in hs.config.worker.writers.push_rules
) )
# In the worker store this is an ID tracker which we overwrite in the non-worker self._push_rules_stream_id_gen = MultiWriterIdGenerator(
# class below that is used on the main process. db_conn=db_conn,
self._push_rules_stream_id_gen = StreamIdGenerator( db=database,
db_conn, notifier=hs.get_replication_notifier(),
hs.get_replication_notifier(), stream_name="push_rules_stream",
"push_rules_stream", instance_name=self._instance_name,
"stream_id", tables=[
is_writer=self._is_push_writer, ("push_rules_stream", "instance_name", "stream_id"),
],
sequence_name="push_rules_stream_sequence",
writers=hs.config.worker.writers.push_rules,
) )
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
raise Exception("Not a push writer") raise Exception("Not a push writer")
values = { values = {
"instance_name": self._instance_name,
"stream_id": stream_id, "stream_id": stream_id,
"event_stream_ordering": event_stream_ordering, "event_stream_ordering": event_stream_ordering,
"user_id": user_id, "user_id": user_id,

View File

@ -40,10 +40,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import MultiWriterIdGenerator
AbstractStreamIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# In the worker store this is an ID tracker which we overwrite in the non-worker self._instance_name = hs.get_instance_name()
# class below that is used on the main process.
self._pushers_id_gen = StreamIdGenerator( self._pushers_id_gen = MultiWriterIdGenerator(
db_conn, db_conn=db_conn,
hs.get_replication_notifier(), db=database,
"pushers", notifier=hs.get_replication_notifier(),
"id", stream_name="pushers",
extra_tables=[("deleted_pushers", "stream_id")], instance_name=self._instance_name,
is_writer=hs.config.worker.worker_app is None, tables=[
("pushers", "instance_name", "id"),
("deleted_pushers", "instance_name", "stream_id"),
],
sequence_name="pushers_sequence",
writers=["master"],
) )
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
# Because we have write access, this will be a StreamIdGenerator # Because we have write access, this will be a StreamIdGenerator
# (see PusherWorkerStore.__init__) # (see PusherWorkerStore.__init__)
_pushers_id_gen: AbstractStreamIdGenerator _pushers_id_gen: MultiWriterIdGenerator
async def add_pusher( async def add_pusher(
self, self,
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"last_stream_ordering": last_stream_ordering, "last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,
"instance_name": self._instance_name,
"enabled": enabled, "enabled": enabled,
"device_id": device_id, "device_id": device_id,
# XXX(quenting): We're only really persisting the access token ID # XXX(quenting): We're only really persisting the access token ID
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
table="deleted_pushers", table="deleted_pushers",
values={ values={
"stream_id": stream_id, "stream_id": stream_id,
"instance_name": self._instance_name,
"app_id": app_id, "app_id": app_id,
"pushkey": pushkey, "pushkey": pushkey,
"user_id": user_id, "user_id": user_id,
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
self.db_pool.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="deleted_pushers", table="deleted_pushers",
keys=("stream_id", "app_id", "pushkey", "user_id"), keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
values=[ values=[
(stream_id, pusher.app_id, pusher.pushkey, user_id) (
stream_id,
self._instance_name,
pusher.app_id,
pusher.pushkey,
user_id,
)
for stream_id, pusher in zip(stream_ids, pushers) for stream_id, pusher in zip(stream_ids, pushers)
], ],
) )

View File

@ -0,0 +1,27 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add `instance_name` columns to stream tables to allow them to be used with
-- `MultiWriterIdGenerator`
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,54 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add squences for stream tables to allow them to be used with
-- `MultiWriterIdGenerator`
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
-- We need to take the max across all the device lists tables as they share the
-- ID generator
SELECT setval('device_lists_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
)
));
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
SELECT setval('e2e_cross_signing_keys_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
));
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
SELECT setval('push_rules_stream_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
));
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
-- We need to take the max across all the pusher tables as they share the
-- ID generator
SELECT setval('pushers_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(id), 1) FROM pushers),
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
)
));

View File

@ -23,15 +23,12 @@ import abc
import heapq import heapq
import logging import logging
import threading import threading
from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AsyncContextManager, AsyncContextManager,
ContextManager, ContextManager,
Dict, Dict,
Generator,
Generic, Generic,
Iterable, Iterable,
List, List,
@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with a single writer.
This class must only be used when the current Synapse process is the sole
writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(
self,
db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step: int = 1,
is_writer: bool = True,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
self._is_writer = is_writer
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
self._notifier = notifier
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
raise Exception("Replication is not supported by writer StreamIdGenerator")
self._current = (max if self._step > 0 else min)(self._current, new_id)
def get_next(self) -> AsyncContextManager[int]:
with self._lock:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
with self._lock:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
with self._lock:
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
self._step,
)
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Retrieve the next stream ID from within a database transaction.
Clean-up functions will be called when the transaction finishes.
Args:
txn: The database transaction object.
Returns:
The next stream ID.
"""
if not self._is_writer:
raise Exception("Tried to allocate stream ID on non-writer")
# Get the next stream ID.
with self._lock:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
def clear_unfinished_id(id_to_clear: int) -> None:
"""A function to mark processing this ID as finished"""
with self._lock:
self._unfinished_ids.pop(id_to_clear)
# Mark this ID as finished once the database transaction itself finishes.
txn.call_after(clear_unfinished_id, next_id)
txn.call_on_exception(clear_unfinished_id, next_id)
# Return the new ID.
return next_id
def get_current_token(self) -> int:
if not self._is_writer:
return self._current
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
def get_minimal_local_current_token(self) -> int:
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator): class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers. """Generates and tracks stream IDs for a stream with multiple writers.

View File

@ -30,7 +30,7 @@ from synapse.storage.database import (
) )
from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import ( from synapse.storage.util.sequence import (
LocalSequenceGenerator, LocalSequenceGenerator,
PostgresSequenceGenerator, PostgresSequenceGenerator,
@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
class StreamIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
data TEXT
);
"""
)
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
def _create_id_generator(self) -> StreamIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
notifier=self.hs.get_replication_notifier(),
table="foobar",
column="stream_id",
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def test_initial_value(self) -> None:
"""Check that we read the current token from the DB."""
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_current_token(), 123)
def test_single_gen_next(self) -> None:
"""Check that we correctly increment the current token from the DB."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
async with id_gen.get_next() as next_id:
# We haven't persisted `next_id` yet; current token is still 123
self.assertEqual(id_gen.get_current_token(), 123)
# But we did learn what the next value is
self.assertEqual(next_id, 124)
# Once the context manager closes we assume that the `next_id` has been
# written to the DB.
self.assertEqual(id_gen.get_current_token(), 124)
self.get_success(test_gen_next())
def test_multiple_gen_nexts(self) -> None:
"""Check that we handle overlapping calls to gen_next sensibly."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)
# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)
# Persist each in turn.
await ctx1.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 124)
await ctx2.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 125)
await ctx3.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 126)
self.get_success(test_gen_next())
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
"""Check that we handle overlapping calls to gen_next, even when their IDs
created and persisted in different orders."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)
# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)
# Persist them in a different order, starting with 126 from ctx3.
await ctx3.__aexit__(None, None, None)
# We haven't persisted 124 from ctx1 yet---current token is still 123.
self.assertEqual(id_gen.get_current_token(), 123)
# Now persist 124 from ctx1.
await ctx1.__aexit__(None, None, None)
# Current token is then 124, waiting for 125 to be persisted.
self.assertEqual(id_gen.get_current_token(), 124)
# Finally persist 125 from ctx2.
await ctx2.__aexit__(None, None, None)
# Current token is then 126 (skipping over 125).
self.assertEqual(id_gen.get_current_token(), 126)
self.get_success(test_gen_next())
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
"""Check that we handle overlapping calls to gen_next."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request two new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
# Persist ctx2 first.
await ctx2.__aexit__(None, None, None)
# Still waiting on ctx1's ID to be persisted.
self.assertEqual(id_gen.get_current_token(), 123)
# Now request a third stream ID. It should be 126 (the smallest ID that
# we've not yet handed out.)
self.assertEqual(await ctx3.__aenter__(), 126)
self.get_success(test_gen_next())
class MultiWriterIdGeneratorBase(HomeserverTestCase): class MultiWriterIdGeneratorBase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main