Remove `ChainedIdGenerator`. (#8123)

It's just a thin wrapper around two ID gens to make `get_current_token`
and `get_next` return tuples. This can easily be replaced by calling the
appropriate methods on the underlying ID gens directly.
This commit is contained in:
Erik Johnston 2020-08-19 13:41:51 +01:00 committed by GitHub
parent f594e434c3
commit c9c544cda5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 26 additions and 95 deletions

1
changelog.d/8123.misc Normal file
View File

@ -0,0 +1 @@
Remove `ChainedIdGenerator`.

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -21,16 +22,13 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME: if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token) self._push_rules_stream_id_gen.advance(token)
for row in rows: for row in rows:

View File

@ -352,7 +352,7 @@ class PushRulesStream(Stream):
) )
def _current_token(self, instance_name: str) -> int: def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token() push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token return push_rules_token

View File

@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
return 200, {} return 200, {}
def notify_user(self, user_id): def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token() stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val): async def set_rule_attr(self, user_id, spec, val):

View File

@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = ChainedIdGenerator( self._push_rules_stream_id_gen = StreamIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" db_conn, "push_rules_stream", "stream_id"
) # type: Union[ChainedIdGenerator, SlavedIdTracker] ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else: else:
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id" db_conn, "push_rules_stream", "stream_id"
@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None: ) -> None:
conditions_json = json_encoder.encode(conditions) conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as stream_id:
stream_id, event_stream_ordering = ids event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after: if before or after:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_add_push_rule_relative_txn", "_add_push_rule_relative_txn",
@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
) )
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as stream_id:
stream_id, event_stream_ordering = ids event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_push_rule", "delete_push_rule",
delete_push_rule_txn, delete_push_rule_txn,
@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
) )
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as stream_id:
stream_id, event_stream_ordering = ids event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn, self._set_push_rule_enabled_txn,
@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json}, data={"actions": actions_json},
) )
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as stream_id:
stream_id, event_stream_ordering = ids event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_push_rule_actions", "set_push_rule_actions",
set_push_rule_actions_txn, set_push_rule_actions_txn,
@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
) )
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self):
return self.get_push_rules_stream_token()[0] return self._push_rules_stream_id_gen.get_current_token()

View File

@ -16,7 +16,7 @@
import contextlib import contextlib
import threading import threading
from collections import deque from collections import deque
from typing import Dict, Set, Tuple from typing import Dict, Set
from typing_extensions import Deque from typing_extensions import Deque
@ -167,72 +167,6 @@ class StreamIdGenerator(object):
return self.get_current_token() return self.get_current_token()
class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
with another stream. It generates pairs of IDs, the first element is an
integer ID for this stream, the second element is the ID for the stream
that this stream needs to be kept in sync with."""
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._table = table
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
Usage:
with stream_id_gen.get_next() as (stream_id, chained_id):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@contextlib.contextmanager
def manager():
try:
yield (next_id, chained_id)
finally:
with self._lock:
self._unfinished_ids.remove((next_id, chained_id))
return manager()
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
def advance(self, token: int):
"""Stub implementation for advancing the token when receiving updates
over replication; raises an exception as this instance should be the
only source of updates.
"""
raise Exception(
"Attempted to advance token on source for table %r", self._table
)
def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class MultiWriterIdGenerator: class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers. """An ID generator that tracks a stream that can have multiple writers.

View File

@ -39,7 +39,7 @@ class EventSources(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken: def get_current_token(self) -> StreamToken:
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key = self.store.get_max_push_rules_stream_id()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token() device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token() groups_key = self.store.get_group_stream_token()