Trace functions which return `Awaitable` (#15650)
This commit is contained in:
parent
4e6390cb10
commit
8bfded81f3
|
@ -0,0 +1 @@
|
|||
Add support for tracing functions which return `Awaitable`s.
|
|
@ -171,6 +171,7 @@ from functools import wraps
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
ContextManager,
|
||||
|
@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
|
|||
"""
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
# For this branch, we handle async functions like `async def func() -> RInner`.
|
||||
# In this branch, R = Awaitable[RInner], for some other type RInner
|
||||
@wraps(func)
|
||||
async def _wrapper(
|
||||
|
@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
|
|||
return await func(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
else:
|
||||
# The other case here handles both sync functions and those
|
||||
# decorated with inlineDeferred.
|
||||
# The other case here handles sync functions including those decorated with
|
||||
# `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
|
||||
@wraps(func)
|
||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||
scope = wrapping_logic(func, *args, **kwargs)
|
||||
scope.__enter__()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if isinstance(result, defer.Deferred):
|
||||
|
||||
def call_back(result: R) -> R:
|
||||
|
@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
|
|||
return result
|
||||
|
||||
def err_back(result: R) -> R:
|
||||
# TODO: Pass the error details into `scope.__exit__(...)` for
|
||||
# consistency with the other paths.
|
||||
scope.__exit__(None, None, None)
|
||||
return result
|
||||
|
||||
result.addCallbacks(call_back, err_back)
|
||||
|
||||
else:
|
||||
if inspect.isawaitable(result):
|
||||
logger.error(
|
||||
"@trace may not have wrapped %s correctly! "
|
||||
"The function is not async but returned a %s.",
|
||||
func.__qualname__,
|
||||
type(result).__name__,
|
||||
)
|
||||
elif inspect.isawaitable(result):
|
||||
|
||||
async def wrap_awaitable() -> Any:
|
||||
try:
|
||||
assert isinstance(result, Awaitable)
|
||||
awaited_result = await result
|
||||
scope.__exit__(None, None, None)
|
||||
return awaited_result
|
||||
except Exception as e:
|
||||
scope.__exit__(type(e), None, e.__traceback__)
|
||||
raise
|
||||
|
||||
# The original method returned an awaitable, eg. a coroutine, so we
|
||||
# create another awaitable wrapping it that calls
|
||||
# `scope.__exit__(...)`.
|
||||
return wrap_awaitable()
|
||||
else:
|
||||
# Just a simple sync function so we can just exit the scope and
|
||||
# return the result without any fuss.
|
||||
scope.__exit__(None, None, None)
|
||||
|
||||
return result
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import cast
|
||||
from typing import Awaitable, cast
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
|||
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||
with functions that return deferreds
|
||||
"""
|
||||
reactor = MemoryReactorClock()
|
||||
|
||||
with LoggingContext("root context"):
|
||||
|
||||
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
|
||||
|
@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
|||
|
||||
result_d1 = fixture_deferred_func()
|
||||
|
||||
# let the tasks complete
|
||||
reactor.pump((2,) * 8)
|
||||
|
||||
self.assertEqual(self.successResultOf(result_d1), "foo")
|
||||
|
||||
# the span should have been reported
|
||||
|
@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
|||
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||
with async functions
|
||||
"""
|
||||
reactor = MemoryReactorClock()
|
||||
|
||||
with LoggingContext("root context"):
|
||||
|
||||
@trace_with_opname("fixture_async_func", tracer=self._tracer)
|
||||
|
@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
|
|||
|
||||
d1 = defer.ensureDeferred(fixture_async_func())
|
||||
|
||||
# let the tasks complete
|
||||
reactor.pump((2,) * 8)
|
||||
|
||||
self.assertEqual(self.successResultOf(d1), "foo")
|
||||
|
||||
# the span should have been reported
|
||||
|
@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
|
|||
[span.operation_name for span in self._reporter.get_spans()],
|
||||
["fixture_async_func"],
|
||||
)
|
||||
|
||||
def test_trace_decorator_awaitable_return(self) -> None:
|
||||
"""
|
||||
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
|
||||
with functions that return an awaitable (e.g. a coroutine)
|
||||
"""
|
||||
with LoggingContext("root context"):
|
||||
# Something we can return without `await` to get a coroutine
|
||||
async def fixture_async_func() -> str:
|
||||
return "foo"
|
||||
|
||||
# The actual kind of function we want to test that returns an awaitable
|
||||
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
|
||||
@tag_args
|
||||
def fixture_awaitable_return_func() -> Awaitable[str]:
|
||||
return fixture_async_func()
|
||||
|
||||
# Something we can run with `defer.ensureDeferred(runner())` and pump the
|
||||
# whole async tasks through to completion.
|
||||
async def runner() -> str:
|
||||
return await fixture_awaitable_return_func()
|
||||
|
||||
d1 = defer.ensureDeferred(runner())
|
||||
|
||||
self.assertEqual(self.successResultOf(d1), "foo")
|
||||
|
||||
# the span should have been reported
|
||||
self.assertEqual(
|
||||
[span.operation_name for span in self._reporter.get_spans()],
|
||||
["fixture_awaitable_return_func"],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue