Fix make_deferred_yieldable to work with coroutines

This commit is contained in:
Erik Johnston 2019-12-10 11:22:12 +00:00
parent 0f3614f0f6
commit 9a2223d4c8
2 changed files with 32 additions and 1 deletions

View File

@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect
import logging import logging
import threading import threading
import types import types
@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred): def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules: """Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
If the deferred has completed (or is not actually a Deferred), essentially If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the does nothing (just returns another completed deferred with the
@ -624,6 +626,11 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.) (This is more-or-less the opposite operation to run_in_background.)
""" """
if inspect.isawaitable(deferred):
# If we're given a coroutine we need to convert it to a deferred so that
# we can attach callbacks (and not immediately return).
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred): if not isinstance(deferred, defer.Deferred):
return deferred return deferred

View File

@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar") nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar") self.assertEqual(nested_context.request, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
# an async function which retuns an incomplete coroutine, but doesn't
# follow the synapse rules.
async def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
await d
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
yield d1
# now it should be restored
self._check_test_key("one")
# a function which returns a deferred which has been "called", but # a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on # which had a function which returned another incomplete deferred on