Make background updates controllable via a plugin (#11306)

Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
Erik Johnston 2021-11-29 16:57:06 +00:00 committed by GitHub
parent 9d1971a5c4
commit d08ef6f155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 407 additions and 61 deletions

View File

@ -0,0 +1 @@
Add plugin support for controlling database background updates.

View File

@ -0,0 +1,71 @@
# Background update controller callbacks
Background update controller callbacks allow module developers to control (e.g. rate-limit)
how database background updates are run. A database background update is an operation
Synapse runs on its database in the background after it starts. It's usually used to run
database operations that would take too long if they were run at the same time as schema
updates (which are run on startup) and delay Synapse's startup too much: populating a
table with a big amount of data, adding an index on a big table, deleting superfluous data,
etc.
Background update controller callbacks can be registered using the module API's
`register_background_update_controller_callbacks` method. Only the first module (in order
of appearance in Synapse's configuration file) calling this method can register background
update controller callbacks, subsequent calls are ignored.
The available background update controller callbacks are:
### `on_update`
_First introduced in Synapse v1.49.0_
```python
def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int]
```
Called when about to do an iteration of a background update. The module is given the name
of the update, the name of the database, and a flag to indicate whether the background
update will happen in one go and may take a long time (e.g. creating indices). If this last
argument is set to `False`, the update will be run in batches.
The module must return an async context manager. It will be entered before Synapse runs a
background update; this should return the desired duration of the iteration, in
milliseconds.
The context manager will be exited when the iteration completes. Note that the duration
returned by the context manager is a target, and an iteration may take substantially longer
or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored.
__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is
because asynchronous operations are expected to be run by the async context manager.
This callback is required when registering any other background update controller callback.
### `default_batch_size`
_First introduced in Synapse v1.49.0_
```python
async def default_batch_size(update_name: str, database_name: str) -> int
```
Called before the first iteration of a background update, with the name of the update and
of the database. The module must return the number of elements to process in this first
iteration.
If this callback is not defined, Synapse will use a default value of 100.
### `min_batch_size`
_First introduced in Synapse v1.49.0_
```python
async def min_batch_size(update_name: str, database_name: str) -> int
```
Called before running a new batch for a background update, with the name of the update and
of the database. The module must return an integer representing the minimum number of
elements to process in this iteration. This number must be at least 1, and is used to
ensure that progress is always made.
If this callback is not defined, Synapse will use a default value of 100.

View File

@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method.
## Registering a callback ## Registering a callback
Modules can use Synapse's module API to register callbacks. Callbacks are functions that Modules can use Synapse's module API to register callbacks. Callbacks are functions that
Synapse will call when performing specific actions. Callbacks must be asynchronous, and Synapse will call when performing specific actions. Callbacks must be asynchronous (unless
are split in categories. A single module may implement callbacks from multiple categories, specified otherwise), and are split in categories. A single module may implement callbacks
and is under no obligation to implement all callbacks from the categories it registers from multiple categories, and is under no obligation to implement all callbacks from the
callbacks for. categories it registers callbacks for.
Modules can register callbacks using one of the module API's `register_[...]_callbacks` Modules can register callbacks using one of the module API's `register_[...]_callbacks`
methods. The callback functions are passed to these methods as keyword arguments, with methods. The callback functions are passed to these methods as keyword arguments, with
the callback name as the argument name and the function as its value. This is demonstrated the callback name as the argument name and the function as its value. A
in the example below. A `register_[...]_callbacks` method exists for each category. `register_[...]_callbacks` method exists for each category.
Callbacks for each category can be found on their respective page of the Callbacks for each category can be found on their respective page of the
[Synapse documentation website](https://matrix-org.github.io/synapse). [Synapse documentation website](https://matrix-org.github.io/synapse).

View File

@ -119,7 +119,9 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
# Tests assume that all optional dependencies are installed. # Tests assume that all optional dependencies are installed.
# #
# parameterized_class decorator was introduced in parameterized 0.7.0 # parameterized_class decorator was introduced in parameterized 0.7.0
CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"] #
# We use `mock` library as that backports `AsyncMock` to Python 3.6
CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
CONDITIONAL_REQUIREMENTS["dev"] = ( CONDITIONAL_REQUIREMENTS["dev"] = (
CONDITIONAL_REQUIREMENTS["lint"] CONDITIONAL_REQUIREMENTS["lint"]

View File

@ -82,10 +82,19 @@ from synapse.http.server import (
) )
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import (
defer_to_thread,
make_deferred_yieldable,
run_in_background,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.background_updates import (
DEFAULT_BATCH_SIZE_CALLBACK,
MIN_BATCH_SIZE_CALLBACK,
ON_UPDATE_CALLBACK,
)
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -311,6 +320,24 @@ class ModuleApi:
auth_checkers=auth_checkers, auth_checkers=auth_checkers,
) )
def register_background_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
) -> None:
"""Registers background update controller callbacks.
Added in Synapse v1.49.0.
"""
for db in self._hs.get_datastores().databases:
db.updates.register_update_controller_callbacks(
on_update=on_update,
default_batch_size=default_batch_size,
min_batch_size=min_batch_size,
)
def register_web_resource(self, path: str, resource: Resource) -> None: def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path. """Registers a web resource to be served at the given path.
@ -995,6 +1022,11 @@ class ModuleApi:
f, f,
) )
async def sleep(self, seconds: float) -> None:
"""Sleeps for the given number of seconds."""
await self._clock.sleep(seconds)
async def send_mail( async def send_mail(
self, self,
recipient: str, recipient: str,
@ -1149,6 +1181,26 @@ class ModuleApi:
return {key: state_events[event_id] for key, event_id in state_ids.items()} return {key: state_events[event_id] for key, event_id in state_ids.items()}
async def defer_to_thread(
self,
f: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> T:
"""Runs the given function in a separate thread from Synapse's thread pool.
Added in Synapse v1.49.0.
Args:
f: The function to run.
args: The function's arguments.
kwargs: The function's keyword arguments.
Returns:
The return value of the function once ran in a thread.
"""
return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
class PublicRoomListManager: class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room """Contains methods for adding to, removing from and querying whether a room

View File

@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from typing import (
TYPE_CHECKING,
AsyncContextManager,
Awaitable,
Callable,
Dict,
Iterable,
Optional,
)
import attr
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import Clock, json_encoder
from . import engines from . import engines
@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler:
"""A handler for a given background update.
Attributes:
callback: The function to call to make progress on the background
update.
oneshot: Wether the update is likely to happen all in one go, ignoring
the supplied target duration, e.g. index creation. This is used by
the update controller to help correctly schedule the update.
"""
callback: Callable[[JsonDict, int], Awaitable[int]]
oneshot: bool = False
class _BackgroundUpdateContextManager:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, sleep: bool, clock: Clock):
self._sleep = sleep
self._clock = clock
async def __aenter__(self) -> int:
if self._sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
return self.BACKGROUND_UPDATE_DURATION_MS
async def __aexit__(self, *exc) -> None:
pass
class BackgroundUpdatePerformance: class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items""" """Tracks the how long a background update is taking to update its items"""
@ -84,20 +133,22 @@ class BackgroundUpdater:
MINIMUM_BACKGROUND_BATCH_SIZE = 1 MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100 DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"): def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.db_pool = database self.db_pool = database
self._database_name = database.name()
# if a background update is currently running, its name. # if a background update is currently running, its name.
self._current_background_update: Optional[str] = None self._current_background_update: Optional[str] = None
self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
self._background_update_handlers: Dict[ self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
str, Callable[[JsonDict, int], Awaitable[int]]
] = {}
self._all_done = False self._all_done = False
# Whether we're currently running updates # Whether we're currently running updates
@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API. # enable/disable background updates via the admin API.
self.enabled = True self.enabled = True
def register_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
) -> None:
"""Register callbacks from a module for each hook."""
if self._on_update_callback is not None:
logger.warning(
"More than one module tried to register callbacks for controlling"
" background updates. Only the callbacks registered by the first module"
" (in order of appearance in Synapse's configuration file) that tried to"
" do so will be called."
)
return
self._on_update_callback = on_update
if default_batch_size is not None:
self._default_batch_size_callback = default_batch_size
if min_batch_size is not None:
self._min_batch_size_callback = min_batch_size
def _get_context_manager_for_update(
self,
sleep: bool,
update_name: str,
database_name: str,
oneshot: bool,
) -> AsyncContextManager[int]:
"""Get a context manager to run a background update with.
If a module has registered a `update_handler` callback, use the context manager
it returns.
Otherwise, returns a context manager that will return a default value, optionally
sleeping if needed.
Args:
sleep: Whether we can sleep between updates.
update_name: The name of the update.
database_name: The name of the database the update is being run on.
oneshot: Whether the update will complete all in one go, e.g. index creation.
In such cases the returned target duration is ignored.
Returns:
The target duration in milliseconds that the background update should run for.
Note: this is a *target*, and an iteration may take substantially longer or
shorter.
"""
if self._on_update_callback is not None:
return self._on_update_callback(update_name, database_name, oneshot)
return _BackgroundUpdateContextManager(sleep, self._clock)
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
"""The batch size to use for the first iteration of a new background
update.
"""
if self._default_batch_size_callback is not None:
return await self._default_batch_size_callback(update_name, database_name)
return self.DEFAULT_BACKGROUND_BATCH_SIZE
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
"""A lower bound on the batch size of a new background update.
Used to ensure that progress is always made. Must be greater than 0.
"""
if self._min_batch_size_callback is not None:
return await self._min_batch_size_callback(update_name, database_name)
return self.MINIMUM_BACKGROUND_BATCH_SIZE
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any.""" """Returns the current background update, if any."""
@ -135,13 +263,8 @@ class BackgroundUpdater:
try: try:
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while self.enabled: while self.enabled:
if sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try: try:
result = await self.do_next_background_update( result = await self.do_next_background_update(sleep)
self.BACKGROUND_UPDATE_DURATION_MS
)
except Exception: except Exception:
logger.exception("Error doing update") logger.exception("Error doing update")
else: else:
@ -203,13 +326,15 @@ class BackgroundUpdater:
return not update_exists return not update_exists
async def do_next_background_update(self, desired_duration_ms: float) -> bool: async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update """Does some amount of work on the next queued background update
Returns once some amount of work is done. Returns once some amount of work is done.
Args: Args:
desired_duration_ms: How long we want to spend updating. sleep: Whether to limit how quickly we run background updates or
not.
Returns: Returns:
True if we have finished running all the background updates, otherwise False True if we have finished running all the background updates, otherwise False
""" """
@ -252,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"] self._current_background_update = upd["update_name"]
# We have a background update to run, otherwise we would have returned
# early.
assert self._current_background_update is not None
update_info = self._background_update_handlers[self._current_background_update]
async with self._get_context_manager_for_update(
sleep=sleep,
update_name=self._current_background_update,
database_name=self._database_name,
oneshot=update_info.oneshot,
) as desired_duration_ms:
await self._do_background_update(desired_duration_ms) await self._do_background_update(desired_duration_ms)
return False return False
async def _do_background_update(self, desired_duration_ms: float) -> int: async def _do_background_update(self, desired_duration_ms: float) -> int:
@ -260,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name) logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name] update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name) performance = self._background_update_performance.get(update_name)
@ -273,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None: if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms) batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress # Clamp the batch size so that we always make progress
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) batch_size = max(
batch_size,
await self._min_batch_size(update_name, self._database_name),
)
else: else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE batch_size = await self._default_batch_size(
update_name, self._database_name
)
progress_json = await self.db_pool.simple_select_one_onecol( progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates", "background_updates",
@ -294,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start duration_ms = time_stop - time_start
performance.update(items_updated, duration_ms)
logger.info( logger.info(
"Running background update %r. Processed %r items in %rms." "Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@ -306,8 +450,6 @@ class BackgroundUpdater:
batch_size, batch_size,
) )
performance.update(items_updated, duration_ms)
return len(self._background_update_performance) return len(self._background_update_performance)
def register_background_update_handler( def register_background_update_handler(
@ -331,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles. update_name: The name of the update that this code handles.
update_handler: The function that does the update. update_handler: The function that does the update.
""" """
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
update_handler
)
def register_noop_background_update(self, update_name: str) -> None: def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update. """Register a noop handler for a background update.
@ -453,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name) await self._end_background_update(update_name)
return 1 return 1
self.register_background_update_handler(update_name, updater) self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
updater, oneshot=True
)
async def _end_background_update(self, update_name: str) -> None: async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue. """Removes a completed background update task from the queue.

View File

@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
) )
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastore()
def test_need_validated_email(self): def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated """Test that we can only add an email pusher if the user has validated
@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase):
self.hs.get_datastore().db_pool.updates._all_done = False self.hs.get_datastore().db_pool.updates._all_done = False
# Now let's actually drive the updates to completion # Now let's actually drive the updates to completion
while not self.get_success( self.wait_for_background_updates()
self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
):
self.get_success(
self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
by=0.1,
)
# Check that all pushers with unlinked addresses were deleted # Check that all pushers with unlinked addresses were deleted
pushers = self.get_success( pushers = self.get_success(

View File

@ -135,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self._register_bg_update() self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates() self.store.db_pool.updates.start_doing_background_updates()
self.reactor.pump([1.0, 1.0]) self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request( channel = self.make_request(
"GET", "GET",

View File

@ -1,8 +1,11 @@
from unittest.mock import Mock from mock import Mock
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@ -20,10 +23,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def test_do_background_update(self): def test_do_background_update(self):
# the time we claim it takes to update one item when running the update # the time we claim it takes to update one item when running the update
duration_ms = 4200 duration_ms = 10
# the target runtime for each bg update # the target runtime for each bg update
target_background_update_duration_ms = 5000000 target_background_update_duration_ms = 100
store = self.hs.get_datastore() store = self.hs.get_datastore()
self.get_success( self.get_success(
@ -48,10 +51,8 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update self.update_handler.side_effect = update
self.update_handler.reset_mock() self.update_handler.reset_mock()
res = self.get_success( res = self.get_success(
self.updates.do_next_background_update( self.updates.do_next_background_update(False),
target_background_update_duration_ms by=0.01,
),
by=0.1,
) )
self.assertFalse(res) self.assertFalse(res)
@ -74,16 +75,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update self.update_handler.side_effect = update
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = self.get_success( result = self.get_success(self.updates.do_next_background_update(False))
self.updates.do_next_background_update(target_background_update_duration_ms)
)
self.assertFalse(result) self.assertFalse(result)
self.update_handler.assert_called_once() self.update_handler.assert_called_once()
# third step: we don't expect to be called any more # third step: we don't expect to be called any more
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = self.get_success( result = self.get_success(self.updates.do_next_background_update(False))
self.updates.do_next_background_update(target_background_update_duration_ms)
)
self.assertTrue(result) self.assertTrue(result)
self.assertFalse(self.update_handler.called) self.assertFalse(self.update_handler.called)
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
)
self.update_deferred = Deferred()
self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler(
"test_update", self.update_handler
)
# Mock out the AsyncContextManager
self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
self._update_ctx_manager.__aenter__ = Mock(
return_value=make_awaitable(None),
)
self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
# Mock out the `update_handler` callback
self._on_update = Mock(return_value=self._update_ctx_manager)
# Define a default batch size value that's not the same as the internal default
# value (100).
self._default_batch_size = 500
# Register the callbacks with more mocks
self.hs.get_module_api().register_background_update_controller_callbacks(
on_update=self._on_update,
min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
default_batch_size=Mock(
return_value=make_awaitable(self._default_batch_size),
),
)
def test_controller(self):
store = self.hs.get_datastore()
self.get_success(
store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": "{}"},
)
)
# Set the return value for the context manager.
enter_defer = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update.
do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
self.pump()
# `run_update` should have been called, but the update handler won't be
# called until the `enter_defer` (returned by `__aenter__`) is resolved.
self._on_update.assert_called_once_with(
"test_update",
"master",
False,
)
self.assertFalse(do_update_d.called)
self.assertFalse(self.update_deferred.called)
# Resolving the `enter_defer` should call the update handler, which then
# blocks.
enter_defer.callback(100)
self.pump()
self.update_handler.assert_called_once_with({}, self._default_batch_size)
self.assertFalse(self.update_deferred.called)
self._update_ctx_manager.__aexit__.assert_not_called()
# Resolving the update handler deferred should cause the
# `do_next_background_update` to finish and return
self.update_deferred.callback(100)
self.pump()
self._update_ctx_manager.__aexit__.assert_called()
self.get_success(do_update_d)

View File

@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
): ):
iterations += 1 iterations += 1
self.get_success( self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1 self.store.db_pool.updates.do_next_background_update(False), by=0.1
) )
# Ensure that we did actually take multiple iterations to process the # Ensure that we did actually take multiple iterations to process the
@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
): ):
iterations += 1 iterations += 1
self.get_success( self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1 self.store.db_pool.updates.do_next_background_update(False), by=0.1
) )
# Ensure that we did actually take multiple iterations to process the # Ensure that we did actually take multiple iterations to process the

View File

@ -23,6 +23,7 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room from synapse.rest.client import login, register, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock from synapse.util import Clock
@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
with mock.patch.dict( with mock.patch.dict(
self.store.db_pool.updates._background_update_handlers, self.store.db_pool.updates._background_update_handlers,
populate_user_directory_process_users=mocked_process_users, populate_user_directory_process_users=_BackgroundUpdateHandler(
mocked_process_users,
),
): ):
self._purge_and_rebuild_user_dir() self._purge_and_rebuild_user_dir()

View File

@ -331,17 +331,16 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01) time.sleep(0.01)
def wait_for_background_updates(self) -> None: def wait_for_background_updates(self) -> None:
""" """Block until all background database updates have completed.
Block until all background database updates have completed.
Note that callers must ensure that's a store property created on the Note that callers must ensure there's a store property created on the
testcase. testcase.
""" """
while not self.get_success( while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates() self.store.db_pool.updates.has_completed_background_updates()
): ):
self.get_success( self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1 self.store.db_pool.updates.do_next_background_update(False), by=0.1
) )
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -500,8 +499,7 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates(): async def run_bg_updates():
with LoggingContext("run_bg_updates"): with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates(): self.get_success(stor.db_pool.updates.run_background_updates(False))
await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore() stor = hs.get_datastore()