Update `delay_cancellation` to accept any awaitable (#12468)
This will mainly be useful when dealing with module callbacks, which are all typed as returning `Awaitable`s instead of coroutines or `Deferred`s. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
b82fff66df
commit
a50fb411b3
|
@ -0,0 +1 @@
|
||||||
|
Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.
|
|
@ -41,7 +41,6 @@ from prometheus_client import Histogram
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
|
@ -794,7 +793,7 @@ class DatabasePool:
|
||||||
# We also wait until everything above is done before releasing the
|
# We also wait until everything above is done before releasing the
|
||||||
# `CancelledError`, so that logging contexts won't get used after they have been
|
# `CancelledError`, so that logging contexts won't get used after they have been
|
||||||
# finished.
|
# finished.
|
||||||
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
|
return await delay_cancellation(_runInteraction())
|
||||||
|
|
||||||
async def runWithConnection(
|
async def runWithConnection(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -25,6 +26,7 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
|
@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||||
return new_deferred
|
return new_deferred
|
||||||
|
|
||||||
|
|
||||||
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
@overload
|
||||||
"""Delay cancellation of a `Deferred` until it resolves.
|
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
||||||
|
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
|
||||||
|
|
||||||
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
||||||
resolve with a `CancelledError` until the original `Deferred` resolves.
|
resolve with a `CancelledError` until the original awaitable resolves.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
deferred: The `Deferred` to protect against cancellation. May optionally follow
|
deferred: The coroutine or `Deferred` to protect against cancellation. May
|
||||||
the Synapse logcontext rules.
|
optionally follow the Synapse logcontext rules.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new `Deferred`, which will contain the result of the original `Deferred`.
|
A new `Deferred`, which will contain the result of the original coroutine or
|
||||||
The new `Deferred` will not propagate cancellation through to the original.
|
`Deferred`. The new `Deferred` will not propagate cancellation through to the
|
||||||
When cancelled, the new `Deferred` will wait until the original `Deferred`
|
original coroutine or `Deferred`.
|
||||||
resolves before failing with a `CancelledError`.
|
|
||||||
|
|
||||||
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
|
When cancelled, the new `Deferred` will wait until the original coroutine or
|
||||||
|
`Deferred` resolves before failing with a `CancelledError`.
|
||||||
|
|
||||||
|
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
|
||||||
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
|
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
|
||||||
wrapped with `make_deferred_yieldable`.
|
wrapped with `make_deferred_yieldable`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# First, convert the awaitable into a `Deferred`.
|
||||||
|
if isinstance(awaitable, defer.Deferred):
|
||||||
|
deferred = awaitable
|
||||||
|
elif asyncio.iscoroutine(awaitable):
|
||||||
|
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
|
||||||
|
# type-checking, but we'd need Twisted >= 21.2.
|
||||||
|
deferred = defer.ensureDeferred(awaitable)
|
||||||
|
else:
|
||||||
|
# We have no idea what to do with this awaitable.
|
||||||
|
# We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
|
||||||
|
# `make_awaitable`, and let the caller `await` it normally.
|
||||||
|
return awaitable
|
||||||
|
|
||||||
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
|
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
|
||||||
# before the new deferred is cancelled, we `pause` it to stop the cancellation
|
# before the new deferred is cancelled, we `pause` it to stop the cancellation
|
||||||
# propagating. we then `unpause` it once the wrapped deferred completes, to
|
# propagating. we then `unpause` it once the wrapped deferred completes, to
|
||||||
|
|
|
@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
|
||||||
class DelayCancellationTests(TestCase):
|
class DelayCancellationTests(TestCase):
|
||||||
"""Tests for the `delay_cancellation` function."""
|
"""Tests for the `delay_cancellation` function."""
|
||||||
|
|
||||||
def test_cancellation(self):
|
def test_deferred_cancellation(self):
|
||||||
"""Test that cancellation of the new `Deferred` waits for the original."""
|
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||||
deferred: "Deferred[str]" = Deferred()
|
deferred: "Deferred[str]" = Deferred()
|
||||||
wrapper_deferred = delay_cancellation(deferred)
|
wrapper_deferred = delay_cancellation(deferred)
|
||||||
|
@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
|
||||||
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
|
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
|
||||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||||
|
|
||||||
|
def test_coroutine_cancellation(self):
|
||||||
|
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||||
|
blocking_deferred: "Deferred[None]" = Deferred()
|
||||||
|
completion_deferred: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
async def task():
|
||||||
|
await blocking_deferred
|
||||||
|
completion_deferred.callback(None)
|
||||||
|
# Raise an exception. Twisted should consume it, otherwise unwanted
|
||||||
|
# tracebacks will be printed in logs.
|
||||||
|
raise ValueError("abc")
|
||||||
|
|
||||||
|
wrapper_deferred = delay_cancellation(task())
|
||||||
|
|
||||||
|
# Cancel the new `Deferred`.
|
||||||
|
wrapper_deferred.cancel()
|
||||||
|
self.assertNoResult(wrapper_deferred)
|
||||||
|
self.assertFalse(
|
||||||
|
blocking_deferred.called, "Cancellation was propagated too deep"
|
||||||
|
)
|
||||||
|
self.assertFalse(completion_deferred.called)
|
||||||
|
|
||||||
|
# Unblock the task.
|
||||||
|
blocking_deferred.callback(None)
|
||||||
|
self.assertTrue(completion_deferred.called)
|
||||||
|
|
||||||
|
# Now that the original coroutine has failed, we should get a `CancelledError`.
|
||||||
|
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||||
|
|
||||||
def test_suppresses_second_cancellation(self):
|
def test_suppresses_second_cancellation(self):
|
||||||
"""Test that a second cancellation is suppressed.
|
"""Test that a second cancellation is suppressed.
|
||||||
|
|
||||||
|
@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
|
||||||
async def outer():
|
async def outer():
|
||||||
with LoggingContext("c") as c:
|
with LoggingContext("c") as c:
|
||||||
try:
|
try:
|
||||||
await delay_cancellation(defer.ensureDeferred(inner()))
|
await delay_cancellation(inner())
|
||||||
self.fail("`CancelledError` was not raised")
|
self.fail("`CancelledError` was not raised")
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
self.assertEqual(c, current_context())
|
self.assertEqual(c, current_context())
|
||||||
|
|
Loading…
Reference in New Issue