Add a batching queue implementation. (#10017)
This commit is contained in:
parent
1c6a19002c
commit
7958eadcd1
|
@ -0,0 +1 @@
|
||||||
|
Add a batching queue implementation.
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Hashable,
|
||||||
|
List,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
|
from synapse.metrics import LaterGauge
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
V = TypeVar("V")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchingQueue(Generic[V, R]):
|
||||||
|
"""A queue that batches up work, calling the provided processing function
|
||||||
|
with all pending work (for a given key).
|
||||||
|
|
||||||
|
The provided processing function will only be called once at a time for each
|
||||||
|
key. It will be called the next reactor tick after `add_to_queue` has been
|
||||||
|
called, and will keep being called until the queue has been drained (for the
|
||||||
|
given key).
|
||||||
|
|
||||||
|
Note that the return value of `add_to_queue` will be the return value of the
|
||||||
|
processing function that processed the given item. This means that the
|
||||||
|
returned value will likely include data for other items that were in the
|
||||||
|
batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
clock: Clock,
|
||||||
|
process_batch_callback: Callable[[List[V]], Awaitable[R]],
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._clock = clock
|
||||||
|
|
||||||
|
# The set of keys currently being processed.
|
||||||
|
self._processing_keys = set() # type: Set[Hashable]
|
||||||
|
|
||||||
|
# The currently pending batch of values by key, with a Deferred to call
|
||||||
|
# with the result of the corresponding `_process_batch_callback` call.
|
||||||
|
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
|
||||||
|
|
||||||
|
# The function to call with batches of values.
|
||||||
|
self._process_batch_callback = process_batch_callback
|
||||||
|
|
||||||
|
LaterGauge(
|
||||||
|
"synapse_util_batching_queue_number_queued",
|
||||||
|
"The number of items waiting in the queue across all keys",
|
||||||
|
labels=("name",),
|
||||||
|
caller=lambda: sum(len(v) for v in self._next_values.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
LaterGauge(
|
||||||
|
"synapse_util_batching_queue_number_of_keys",
|
||||||
|
"The number of distinct keys that have items queued",
|
||||||
|
labels=("name",),
|
||||||
|
caller=lambda: len(self._next_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||||
|
"""Adds the value to the queue with the given key, returning the result
|
||||||
|
of the processing function for the batch that included the given value.
|
||||||
|
|
||||||
|
The optional `key` argument allows sharding the queue by some key. The
|
||||||
|
queues will then be processed in parallel, i.e. the process batch
|
||||||
|
function will be called in parallel with batched values from a single
|
||||||
|
key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First we create a defer and add it and the value to the list of
|
||||||
|
# pending items.
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._next_values.setdefault(key, []).append((value, d))
|
||||||
|
|
||||||
|
# If we're not currently processing the key fire off a background
|
||||||
|
# process to start processing.
|
||||||
|
if key not in self._processing_keys:
|
||||||
|
run_as_background_process(self._name, self._process_queue, key)
|
||||||
|
|
||||||
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
async def _process_queue(self, key: Hashable) -> None:
|
||||||
|
"""A background task to repeatedly pull things off the queue for the
|
||||||
|
given key and call the `self._process_batch_callback` with the values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if key in self._processing_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._processing_keys.add(key)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# We purposefully wait a reactor tick to allow us to batch
|
||||||
|
# together requests that we're about to receive. A common
|
||||||
|
# pattern is to call `add_to_queue` multiple times at once, and
|
||||||
|
# deferring to the next reactor tick allows us to batch all of
|
||||||
|
# those up.
|
||||||
|
await self._clock.sleep(0)
|
||||||
|
|
||||||
|
next_values = self._next_values.pop(key, [])
|
||||||
|
if not next_values:
|
||||||
|
# We've exhausted the queue.
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
values = [value for value, _ in next_values]
|
||||||
|
results = await self._process_batch_callback(values)
|
||||||
|
|
||||||
|
for _, deferred in next_values:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
deferred.callback(results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
for _, deferred in next_values:
|
||||||
|
if deferred.called:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
deferred.errback(e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._processing_keys.discard(key)
|
|
@ -0,0 +1,169 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.util.batching_queue import BatchingQueue
|
||||||
|
|
||||||
|
from tests.server import get_clock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class BatchingQueueTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.clock, hs_clock = get_clock()
|
||||||
|
|
||||||
|
self._pending_calls = []
|
||||||
|
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||||
|
|
||||||
|
async def _process_queue(self, values):
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._pending_calls.append((values, d))
|
||||||
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
"""Tests the basic case of calling `add_to_queue` once and having
|
||||||
|
`_process_queue` return.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
||||||
|
|
||||||
|
# The queue should wait a reactor tick before calling the processing
|
||||||
|
# function.
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
self.assertFalse(queue_d.called)
|
||||||
|
|
||||||
|
# We should see a call to `_process_queue` after a reactor tick.
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
||||||
|
self.assertFalse(queue_d.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back.
|
||||||
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d), "bar")
|
||||||
|
|
||||||
|
def test_batching(self):
|
||||||
|
"""Test that multiple calls at the same time get batched up into one
|
||||||
|
call to `_process_queue`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
# We should see only *one* call to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to both.
|
||||||
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||||
|
|
||||||
|
def test_queuing(self):
|
||||||
|
"""Test that we queue up requests while a `_process_queue` is being
|
||||||
|
called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
|
||||||
|
# We should see only *one* call to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# first.
|
||||||
|
self._pending_calls.pop()[1].callback("bar1")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# We should now see a second call to `_process_queue`
|
||||||
|
self.clock.pump([0])
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo2"])
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# second.
|
||||||
|
self._pending_calls.pop()[1].callback("bar2")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||||
|
|
||||||
|
def test_different_keys(self):
|
||||||
|
"""Test that calls to different keys get processed in parallel."""
|
||||||
|
|
||||||
|
self.assertFalse(self._pending_calls)
|
||||||
|
|
||||||
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
|
||||||
|
self.clock.pump([0])
|
||||||
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
|
||||||
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
# We queue up another item with key=2 to check that we will keep taking
|
||||||
|
# things off the queue.
|
||||||
|
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
|
||||||
|
|
||||||
|
# We should see two calls to `_process_queue`
|
||||||
|
self.assertEqual(len(self._pending_calls), 2)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||||
|
self.assertEqual(self._pending_calls[1][0], ["foo2"])
|
||||||
|
self.assertFalse(queue_d1.called)
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# first.
|
||||||
|
self._pending_calls.pop(0)[1].callback("bar1")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# second.
|
||||||
|
self._pending_calls.pop()[1].callback("bar2")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# We should now see a call `_pending_calls` for `foo3`
|
||||||
|
self.clock.pump([0])
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
|
||||||
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
|
# third deferred.
|
||||||
|
self._pending_calls.pop()[1].callback("bar4")
|
||||||
|
|
||||||
|
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
Loading…
Reference in New Issue