Use `async with` for ID gens (#8383)
This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
parent
916bb9d0d1
commit
cbabb312e0
|
@ -0,0 +1 @@
|
|||
Refactor ID generators to use `async with` syntax.
|
|
@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
|
@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
|
|
|
@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
rows.append((destination, stream_id, now_ms, edu_json))
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
|
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
txn, stream_id, local_messages_by_user_then_device
|
||||
)
|
||||
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
|
|
|
@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with await self._device_list_id_gen.get_next() as stream_id:
|
||||
async with self._device_list_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
|
@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if not device_ids:
|
||||
return
|
||||
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
return stream_ids[-1]
|
||||
|
||||
context = get_active_span_text_map()
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
key (dict): the key data
|
||||
"""
|
||||
|
||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||
async with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
|
|
|
@ -156,15 +156,15 @@ class PersistEventsStore:
|
|||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
async with stream_ordering_manager as stream_orderings:
|
||||
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
||||
event.internal_metadata.stream_ordering = stream
|
||||
|
||||
|
|
|
@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
|
||||
return next_id
|
||||
|
||||
with await self._group_updates_id_gen.get_next() as next_id:
|
||||
async with self._group_updates_id_gen.get_next() as next_id:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn,
|
||||
|
|
|
@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
|
|||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
async def update_presence(self, presence_states):
|
||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||
len(presence_states)
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
async with stream_ordering_manager as stream_orderings:
|
||||
await self.db_pool.runInteraction(
|
||||
"update_presence",
|
||||
self._update_presence_txn,
|
||||
|
|
|
@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
|
@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
Raises:
|
||||
NotFoundError if the rule does not exist.
|
||||
"""
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
await self.db_pool.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
|
@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
|||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
) -> None:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
await self.db_pool.simple_upsert(
|
||||
|
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
|
|
@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
with await self._receipts_id_gen.get_next() as stream_id:
|
||||
async with self._receipts_id_gen.get_next() as stream_id:
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
|
|
|
@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"store_room_txn", store_room_txn, next_id
|
||||
)
|
||||
|
@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
|
@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
|
|
|
@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
|||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
|||
txn.execute(sql, (user_id, room_id, tag))
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, List, Set
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
|
@ -86,7 +86,7 @@ class StreamIdGenerator:
|
|||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
|
@ -101,10 +101,10 @@ class StreamIdGenerator:
|
|||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -113,7 +113,7 @@ class StreamIdGenerator:
|
|||
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
|
@ -121,12 +121,12 @@ class StreamIdGenerator:
|
|||
with self._lock:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
async def get_next_mult(self, n):
|
||||
def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
async with stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -140,7 +140,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
|
@ -149,7 +149,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_current_token(self):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
|
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
|
|||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
|
||||
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
return _MultiWriterCtxManager(self)
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
return _MultiWriterCtxManager(self, n)
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
|
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
|
|||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncCtxManagerWrapper:
|
||||
"""Helper class to convert a plain context manager to an async one.
|
||||
|
||||
This is mainly useful if you have a plain context manager but the interface
|
||||
requires an async one.
|
||||
"""
|
||||
|
||||
inner = attr.ib()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.inner.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return self.inner.__exit__(exc_type, exc, tb)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _MultiWriterCtxManager:
|
||||
"""Async context manager returned by MultiWriterIdGenerator
|
||||
"""
|
||||
|
||||
id_gen = attr.ib(type=MultiWriterIdGenerator)
|
||||
multiple_ids = attr.ib(type=Optional[int], default=None)
|
||||
stream_ids = attr.ib(type=List[int], factory=list)
|
||||
|
||||
async def __aenter__(self) -> Union[int, List[int]]:
|
||||
self.stream_ids = await self.id_gen._db.runInteraction(
|
||||
"_load_next_mult_id",
|
||||
self.id_gen._load_next_mult_id_txn,
|
||||
self.multiple_ids or 1,
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self.id_gen._lock:
|
||||
assert max(self.id_gen._current_positions.values(), default=0) < min(
|
||||
self.stream_ids
|
||||
)
|
||||
|
||||
self.id_gen._unfinished_ids.update(self.stream_ids)
|
||||
|
||||
if self.multiple_ids is None:
|
||||
return self.stream_ids[0] * self.id_gen._return_factor
|
||||
else:
|
||||
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
for i in self.stream_ids:
|
||||
self.id_gen._mark_id_as_finished(i)
|
||||
|
||||
if exc_type is not None:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
|
|
@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await id_gen.get_next() as stream_id:
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
|
@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
ctx3 = self.get_success(id_gen.get_next())
|
||||
ctx4 = self.get_success(id_gen.get_next())
|
||||
|
||||
s1 = ctx1.__enter__()
|
||||
s2 = ctx2.__enter__()
|
||||
s3 = ctx3.__enter__()
|
||||
s4 = ctx4.__enter__()
|
||||
s1 = self.get_success(ctx1.__aenter__())
|
||||
s2 = self.get_success(ctx2.__aenter__())
|
||||
s3 = self.get_success(ctx3.__aenter__())
|
||||
s4 = self.get_success(ctx4.__aenter__())
|
||||
|
||||
self.assertEqual(s1, 8)
|
||||
self.assertEqual(s2, 9)
|
||||
|
@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
ctx2.__exit__(None, None, None)
|
||||
self.get_success(ctx2.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
ctx1.__exit__(None, None, None)
|
||||
self.get_success(ctx1.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||
|
||||
ctx4.__exit__(None, None, None)
|
||||
self.get_success(ctx4.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||
|
||||
ctx3.__exit__(None, None, None)
|
||||
self.get_success(ctx3.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||
|
@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await first_id_gen.get_next() as stream_id:
|
||||
async with first_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# stream ID
|
||||
|
||||
async def _get_next_async():
|
||||
with await second_id_gen.get_next() as stream_id:
|
||||
async with second_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 9)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -305,10 +305,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
|
||||
async def _get_next_async():
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# We assume that so long as `get_next` does correctly advance the
|
||||
|
@ -373,17 +377,23 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
async def _get_next_async():
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
|
||||
async def _get_next_async2():
|
||||
async with id_gen.get_next_mult(3) as stream_ids:
|
||||
for stream_id in stream_ids:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.get_success(_get_next_async2())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
|
||||
|
@ -402,19 +412,25 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
id_gen_1 = self._create_id_generator("first")
|
||||
id_gen_2 = self._create_id_generator("second")
|
||||
|
||||
with self.get_success(id_gen_1.get_next()) as stream_id:
|
||||
async def _get_next_async():
|
||||
async with id_gen_1.get_next() as stream_id:
|
||||
self._insert_row("first", stream_id)
|
||||
id_gen_2.advance("first", stream_id)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen_2.get_next()) as stream_id:
|
||||
async def _get_next_async2():
|
||||
async with id_gen_2.get_next() as stream_id:
|
||||
self._insert_row("second", stream_id)
|
||||
id_gen_1.advance("second", stream_id)
|
||||
|
||||
self.get_success(_get_next_async2())
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
||||
|
|
Loading…
Reference in New Issue