Improve type hints for cached decorator. (#15658)
The cached decorators always return a Deferred, which was not properly propagated. It was close enough when wrapping coroutines, but failed if a bare function was wrapped.
This commit is contained in:
parent
379eb2d7ab
commit
1f55c04cbc
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
|
||||||
|
|
||||||
from typing import Callable, Optional, Type
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
|
from mypy.erasetype import remove_instance_last_known_values
|
||||||
from mypy.nodes import ARG_NAMED_OPT
|
from mypy.nodes import ARG_NAMED_OPT
|
||||||
from mypy.plugin import MethodSigContext, Plugin
|
from mypy.plugin import MethodSigContext, Plugin
|
||||||
from mypy.typeops import bind_self
|
from mypy.typeops import bind_self
|
||||||
from mypy.types import CallableType, NoneType, UnionType
|
from mypy.types import CallableType, Instance, NoneType, UnionType
|
||||||
|
|
||||||
|
|
||||||
class SynapsePlugin(Plugin):
|
class SynapsePlugin(Plugin):
|
||||||
|
@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||||
arg_names.append("on_invalidate")
|
arg_names.append("on_invalidate")
|
||||||
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
||||||
|
|
||||||
|
# Finally we ensure the return type is a Deferred.
|
||||||
|
if (
|
||||||
|
isinstance(signature.ret_type, Instance)
|
||||||
|
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
|
||||||
|
):
|
||||||
|
# If it is already a Deferred, nothing to do.
|
||||||
|
ret_type = signature.ret_type
|
||||||
|
else:
|
||||||
|
ret_arg = None
|
||||||
|
if isinstance(signature.ret_type, Instance):
|
||||||
|
# If a coroutine, wrap the coroutine's return type in a Deferred.
|
||||||
|
if signature.ret_type.type.fullname == "typing.Coroutine":
|
||||||
|
ret_arg = signature.ret_type.args[2]
|
||||||
|
|
||||||
|
# If an awaitable, wrap the awaitable's final value in a Deferred.
|
||||||
|
elif signature.ret_type.type.fullname == "typing.Awaitable":
|
||||||
|
ret_arg = signature.ret_type.args[0]
|
||||||
|
|
||||||
|
# Otherwise, wrap the return value in a Deferred.
|
||||||
|
if ret_arg is None:
|
||||||
|
ret_arg = signature.ret_type
|
||||||
|
|
||||||
|
# This should be able to use ctx.api.named_generic_type, but that doesn't seem
|
||||||
|
# to find the correct symbol for anything more than 1 module deep.
|
||||||
|
#
|
||||||
|
# modules is not part of CheckerPluginInterface. The following is a combination
|
||||||
|
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
|
||||||
|
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
|
||||||
|
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
|
||||||
|
|
||||||
signature = signature.copy_modified(
|
signature = signature.copy_modified(
|
||||||
arg_types=arg_types,
|
arg_types=arg_types,
|
||||||
arg_names=arg_names,
|
arg_names=arg_names,
|
||||||
arg_kinds=arg_kinds,
|
arg_kinds=arg_kinds,
|
||||||
|
ret_type=ret_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return signature
|
return signature
|
||||||
|
|
|
@ -1099,7 +1099,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
# `get_joined_hosts` is called with the "current" state group for the
|
# `get_joined_hosts` is called with the "current" state group for the
|
||||||
# room, and so consecutive calls will be for consecutive state groups
|
# room, and so consecutive calls will be for consecutive state groups
|
||||||
# which point to the previous state group.
|
# which point to the previous state group.
|
||||||
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
|
cache = await self._get_joined_hosts_cache(room_id)
|
||||||
|
|
||||||
# If the state group in the cache matches, we already have the data we need.
|
# If the state group in the cache matches, we already have the data we need.
|
||||||
if state_entry.state_group == cache.state_group:
|
if state_entry.state_group == cache.state_group:
|
||||||
|
|
|
@ -220,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
self.prune_unread_entries = prune_unread_entries
|
self.prune_unread_entries = prune_unread_entries
|
||||||
|
|
||||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
def __get__(
|
||||||
|
self, obj: Optional[Any], owner: Optional[Type]
|
||||||
|
) -> Callable[..., "defer.Deferred[Any]"]:
|
||||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
|
@ -232,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
get_cache_key = self.cache_key_builder
|
get_cache_key = self.cache_key_builder
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
from typing import Generator
|
from typing import Any, Generator
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -49,15 +49,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_match(
|
def test_regex_user_id_prefix_match(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -65,15 +63,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_no_match(
|
def test_regex_user_id_prefix_no_match(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -81,17 +77,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_member_is_checked(
|
def test_regex_room_member_is_checked(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
self.event.sender = "@someone_else:matrix.org"
|
self.event.sender = "@someone_else:matrix.org"
|
||||||
self.event.type = "m.room.member"
|
self.event.type = "m.room.member"
|
||||||
self.event.state_key = "@irc_foobar:matrix.org"
|
self.event.state_key = "@irc_foobar:matrix.org"
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -99,17 +93,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_id_match(
|
def test_regex_room_id_match(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||||
)
|
)
|
||||||
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -117,25 +109,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_room_id_no_match(
|
def test_regex_room_id_no_match(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||||
_regex("!some_prefix.*some_suffix:matrix.org")
|
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||||
)
|
)
|
||||||
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_alias_match(
|
def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self,
|
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
|
@ -145,10 +133,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
self.store.get_local_users_in_room = simple_async_mock([])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -192,7 +178,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_alias_no_match(
|
def test_regex_alias_no_match(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
|
@ -213,7 +199,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_multiple_matches(
|
def test_regex_multiple_matches(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
|
@ -223,18 +209,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
self.store.get_local_users_in_room = simple_async_mock([])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_interested_in_self(
|
def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self,
|
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
|
||||||
# make sure invites get through
|
# make sure invites get through
|
||||||
self.service.sender = "@appservice:name"
|
self.service.sender = "@appservice:name"
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
|
@ -243,18 +225,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.event.state_key = self.service.sender
|
self.event.state_key = self.service.sender
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_member_list_match(
|
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self,
|
|
||||||
) -> Generator["defer.Deferred[object]", object, None]:
|
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
# Note that @irc_fo:here is the AS user.
|
# Note that @irc_fo:here is the AS user.
|
||||||
self.store.get_local_users_in_room = simple_async_mock(
|
self.store.get_local_users_in_room = simple_async_mock(
|
||||||
|
@ -265,10 +243,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield self.service.is_interested_in_event(
|
||||||
self.service.is_interested_in_event(
|
self.event.event_id, self.event, self.store
|
||||||
self.event.event_id, self.event, self.store
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,15 +33,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
|
||||||
destination retries, as well as testing tht we can set and get
|
destination retries, as well as testing tht we can set and get
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
d = self.store.get_destination_retry_timings("example.com")
|
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
|
||||||
r = self.get_success(d)
|
|
||||||
self.assertIsNone(r)
|
self.assertIsNone(r)
|
||||||
|
|
||||||
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
|
self.get_success(
|
||||||
self.get_success(d)
|
self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
|
||||||
|
)
|
||||||
|
|
||||||
d = self.store.get_destination_retry_timings("example.com")
|
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
|
||||||
r = self.get_success(d)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
DestinationRetryTimings(
|
DestinationRetryTimings(
|
||||||
|
|
Loading…
Reference in New Issue