Use `async with` for ID gens (#8383)

This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
Erik Johnston 2020-09-23 16:11:18 +01:00 committed by GitHub
parent 916bb9d0d1
commit cbabb312e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 144 additions and 105 deletions

1
changelog.d/8383.misc Normal file
View File

@ -0,0 +1 @@
Refactor ID generators to use `async with` syntax.

View File

@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with await self._account_data_id_gen.get_next() as next_id:
async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with await self._account_data_id_gen.get_next() as next_id:
async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.

View File

@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with await self._device_inbox_id_gen.get_next() as stream_id:
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
with await self._device_inbox_id_gen.get_next() as stream_id:
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",

View File

@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
with await self._device_list_id_gen.get_next() as stream_id:
async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
with await self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
with await self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(

View File

@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
with await self._cross_signing_id_gen.get_next() as stream_id:
async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,

View File

@ -156,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream

View File

@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
with await self._group_updates_id_gen.get_next() as next_id:
async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,

View File

@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
with stream_ordering_manager as stream_orderings:
async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,

View File

@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with await self._push_rules_stream_id_gen.get_next() as stream_id:
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with await self._push_rules_stream_id_gen.get_next() as stream_id:
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
NotFoundError if the rule does not exist.
"""
with await self._push_rules_stream_id_gen.get_next() as stream_id:
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
with await self._push_rules_stream_id_gen.get_next() as stream_id:
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(

View File

@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
with await self._pushers_id_gen.get_next() as stream_id:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
with await self._pushers_id_gen.get_next() as stream_id:
async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)

View File

@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
with await self._receipts_id_gen.get_next() as stream_id:
async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,

View File

@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with await self._public_room_id_gen.get_next() as next_id:
async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with await self._public_room_id_gen.get_next() as next_id:
async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with await self._public_room_id_gen.get_next() as next_id:
async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,

View File

@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with await self._account_data_id_gen.get_next() as next_id:
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with await self._account_data_id_gen.get_next() as next_id:
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))

View File

@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import heapq
import logging
import threading
from collections import deque
from typing import Dict, List, Set
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union
import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
async def get_next(self):
def get_next(self):
"""
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
@contextmanager
def manager():
try:
yield next_id
@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
return manager()
return _AsyncCtxManagerWrapper(manager())
async def get_next_mult(self, n):
def get_next_mult(self, n):
"""
Usage:
with await stream_id_gen.get_next(n) as stream_ids:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
@contextmanager
def manager():
try:
yield next_ids
@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
return manager()
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self):
def get_next(self):
"""
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
return _MultiWriterCtxManager(self)
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
return manager()
async def get_next_mult(self, n: int):
def get_next_mult(self, n: int):
"""
Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids:
async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
return manager()
return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
@attr.s(slots=True)
class _AsyncCtxManagerWrapper:
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
inner = attr.ib()
async def __aenter__(self):
return self.inner.__enter__()
async def __aexit__(self, exc_type, exc, tb):
return self.inner.__exit__(exc_type, exc, tb)
@attr.s(slots=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator
"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
stream_ids = attr.ib(type=List[int], factory=list)
async def __aenter__(self) -> Union[int, List[int]]:
self.stream_ids = await self.id_gen._db.runInteraction(
"_load_next_mult_id",
self.id_gen._load_next_mult_id_txn,
self.multiple_ids or 1,
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self.id_gen._lock:
assert max(self.id_gen._current_positions.values(), default=0) < min(
self.stream_ids
)
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
async def __aexit__(self, exc_type, exc, tb):
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
if exc_type is not None:
return False
return False

View File

@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
with await id_gen.get_next() as stream_id:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
ctx3 = self.get_success(id_gen.get_next())
ctx4 = self.get_success(id_gen.get_next())
s1 = ctx1.__enter__()
s2 = ctx2.__enter__()
s3 = ctx3.__enter__()
s4 = ctx4.__enter__()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
s3 = self.get_success(ctx3.__aenter__())
s4 = self.get_success(ctx4.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
ctx2.__exit__(None, None, None)
self.get_success(ctx2.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
ctx1.__exit__(None, None, None)
self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
ctx4.__exit__(None, None, None)
self.get_success(ctx4.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
ctx3.__exit__(None, None, None)
self.get_success(ctx3.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
with await first_id_gen.get_next() as stream_id:
async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
with await second_id_gen.get_next() as stream_id:
async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
with self.get_success(id_gen.get_next()) as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
id_gen = self._create_id_generator()
with self.get_success(id_gen.get_next()) as stream_id:
self._insert_row("master", stream_id)
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
async def _get_next_async2():
async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")
with self.get_success(id_gen_1.get_next()) as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
async def _get_next_async():
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
with self.get_success(id_gen_2.get_next()) as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
async def _get_next_async2():
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})