From d0fc1e904a3060b0f459be9aa7df9b9f1501e294 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Nov 2024 15:26:14 +0000 Subject: [PATCH] Fix cancellation tests with new Twisted. (#17906) The latest Twisted release changed how they implemented `__await__` on deferreds, which broke the machinery we used to test cancellation. This PR changes things a bit to instead patch the `__await__` method, which is a stable API. This mostly doesn't change the core logic, except for fixing two bugs: - We previously did not intercept all await points - After cancellation we now need to not only unblock currently blocked await points, but also make sure we don't block any future await points. c.f. https://github.com/twisted/twisted/pull/12226 --------- Co-authored-by: Devon Hudson --- changelog.d/17906.bugfix | 1 + tests/http/server/_base.py | 107 ++++++++++++++++++++++++++++--------- 2 files changed, 84 insertions(+), 24 deletions(-) create mode 100644 changelog.d/17906.bugfix diff --git a/changelog.d/17906.bugfix b/changelog.d/17906.bugfix new file mode 100644 index 0000000000..f38ce6a590 --- /dev/null +++ b/changelog.d/17906.bugfix @@ -0,0 +1 @@ +Fix tests to run with latest Twisted. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 731b0c4e59..dff5a5d262 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -27,6 +27,7 @@ from typing import ( Callable, ContextManager, Dict, + Generator, List, Optional, Set, @@ -49,7 +50,10 @@ from synapse.http.server import ( respond_with_json, ) from synapse.http.site import SynapseRequest -from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.logging.context import ( + LoggingContext, + make_deferred_yieldable, +) from synapse.types import JsonDict from tests.server import FakeChannel, make_request @@ -199,7 +203,7 @@ def make_request_with_cancellation_test( # # We would like to trigger a cancellation at the first `await`, re-run the # request and cancel at the second `await`, and so on. By patching - # `Deferred.__next__`, we can intercept `await`s, track which ones we have or + # `Deferred.__await__`, we can intercept `await`s, track which ones we have or # have not seen, and force them to block when they wouldn't have. # The set of previously seen `await`s. @@ -211,7 +215,7 @@ def make_request_with_cancellation_test( ) for request_number in itertools.count(1): - deferred_patch = Deferred__next__Patch(seen_awaits, request_number) + deferred_patch = Deferred__await__Patch(seen_awaits, request_number) try: with mock.patch( @@ -250,6 +254,8 @@ def make_request_with_cancellation_test( ) if respond_mock.called: + _log_for_request(request_number, "--- response finished ---") + # The request ran to completion and we are done with testing it. # `respond_with_json` writes the response asynchronously, so we @@ -311,8 +317,8 @@ def make_request_with_cancellation_test( assert False, "unreachable" # noqa: B011 -class Deferred__next__Patch: - """A `Deferred.__next__` patch that will intercept `await`s and force them +class Deferred__await__Patch: + """A `Deferred.__await__` patch that will intercept `await`s and force them to block once it sees a new `await`. When done with the patch, `unblock_awaits()` must be called to clean up after any @@ -322,7 +328,7 @@ class Deferred__next__Patch: Usage: seen_awaits = set() - deferred_patch = Deferred__next__Patch(seen_awaits, 1) + deferred_patch = Deferred__await__Patch(seen_awaits, 1) try: with deferred_patch.patch(): # do things @@ -335,14 +341,14 @@ class Deferred__next__Patch: """ Args: seen_awaits: The set of stack traces of `await`s that have been previously - seen. When the `Deferred.__next__` patch sees a new `await`, it will add + seen. When the `Deferred.__await__` patch sees a new `await`, it will add it to the set. request_number: The request number to log against. """ self._request_number = request_number self._seen_awaits = seen_awaits - self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore] + self._original_Deferred__await__ = Deferred.__await__ # type: ignore[misc,unused-ignore] # The number of `await`s on `Deferred`s we have seen so far. self.awaits_seen = 0 @@ -350,8 +356,13 @@ class Deferred__next__Patch: # Whether we have seen a new `await` not in `seen_awaits`. self.new_await_seen = False + # Whether to block new await points we see. This gets set to False once + # we have cancelled the request to allow things to run after + # cancellation. + self._block_new_awaits = True + # To force `await`s on resolved `Deferred`s to block, we make up a new - # unresolved `Deferred` and return it out of `Deferred.__next__` / + # unresolved `Deferred` and return it out of `Deferred.__await__` / # `coroutine.send()`. We have to resolve it later, in case the `await`ing # coroutine is part of some shared processing, such as `@cached`. self._to_unblock: Dict[Deferred, Union[object, Failure]] = {} @@ -360,15 +371,15 @@ class Deferred__next__Patch: self._previous_stack: List[inspect.FrameInfo] = [] def patch(self) -> ContextManager[Mock]: - """Returns a context manager which patches `Deferred.__next__`.""" + """Returns a context manager which patches `Deferred.__await__`.""" - def Deferred___next__( - deferred: "Deferred[T]", value: object = None - ) -> "Deferred[T]": - """Intercepts `await`s on `Deferred`s and rigs them to block once we have - seen enough of them. + def Deferred___await__( + deferred: "Deferred[T]", + ) -> Generator["Deferred[T]", None, T]: + """Intercepts calls to `__await__`, which returns a generator + yielding deferreds that we await on. - `Deferred.__next__` will normally: + The generator for `__await__` will normally: * return `self` if the `Deferred` is unresolved, in which case `coroutine.send()` will return the `Deferred`, and `_defer.inlineCallbacks` will stop running the coroutine until the @@ -376,9 +387,43 @@ class Deferred__next__Patch: * raise a `StopIteration(result)`, containing the result of the `await`. * raise another exception, which will come out of the `await`. """ + + # Get the original generator. + gen = self._original_Deferred__await__(deferred) + + # Run the generator, handling each iteration to see if we need to + # block. + try: + while True: + # We've hit a new await point (or the deferred has + # completed), handle it. + handle_next_iteration(deferred) + + # Continue on. + yield gen.send(None) + except StopIteration as e: + # We need to convert `StopIteration` into a normal return. + return e.value + + def handle_next_iteration( + deferred: "Deferred[T]", + ) -> None: + """Intercepts `await`s on `Deferred`s and rigs them to block once we have + seen enough of them. + + Args: + deferred: The deferred that we've captured and are intercepting + `await` calls within. + """ + if not self._block_new_awaits: + # We're no longer blocking awaits points + return + self.awaits_seen += 1 - stack = _get_stack(skip_frames=1) + stack = _get_stack( + skip_frames=2 # Ignore this function and `Deferred___await__` in stack trace + ) stack_hash = _hash_stack(stack) if stack_hash not in self._seen_awaits: @@ -389,20 +434,29 @@ class Deferred__next__Patch: if not self.new_await_seen: # This `await` isn't interesting. Let it proceed normally. + _log_await_stack( + stack, + self._previous_stack, + self._request_number, + "already seen", + ) + # Don't log the stack. It's been seen before in a previous run. self._previous_stack = stack - return self._original_Deferred___next__(deferred, value) + return # We want to block at the current `await`. if deferred.called and not deferred.paused: - # This `Deferred` already has a result. - # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait - # on. This blocks the coroutine that did this `await`. + # This `Deferred` already has a result. We chain a new, + # unresolved, `Deferred` to the end of this Deferred that it + # will wait on. This blocks the coroutine that did this `await`. # We queue it up for unblocking later. new_deferred: "Deferred[T]" = Deferred() self._to_unblock[new_deferred] = deferred.result + deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred)) + _log_await_stack( stack, self._previous_stack, @@ -411,7 +465,9 @@ class Deferred__next__Patch: ) self._previous_stack = stack - return make_deferred_yieldable(new_deferred) + # Continue iterating on the deferred now that we've blocked it + # again. + return # This `Deferred` does not have a result yet. # The `await` will block normally, so we don't have to do anything. @@ -423,9 +479,9 @@ class Deferred__next__Patch: ) self._previous_stack = stack - return self._original_Deferred___next__(deferred, value) + return - return mock.patch.object(Deferred, "__next__", new=Deferred___next__) + return mock.patch.object(Deferred, "__await__", new=Deferred___await__) def unblock_awaits(self) -> None: """Unblocks any shared processing that we forced to block. @@ -433,6 +489,9 @@ class Deferred__next__Patch: Must be called when done, otherwise processing shared between multiple requests, such as database queries started by `@cached`, will become permanently stuck. """ + # Also disable blocking at future await points + self._block_new_awaits = False + to_unblock = self._to_unblock self._to_unblock = {} for deferred, result in to_unblock.items():