diff --git a/changelog.d/10017.misc b/changelog.d/10017.misc new file mode 100644 index 0000000000..4777b7fb57 --- /dev/null +++ b/changelog.d/10017.misc @@ -0,0 +1 @@ +Add a batching queue implementation. diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py new file mode 100644 index 0000000000..44bbb7b1a8 --- /dev/null +++ b/synapse/util/batching_queue.py @@ -0,0 +1,153 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import ( + Awaitable, + Callable, + Dict, + Generic, + Hashable, + List, + Set, + Tuple, + TypeVar, +) + +from twisted.internet import defer + +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock + +logger = logging.getLogger(__name__) + + +V = TypeVar("V") +R = TypeVar("R") + + +class BatchingQueue(Generic[V, R]): + """A queue that batches up work, calling the provided processing function + with all pending work (for a given key). + + The provided processing function will only be called once at a time for each + key. It will be called the next reactor tick after `add_to_queue` has been + called, and will keep being called until the queue has been drained (for the + given key). + + Note that the return value of `add_to_queue` will be the return value of the + processing function that processed the given item. This means that the + returned value will likely include data for other items that were in the + batch. + """ + + def __init__( + self, + name: str, + clock: Clock, + process_batch_callback: Callable[[List[V]], Awaitable[R]], + ): + self._name = name + self._clock = clock + + # The set of keys currently being processed. + self._processing_keys = set() # type: Set[Hashable] + + # The currently pending batch of values by key, with a Deferred to call + # with the result of the corresponding `_process_batch_callback` call. + self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] + + # The function to call with batches of values. + self._process_batch_callback = process_batch_callback + + LaterGauge( + "synapse_util_batching_queue_number_queued", + "The number of items waiting in the queue across all keys", + labels=("name",), + caller=lambda: sum(len(v) for v in self._next_values.values()), + ) + + LaterGauge( + "synapse_util_batching_queue_number_of_keys", + "The number of distinct keys that have items queued", + labels=("name",), + caller=lambda: len(self._next_values), + ) + + async def add_to_queue(self, value: V, key: Hashable = ()) -> R: + """Adds the value to the queue with the given key, returning the result + of the processing function for the batch that included the given value. + + The optional `key` argument allows sharding the queue by some key. The + queues will then be processed in parallel, i.e. the process batch + function will be called in parallel with batched values from a single + key. + """ + + # First we create a defer and add it and the value to the list of + # pending items. + d = defer.Deferred() + self._next_values.setdefault(key, []).append((value, d)) + + # If we're not currently processing the key fire off a background + # process to start processing. + if key not in self._processing_keys: + run_as_background_process(self._name, self._process_queue, key) + + return await make_deferred_yieldable(d) + + async def _process_queue(self, key: Hashable) -> None: + """A background task to repeatedly pull things off the queue for the + given key and call the `self._process_batch_callback` with the values. + """ + + try: + if key in self._processing_keys: + return + + self._processing_keys.add(key) + + while True: + # We purposefully wait a reactor tick to allow us to batch + # together requests that we're about to receive. A common + # pattern is to call `add_to_queue` multiple times at once, and + # deferring to the next reactor tick allows us to batch all of + # those up. + await self._clock.sleep(0) + + next_values = self._next_values.pop(key, []) + if not next_values: + # We've exhausted the queue. + break + + try: + values = [value for value, _ in next_values] + results = await self._process_batch_callback(values) + + for _, deferred in next_values: + with PreserveLoggingContext(): + deferred.callback(results) + + except Exception as e: + for _, deferred in next_values: + if deferred.called: + continue + + with PreserveLoggingContext(): + deferred.errback(e) + + finally: + self._processing_keys.discard(key) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py new file mode 100644 index 0000000000..5def1e56c9 --- /dev/null +++ b/tests/util/test_batching_queue.py @@ -0,0 +1,169 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable +from synapse.util.batching_queue import BatchingQueue + +from tests.server import get_clock +from tests.unittest import TestCase + + +class BatchingQueueTestCase(TestCase): + def setUp(self): + self.clock, hs_clock = get_clock() + + self._pending_calls = [] + self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) + + async def _process_queue(self, values): + d = defer.Deferred() + self._pending_calls.append((values, d)) + return await make_deferred_yieldable(d) + + def test_simple(self): + """Tests the basic case of calling `add_to_queue` once and having + `_process_queue` return. + """ + + self.assertFalse(self._pending_calls) + + queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo")) + + # The queue should wait a reactor tick before calling the processing + # function. + self.assertFalse(self._pending_calls) + self.assertFalse(queue_d.called) + + # We should see a call to `_process_queue` after a reactor tick. + self.clock.pump([0]) + + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo"]) + self.assertFalse(queue_d.called) + + # Return value of the `_process_queue` should be propagated back. + self._pending_calls.pop()[1].callback("bar") + + self.assertEqual(self.successResultOf(queue_d), "bar") + + def test_batching(self): + """Test that multiple calls at the same time get batched up into one + call to `_process_queue`. + """ + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + + self.clock.pump([0]) + + # We should see only *one* call to `_process_queue` + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to both. + self._pending_calls.pop()[1].callback("bar") + + self.assertEqual(self.successResultOf(queue_d1), "bar") + self.assertEqual(self.successResultOf(queue_d2), "bar") + + def test_queuing(self): + """Test that we queue up requests while a `_process_queue` is being + called. + """ + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) + self.clock.pump([0]) + + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + + # We should see only *one* call to `_process_queue` + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo1"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to the + # first. + self._pending_calls.pop()[1].callback("bar1") + + self.assertEqual(self.successResultOf(queue_d1), "bar1") + self.assertFalse(queue_d2.called) + + # We should now see a second call to `_process_queue` + self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo2"]) + self.assertFalse(queue_d2.called) + + # Return value of the `_process_queue` should be propagated back to the + # second. + self._pending_calls.pop()[1].callback("bar2") + + self.assertEqual(self.successResultOf(queue_d2), "bar2") + + def test_different_keys(self): + """Test that calls to different keys get processed in parallel.""" + + self.assertFalse(self._pending_calls) + + queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1)) + self.clock.pump([0]) + queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2)) + self.clock.pump([0]) + + # We queue up another item with key=2 to check that we will keep taking + # things off the queue. + queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2)) + + # We should see two calls to `_process_queue` + self.assertEqual(len(self._pending_calls), 2) + self.assertEqual(self._pending_calls[0][0], ["foo1"]) + self.assertEqual(self._pending_calls[1][0], ["foo2"]) + self.assertFalse(queue_d1.called) + self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # first. + self._pending_calls.pop(0)[1].callback("bar1") + + self.assertEqual(self.successResultOf(queue_d1), "bar1") + self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # second. + self._pending_calls.pop()[1].callback("bar2") + + self.assertEqual(self.successResultOf(queue_d2), "bar2") + self.assertFalse(queue_d3.called) + + # We should now see a call `_pending_calls` for `foo3` + self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + self.assertEqual(self._pending_calls[0][0], ["foo3"]) + self.assertFalse(queue_d3.called) + + # Return value of the `_process_queue` should be propagated back to the + # third deferred. + self._pending_calls.pop()[1].callback("bar4") + + self.assertEqual(self.successResultOf(queue_d3), "bar4")