Task scheduler: add replication notify for new task to launch ASAP (#16184)

This commit is contained in:
Mathieu Velten 2023-08-28 16:03:51 +02:00 committed by GitHub
parent 224c2bbcfa
commit 501da8ecd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 67 deletions

1
changelog.d/16184.misc Normal file
View File

@ -0,0 +1 @@
Task scheduler: add replication notify for new task to launch ASAP.

View File

@ -452,6 +452,17 @@ class LockReleasedCommand(Command):
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
class NewActiveTaskCommand(_SimpleCommand):
"""Sent to inform instance handling background tasks that a new active task is available to run.
Format::
NEW_ACTIVE_TASK "<task_id>"
"""
NAME = "NEW_ACTIVE_TASK"
_COMMANDS: Tuple[Type[Command], ...] = ( _COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand, ServerCommand,
RdataCommand, RdataCommand,
@ -466,6 +477,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
LockReleasedCommand, LockReleasedCommand,
NewActiveTaskCommand,
) )
# Map of command name to command type. # Map of command name to command type.

View File

@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
Command, Command,
FederationAckCommand, FederationAckCommand,
LockReleasedCommand, LockReleasedCommand,
NewActiveTaskCommand,
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
@ -238,6 +239,10 @@ class ReplicationCommandHandler:
if self._is_master: if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
self._task_scheduler = None
if hs.config.worker.run_background_tasks:
self._task_scheduler = hs.get_task_scheduler()
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
# If we're using Redis, it's the background worker that should # If we're using Redis, it's the background worker that should
# receive USER_IP commands and store the relevant client IPs. # receive USER_IP commands and store the relevant client IPs.
@ -663,6 +668,15 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key cmd.instance_name, cmd.lock_name, cmd.lock_key
) )
async def on_NEW_ACTIVE_TASK(
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
) -> None:
"""Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler:
task = await self._task_scheduler.get_task(cmd.data)
if task:
await self._task_scheduler._launch_task(task)
def new_connection(self, connection: IReplicationConnection) -> None: def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection.""" """Called when we have a new connection."""
self._connections.append(connection) self._connections.append(connection)
@ -776,6 +790,10 @@ class ReplicationCommandHandler:
if instance_name == self._instance_name: if instance_name == self._instance_name:
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
def send_new_active_task(self, task_id: str) -> None:
"""Called when a new task has been scheduled for immediate launch and is ACTIVE."""
self.send_command(NewActiveTaskCommand(task_id))
UpdateToken = TypeVar("UpdateToken") UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow") UpdateRow = TypeVar("UpdateRow")

View File

@ -57,14 +57,13 @@ class TaskScheduler:
the code launching the task. the code launching the task.
You can also specify the `result` (and/or an `error`) when returning from the function. You can also specify the `result` (and/or an `error`) when returning from the function.
The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting The reconciliation loop runs every minute, so this is not a precise scheduler.
to launch now, the launch will still not happen before the next loop run.
Tasks will be run on the worker specified with `run_background_tasks_on` config,
or the main one by default.
There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already
full. In this regard, please take great care that scheduled tasks can actually finished. full. In this regard, please take great care that scheduled tasks can actually finished.
For now there is no mechanism to stop a running task if it is stuck. For now there is no mechanism to stop a running task if it is stuck.
Tasks will be run on the worker specified with `run_background_tasks_on` config,
or the main one by default.
""" """
# Precision of the scheduler, evaluation of tasks to run will only happen # Precision of the scheduler, evaluation of tasks to run will only happen
@ -85,7 +84,7 @@ class TaskScheduler:
self._actions: Dict[ self._actions: Dict[
str, str,
Callable[ Callable[
[ScheduledTask, bool], [ScheduledTask],
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
], ],
] = {} ] = {}
@ -98,11 +97,13 @@ class TaskScheduler:
"handle_scheduled_tasks", "handle_scheduled_tasks",
self._handle_scheduled_tasks, self._handle_scheduled_tasks,
) )
else:
self.replication_client = hs.get_replication_command_handler()
def register_action( def register_action(
self, self,
function: Callable[ function: Callable[
[ScheduledTask, bool], [ScheduledTask],
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
], ],
action_name: str, action_name: str,
@ -115,10 +116,9 @@ class TaskScheduler:
calling `schedule_task` but rather in an `__init__` method. calling `schedule_task` but rather in an `__init__` method.
Args: Args:
function: The function to be executed for this action. The parameters function: The function to be executed for this action. The parameter
passed to the function when launched are the `ScheduledTask` being run, passed to the function when launched is the `ScheduledTask` being run.
and a `first_launch` boolean to signal if it's a resumed task or the first The function should return a tuple of new `status`, `result`
launch of it. The function should return a tuple of new `status`, `result`
and `error` as specified in `ScheduledTask`. and `error` as specified in `ScheduledTask`.
action_name: The name of the action to be associated with the function action_name: The name of the action to be associated with the function
""" """
@ -171,6 +171,12 @@ class TaskScheduler:
) )
await self._store.insert_scheduled_task(task) await self._store.insert_scheduled_task(task)
if status == TaskStatus.ACTIVE:
if self._run_background_tasks:
await self._launch_task(task)
else:
self.replication_client.send_new_active_task(task.id)
return task.id return task.id
async def update_task( async def update_task(
@ -265,21 +271,13 @@ class TaskScheduler:
Args: Args:
id: id of the task to delete id: id of the task to delete
""" """
if self.task_is_running(id): task = await self.get_task(id)
raise Exception(f"Task {id} is currently running and can't be deleted") if task is None:
raise Exception(f"Task {id} does not exist")
if task.status == TaskStatus.ACTIVE:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id) await self._store.delete_scheduled_task(id)
def task_is_running(self, id: str) -> bool:
"""Check if a task is currently running.
Can only be called from the worker handling the task scheduling.
Args:
id: id of the task to check
"""
assert self._run_background_tasks
return id in self._running_tasks
async def _handle_scheduled_tasks(self) -> None: async def _handle_scheduled_tasks(self) -> None:
"""Main loop taking care of launching tasks and cleaning up old ones.""" """Main loop taking care of launching tasks and cleaning up old ones."""
await self._launch_scheduled_tasks() await self._launch_scheduled_tasks()
@ -288,29 +286,11 @@ class TaskScheduler:
async def _launch_scheduled_tasks(self) -> None: async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time.""" """Retrieve and launch scheduled tasks that should be running at that time."""
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
if not self.task_is_running(task.id): await self._launch_task(task)
if (
len(self._running_tasks)
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
):
await self._launch_task(task, first_launch=False)
else:
if (
self._clock.time_msec()
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
):
logger.warn(
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
)
for task in await self.get_tasks( for task in await self.get_tasks(
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
): ):
if ( await self._launch_task(task)
not self.task_is_running(task.id)
and len(self._running_tasks)
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
):
await self._launch_task(task, first_launch=True)
running_tasks_gauge.set(len(self._running_tasks)) running_tasks_gauge.set(len(self._running_tasks))
@ -320,27 +300,27 @@ class TaskScheduler:
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
): ):
# FAILED and COMPLETE tasks should never be running # FAILED and COMPLETE tasks should never be running
assert not self.task_is_running(task.id) assert task.id not in self._running_tasks
if ( if (
self._clock.time_msec() self._clock.time_msec()
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
): ):
await self._store.delete_scheduled_task(task.id) await self._store.delete_scheduled_task(task.id)
async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now. """Launch a scheduled task now.
Args: Args:
task: the task to launch task: the task to launch
first_launch: `True` if it's the first time is launched, `False` otherwise
""" """
assert task.action in self._actions assert self._run_background_tasks
assert task.action in self._actions
function = self._actions[task.action] function = self._actions[task.action]
async def wrapper() -> None: async def wrapper() -> None:
try: try:
(status, result, error) = await function(task, first_launch) (status, result, error) = await function(task)
except Exception: except Exception:
f = Failure() f = Failure()
logger.error( logger.error(
@ -360,6 +340,20 @@ class TaskScheduler:
) )
self._running_tasks.remove(task.id) self._running_tasks.remove(task.id)
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
if (
self._clock.time_msec()
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
):
logger.warn(
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
)
if task.id in self._running_tasks:
return
self._running_tasks.add(task.id) self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE) await self.update_task(task.id, status=TaskStatus.ACTIVE)
description = f"{task.id}-{task.action}" description = f"{task.id}-{task.action}"

View File

@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus
from synapse.util import Clock from synapse.util import Clock
from synapse.util.task_scheduler import TaskScheduler from synapse.util.task_scheduler import TaskScheduler
from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase, override_config
class TestTaskScheduler(unittest.HomeserverTestCase): class TestTaskScheduler(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.task_scheduler = hs.get_task_scheduler() self.task_scheduler = hs.get_task_scheduler()
self.task_scheduler.register_action(self._test_task, "_test_task") self.task_scheduler.register_action(self._test_task, "_test_task")
@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.task_scheduler.register_action(self._resumable_task, "_resumable_task") self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
async def _test_task( async def _test_task(
self, task: ScheduledTask, first_launch: bool self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
# This test task will copy the parameters to the result # This test task will copy the parameters to the result
result = None result = None
@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.assertIsNone(task) self.assertIsNone(task)
async def _sleeping_task( async def _sleeping_task(
self, task: ScheduledTask, first_launch: bool self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
# Sleep for a second # Sleep for a second
await deferLater(self.reactor, 1, lambda: None) await deferLater(self.reactor, 1, lambda: None)
@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
def test_schedule_lot_of_tasks(self) -> None: def test_schedule_lot_of_tasks(self) -> None:
"""Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
timestamp = self.clock.time_msec() + 30 * 1000
task_ids = [] task_ids = []
for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
task_ids.append( task_ids.append(
self.get_success( self.get_success(
self.task_scheduler.schedule_task( self.task_scheduler.schedule_task(
"_sleeping_task", "_sleeping_task",
timestamp=timestamp,
params={"val": i}, params={"val": i},
) )
) )
) )
# The timestamp being 30s after now the task should been executed # This is to give the time to the active tasks to finish
# after the first scheduling loop is run
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
# This is to give the time to the sleeping tasks to finish
self.reactor.advance(1) self.reactor.advance(1)
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
) )
scheduled_tasks = [ scheduled_tasks = [
t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
] ]
self.assertEquals(len(scheduled_tasks), 1) self.assertEquals(len(scheduled_tasks), 1)
# We need to wait for the next run of the scheduler loop
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
self.reactor.advance(1) self.reactor.advance(1)
@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
) )
async def _raising_task( async def _raising_task(
self, task: ScheduledTask, first_launch: bool self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
raise Exception("raising") raise Exception("raising")
@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
"""Schedule a task raising an exception and check it runs to failure and report exception content.""" """Schedule a task raising an exception and check it runs to failure and report exception content."""
task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
task = self.get_success(self.task_scheduler.get_task(task_id)) task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None assert task is not None
self.assertEqual(task.status, TaskStatus.FAILED) self.assertEqual(task.status, TaskStatus.FAILED)
self.assertEqual(task.error, "raising") self.assertEqual(task.error, "raising")
async def _resumable_task( async def _resumable_task(
self, task: ScheduledTask, first_launch: bool self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
if task.result and "in_progress" in task.result: if task.result and "in_progress" in task.result:
return TaskStatus.COMPLETE, {"success": True}, None return TaskStatus.COMPLETE, {"success": True}, None
@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
"""Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
task = self.get_success(self.task_scheduler.get_task(task_id)) task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None assert task is not None
self.assertEqual(task.status, TaskStatus.ACTIVE) self.assertEqual(task.status, TaskStatus.ACTIVE)
@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.assertEqual(task.status, TaskStatus.COMPLETE) self.assertEqual(task.status, TaskStatus.COMPLETE)
assert task.result is not None assert task.result is not None
self.assertTrue(task.result.get("success")) self.assertTrue(task.result.get("success"))
class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.task_scheduler = hs.get_task_scheduler()
self.task_scheduler.register_action(self._test_task, "_test_task")
async def _test_task(
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
return (TaskStatus.COMPLETE, None, None)
@override_config({"run_background_tasks_on": "worker1"})
def test_schedule_task(self) -> None:
"""Check that a task scheduled to run now is launch right away on the background worker."""
bg_worker_hs = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={"worker_name": "worker1"},
)
bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
task_id = self.get_success(
self.task_scheduler.schedule_task(
"_test_task",
)
)
task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
self.assertEqual(task.status, TaskStatus.COMPLETE)