parent
d9f44fd0b9
commit
78b5102ae7
|
@ -0,0 +1 @@
|
||||||
|
Fix up `BatchingQueue` implementation.
|
|
@ -25,10 +25,11 @@ from typing import (
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics import LaterGauge
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -38,6 +39,24 @@ logger = logging.getLogger(__name__)
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
number_queued = Gauge(
|
||||||
|
"synapse_util_batching_queue_number_queued",
|
||||||
|
"The number of items waiting in the queue across all keys",
|
||||||
|
labelnames=("name",),
|
||||||
|
)
|
||||||
|
|
||||||
|
number_in_flight = Gauge(
|
||||||
|
"synapse_util_batching_queue_number_pending",
|
||||||
|
"The number of items across all keys either being processed or waiting in a queue",
|
||||||
|
labelnames=("name",),
|
||||||
|
)
|
||||||
|
|
||||||
|
number_of_keys = Gauge(
|
||||||
|
"synapse_util_batching_queue_number_of_keys",
|
||||||
|
"The number of distinct keys that have items queued",
|
||||||
|
labelnames=("name",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchingQueue(Generic[V, R]):
|
class BatchingQueue(Generic[V, R]):
|
||||||
"""A queue that batches up work, calling the provided processing function
|
"""A queue that batches up work, calling the provided processing function
|
||||||
|
@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
|
||||||
called, and will keep being called until the queue has been drained (for the
|
called, and will keep being called until the queue has been drained (for the
|
||||||
given key).
|
given key).
|
||||||
|
|
||||||
|
If the processing function raises an exception then the exception is proxied
|
||||||
|
through to the callers waiting on that batch of work.
|
||||||
|
|
||||||
Note that the return value of `add_to_queue` will be the return value of the
|
Note that the return value of `add_to_queue` will be the return value of the
|
||||||
processing function that processed the given item. This means that the
|
processing function that processed the given item. This means that the
|
||||||
returned value will likely include data for other items that were in the
|
returned value will likely include data for other items that were in the
|
||||||
batch.
|
batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: A name for the queue, used for logging contexts and metrics.
|
||||||
|
This must be unique, otherwise the metrics will be wrong.
|
||||||
|
clock: The clock to use to schedule work.
|
||||||
|
process_batch_callback: The callback to to be run to process a batch of
|
||||||
|
work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]):
|
||||||
# The function to call with batches of values.
|
# The function to call with batches of values.
|
||||||
self._process_batch_callback = process_batch_callback
|
self._process_batch_callback = process_batch_callback
|
||||||
|
|
||||||
LaterGauge(
|
number_queued.labels(self._name).set_function(
|
||||||
"synapse_util_batching_queue_number_queued",
|
lambda: sum(len(q) for q in self._next_values.values())
|
||||||
"The number of items waiting in the queue across all keys",
|
|
||||||
labels=("name",),
|
|
||||||
caller=lambda: sum(len(v) for v in self._next_values.values()),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LaterGauge(
|
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
|
||||||
"synapse_util_batching_queue_number_of_keys",
|
|
||||||
"The number of distinct keys that have items queued",
|
self._number_in_flight_metric = number_in_flight.labels(
|
||||||
labels=("name",),
|
self._name
|
||||||
caller=lambda: len(self._next_values),
|
) # type: Gauge
|
||||||
)
|
|
||||||
|
|
||||||
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||||
"""Adds the value to the queue with the given key, returning the result
|
"""Adds the value to the queue with the given key, returning the result
|
||||||
|
@ -107,17 +132,18 @@ class BatchingQueue(Generic[V, R]):
|
||||||
if key not in self._processing_keys:
|
if key not in self._processing_keys:
|
||||||
run_as_background_process(self._name, self._process_queue, key)
|
run_as_background_process(self._name, self._process_queue, key)
|
||||||
|
|
||||||
return await make_deferred_yieldable(d)
|
with self._number_in_flight_metric.track_inprogress():
|
||||||
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
async def _process_queue(self, key: Hashable) -> None:
|
async def _process_queue(self, key: Hashable) -> None:
|
||||||
"""A background task to repeatedly pull things off the queue for the
|
"""A background task to repeatedly pull things off the queue for the
|
||||||
given key and call the `self._process_batch_callback` with the values.
|
given key and call the `self._process_batch_callback` with the values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
if key in self._processing_keys:
|
||||||
if key in self._processing_keys:
|
return
|
||||||
return
|
|
||||||
|
|
||||||
|
try:
|
||||||
self._processing_keys.add(key)
|
self._processing_keys.add(key)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]):
|
||||||
values = [value for value, _ in next_values]
|
values = [value for value, _ in next_values]
|
||||||
results = await self._process_batch_callback(values)
|
results = await self._process_batch_callback(values)
|
||||||
|
|
||||||
for _, deferred in next_values:
|
with PreserveLoggingContext():
|
||||||
with PreserveLoggingContext():
|
for _, deferred in next_values:
|
||||||
deferred.callback(results)
|
deferred.callback(results)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
for _, deferred in next_values:
|
with PreserveLoggingContext():
|
||||||
if deferred.called:
|
for _, deferred in next_values:
|
||||||
continue
|
if deferred.called:
|
||||||
|
continue
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
deferred.errback(e)
|
deferred.errback(e)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -14,7 +14,12 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.util.batching_queue import BatchingQueue
|
from synapse.util.batching_queue import (
|
||||||
|
BatchingQueue,
|
||||||
|
number_in_flight,
|
||||||
|
number_of_keys,
|
||||||
|
number_queued,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.server import get_clock
|
from tests.server import get_clock
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.clock, hs_clock = get_clock()
|
self.clock, hs_clock = get_clock()
|
||||||
|
|
||||||
|
# We ensure that we remove any existing metrics for "test_queue".
|
||||||
|
try:
|
||||||
|
number_queued.remove("test_queue")
|
||||||
|
number_of_keys.remove("test_queue")
|
||||||
|
number_in_flight.remove("test_queue")
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
self._pending_calls = []
|
self._pending_calls = []
|
||||||
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||||
|
|
||||||
|
@ -32,6 +45,41 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self._pending_calls.append((values, d))
|
self._pending_calls.append((values, d))
|
||||||
return await make_deferred_yieldable(d)
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
def _assert_metrics(self, queued, keys, in_flight):
|
||||||
|
"""Assert that the metrics are correct"""
|
||||||
|
|
||||||
|
self.assertEqual(len(number_queued.collect()), 1)
|
||||||
|
self.assertEqual(len(number_queued.collect()[0].samples), 1)
|
||||||
|
self.assertEqual(
|
||||||
|
number_queued.collect()[0].samples[0].labels,
|
||||||
|
{"name": self.queue._name},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
number_queued.collect()[0].samples[0].value,
|
||||||
|
queued,
|
||||||
|
"number_queued",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(number_of_keys.collect()), 1)
|
||||||
|
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
|
||||||
|
self.assertEqual(
|
||||||
|
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(number_in_flight.collect()), 1)
|
||||||
|
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
|
||||||
|
self.assertEqual(
|
||||||
|
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
number_in_flight.collect()[0].samples[0].value,
|
||||||
|
in_flight,
|
||||||
|
"number_in_flight",
|
||||||
|
)
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
"""Tests the basic case of calling `add_to_queue` once and having
|
"""Tests the basic case of calling `add_to_queue` once and having
|
||||||
`_process_queue` return.
|
`_process_queue` return.
|
||||||
|
@ -41,6 +89,8 @@ class BatchingQueueTestCase(TestCase):
|
||||||
|
|
||||||
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
||||||
|
|
||||||
|
self._assert_metrics(queued=1, keys=1, in_flight=1)
|
||||||
|
|
||||||
# The queue should wait a reactor tick before calling the processing
|
# The queue should wait a reactor tick before calling the processing
|
||||||
# function.
|
# function.
|
||||||
self.assertFalse(self._pending_calls)
|
self.assertFalse(self._pending_calls)
|
||||||
|
@ -52,12 +102,15 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self.assertEqual(len(self._pending_calls), 1)
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
||||||
self.assertFalse(queue_d.called)
|
self.assertFalse(queue_d.called)
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=1)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back.
|
# Return value of the `_process_queue` should be propagated back.
|
||||||
self._pending_calls.pop()[1].callback("bar")
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(queue_d), "bar")
|
self.assertEqual(self.successResultOf(queue_d), "bar")
|
||||||
|
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||||
|
|
||||||
def test_batching(self):
|
def test_batching(self):
|
||||||
"""Test that multiple calls at the same time get batched up into one
|
"""Test that multiple calls at the same time get batched up into one
|
||||||
call to `_process_queue`.
|
call to `_process_queue`.
|
||||||
|
@ -68,6 +121,8 @@ class BatchingQueueTestCase(TestCase):
|
||||||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
|
||||||
|
self._assert_metrics(queued=2, keys=1, in_flight=2)
|
||||||
|
|
||||||
self.clock.pump([0])
|
self.clock.pump([0])
|
||||||
|
|
||||||
# We should see only *one* call to `_process_queue`
|
# We should see only *one* call to `_process_queue`
|
||||||
|
@ -75,12 +130,14 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
||||||
self.assertFalse(queue_d1.called)
|
self.assertFalse(queue_d1.called)
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=2)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to both.
|
# Return value of the `_process_queue` should be propagated back to both.
|
||||||
self._pending_calls.pop()[1].callback("bar")
|
self._pending_calls.pop()[1].callback("bar")
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
||||||
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||||
|
|
||||||
def test_queuing(self):
|
def test_queuing(self):
|
||||||
"""Test that we queue up requests while a `_process_queue` is being
|
"""Test that we queue up requests while a `_process_queue` is being
|
||||||
|
@ -92,13 +149,20 @@ class BatchingQueueTestCase(TestCase):
|
||||||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||||
self.clock.pump([0])
|
self.clock.pump([0])
|
||||||
|
|
||||||
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
|
|
||||||
|
# We queue up work after the process function has been called, testing
|
||||||
|
# that they get correctly queued up.
|
||||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||||
|
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))
|
||||||
|
|
||||||
# We should see only *one* call to `_process_queue`
|
# We should see only *one* call to `_process_queue`
|
||||||
self.assertEqual(len(self._pending_calls), 1)
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||||
self.assertFalse(queue_d1.called)
|
self.assertFalse(queue_d1.called)
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=2, keys=1, in_flight=3)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to the
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
# first.
|
# first.
|
||||||
|
@ -106,18 +170,24 @@ class BatchingQueueTestCase(TestCase):
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=2, keys=1, in_flight=2)
|
||||||
|
|
||||||
# We should now see a second call to `_process_queue`
|
# We should now see a second call to `_process_queue`
|
||||||
self.clock.pump([0])
|
self.clock.pump([0])
|
||||||
self.assertEqual(len(self._pending_calls), 1)
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
self.assertEqual(self._pending_calls[0][0], ["foo2"])
|
self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=2)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to the
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
# second.
|
# second.
|
||||||
self._pending_calls.pop()[1].callback("bar2")
|
self._pending_calls.pop()[1].callback("bar2")
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||||
|
self.assertEqual(self.successResultOf(queue_d3), "bar2")
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||||
|
|
||||||
def test_different_keys(self):
|
def test_different_keys(self):
|
||||||
"""Test that calls to different keys get processed in parallel."""
|
"""Test that calls to different keys get processed in parallel."""
|
||||||
|
@ -140,6 +210,7 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self.assertFalse(queue_d1.called)
|
self.assertFalse(queue_d1.called)
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
self.assertFalse(queue_d3.called)
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=1, keys=1, in_flight=3)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to the
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
# first.
|
# first.
|
||||||
|
@ -148,6 +219,7 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||||
self.assertFalse(queue_d2.called)
|
self.assertFalse(queue_d2.called)
|
||||||
self.assertFalse(queue_d3.called)
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=1, keys=1, in_flight=2)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to the
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
# second.
|
# second.
|
||||||
|
@ -161,9 +233,11 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self.assertEqual(len(self._pending_calls), 1)
|
self.assertEqual(len(self._pending_calls), 1)
|
||||||
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
||||||
self.assertFalse(queue_d3.called)
|
self.assertFalse(queue_d3.called)
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=1)
|
||||||
|
|
||||||
# Return value of the `_process_queue` should be propagated back to the
|
# Return value of the `_process_queue` should be propagated back to the
|
||||||
# third deferred.
|
# third deferred.
|
||||||
self._pending_calls.pop()[1].callback("bar4")
|
self._pending_calls.pop()[1].callback("bar4")
|
||||||
|
|
||||||
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
||||||
|
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||||
|
|
Loading…
Reference in New Issue