Add functions to `MultiWriterIdGen` used by events stream (#8164)
This commit is contained in:
parent
5099bd68da
commit
eba98fb024
|
@ -0,0 +1 @@
|
|||
Add functions to `MultiWriterIdGen` used by events stream.
|
|
@ -14,9 +14,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
|
@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
|
|||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
|
||||
# We track the max position where we know everything before has been
|
||||
# persisted. This is done by a) looking at the min across all instances
|
||||
# and b) noting that if we have seen a run of persisted positions
|
||||
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
|
||||
#
|
||||
# Note: There is no guarentee that the IDs generated by the sequence
|
||||
# will be gapless; gaps can form when e.g. a transaction was rolled
|
||||
# back. This means that sometimes we won't be able to skip forward the
|
||||
# position even though everything has been persisted. However, since
|
||||
# gaps should be relatively rare it's still worth doing the book keeping
|
||||
# that allows us to skip forwards when there are gapless runs of
|
||||
# positions.
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 0
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
|
||||
def _load_current_ids(
|
||||
|
@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
|
|||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn):
|
||||
def _load_next_id_txn(self, txn) -> int:
|
||||
return self._sequence_gen.get_next_id_txn(txn)
|
||||
|
||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
|
@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
|
|||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
Usage:
|
||||
|
@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
|
|||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
self._add_persisted_position(new_id)
|
||||
|
||||
def get_persisted_upto_position(self) -> int:
|
||||
"""Get the max position where all previous positions have been
|
||||
persisted.
|
||||
|
||||
Note: In the worst case scenario this will be equal to the minimum
|
||||
position across writers. This means that the returned position here can
|
||||
lag if one writer doesn't write very often.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
This is used to keep the `_current_positions` up to date.
|
||||
"""
|
||||
|
||||
# We require that the lock is locked by caller
|
||||
assert self._lock.locked()
|
||||
|
||||
heapq.heappush(self._known_persisted_positions, new_id)
|
||||
|
||||
# We move the current min position up if the minimum current positions
|
||||
# of all instances is higher (since by definition all positions less
|
||||
# that that have been persisted).
|
||||
min_curr = min(self._current_positions.values())
|
||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||
|
||||
# We now iterate through the seen positions, discarding those that are
|
||||
# less than the current min positions, and incrementing the min position
|
||||
# if its exactly one greater.
|
||||
#
|
||||
# This is also where we discard items from `_known_persisted_positions`
|
||||
# (to ensure the list doesn't infinitely grow).
|
||||
while self._known_persisted_positions:
|
||||
if self._known_persisted_positions[0] <= self._persisted_upto_position:
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
elif (
|
||||
self._known_persisted_positions[0] == self._persisted_upto_position + 1
|
||||
):
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
self._persisted_upto_position += 1
|
||||
else:
|
||||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import abc
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
|
@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
|||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
txn.execute(
|
||||
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
||||
)
|
||||
return [i for (i,) in txn]
|
||||
|
||||
|
||||
GetFirstCallbackType = Callable[[Cursor], int]
|
||||
|
||||
|
|
|
@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_get_persisted_upto_position(self):
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
|
Loading…
Reference in New Issue