Add remaining type hints to `synapse.events`. (#11098)

This commit is contained in:
Patrick Cloke 2021-11-02 09:55:52 -04:00 committed by GitHub
parent 4535532526
commit c01bc5f43d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 184 additions and 109 deletions

1
changelog.d/11098.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.events`.

View File

@ -22,13 +22,7 @@ files =
synapse/config,
synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
synapse/events/presence_router.py,
synapse/events/snapshot.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/utils.py,
synapse/events/validator.py,
synapse/events,
synapse/federation,
synapse/groups,
synapse/handlers,

View File

@ -16,8 +16,23 @@
import abc
import os
from typing import Dict, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Literal
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
if TYPE_CHECKING:
from synapse.events.builder import EventBuilder
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
# dict to frozen_dicts is expensive.
@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
class DictProperty:
T = TypeVar("T")
# DictProperty (and DefaultDictProperty) require the classes they're used with to
# have a _dict property to pull properties from.
#
# TODO _DictPropertyInstance should not include EventBuilder but due to
# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and
# DefaultDictProperty get applied to EventBuilder when it is in a Union with
# EventBase. This is the least invasive hack to get mypy to comply.
#
# Note that DictProperty/DefaultDictProperty cannot actually be used with
# EventBuilder as it lacks a _dict property.
_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
class DictProperty(Generic[T]):
"""An object property which delegates to the `_dict` within its parent object."""
__slots__ = ["key"]
@ -45,12 +79,33 @@ class DictProperty:
def __init__(self, key: str):
self.key = key
def __get__(self, instance, owner=None):
@overload
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DictProperty":
...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> Union[T, "DictProperty"]:
# if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value
if instance is None:
return self
try:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict[self.key]
except KeyError as e1:
# We want this to look like a regular attribute error (mostly so that
@ -65,10 +120,12 @@ class DictProperty:
"'%s' has no '%s' property" % (type(instance), self.key)
) from e1.__context__
def __set__(self, instance, v):
def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
instance._dict[self.key] = v
def __delete__(self, instance):
def __delete__(self, instance: _DictPropertyInstance) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
try:
del instance._dict[self.key]
except KeyError as e1:
@ -77,7 +134,7 @@ class DictProperty:
) from e1.__context__
class DefaultDictProperty(DictProperty):
class DefaultDictProperty(DictProperty, Generic[T]):
"""An extension of DictProperty which provides a default if the property is
not present in the parent's _dict.
@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty):
__slots__ = ["default"]
def __init__(self, key, default):
def __init__(self, key: str, default: T):
super().__init__(key)
self.default = default
def __get__(self, instance, owner=None):
@overload
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DefaultDictProperty":
...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> Union[T, "DefaultDictProperty"]:
if instance is None:
return self
assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict.get(self.key, self.default)
@ -111,22 +189,22 @@ class _EventInternalMetadata:
# in the DAG)
self.outlier = False
out_of_band_membership: bool = DictProperty("out_of_band_membership")
send_on_behalf_of: str = DictProperty("send_on_behalf_of")
recheck_redaction: bool = DictProperty("recheck_redaction")
soft_failed: bool = DictProperty("soft_failed")
proactively_send: bool = DictProperty("proactively_send")
redacted: bool = DictProperty("redacted")
txn_id: str = DictProperty("txn_id")
token_id: int = DictProperty("token_id")
historical: bool = DictProperty("historical")
out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
soft_failed: DictProperty[bool] = DictProperty("soft_failed")
proactively_send: DictProperty[bool] = DictProperty("proactively_send")
redacted: DictProperty[bool] = DictProperty("redacted")
txn_id: DictProperty[str] = DictProperty("txn_id")
token_id: DictProperty[int] = DictProperty("token_id")
historical: DictProperty[bool] = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
before: RoomStreamToken = DictProperty("before")
after: RoomStreamToken = DictProperty("after")
order: Tuple[int, int] = DictProperty("order")
before: DictProperty[RoomStreamToken] = DictProperty("before")
after: DictProperty[RoomStreamToken] = DictProperty("after")
order: DictProperty[Tuple[int, int]] = DictProperty("order")
def get_dict(self) -> JsonDict:
return dict(self._dict)
@ -162,9 +240,6 @@ class _EventInternalMetadata:
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
Returns:
bool
"""
return self._dict.get("recheck_redaction", False)
@ -176,32 +251,23 @@ class _EventInternalMetadata:
sent to clients.
2. They should not be added to the forward extremities (and
therefore not to current state).
Returns:
bool
"""
return self._dict.get("soft_failed", False)
def should_proactively_send(self):
def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and
servers.
This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.
Returns:
bool
"""
return self._dict.get("proactively_send", True)
def is_redacted(self):
def is_redacted(self) -> bool:
"""Whether the event has been redacted.
This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call.
Returns:
bool
"""
return self._dict.get("redacted", False)
@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta):
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = DictProperty("auth_events")
depth = DictProperty("depth")
content = DictProperty("content")
hashes = DictProperty("hashes")
origin = DictProperty("origin")
origin_server_ts = DictProperty("origin_server_ts")
prev_events = DictProperty("prev_events")
redacts = DefaultDictProperty("redacts", None)
room_id = DictProperty("room_id")
sender = DictProperty("sender")
state_key = DictProperty("state_key")
type = DictProperty("type")
user_id = DictProperty("sender")
depth: DictProperty[int] = DictProperty("depth")
content: DictProperty[JsonDict] = DictProperty("content")
hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
origin: DictProperty[str] = DictProperty("origin")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
room_id: DictProperty[str] = DictProperty("room_id")
sender: DictProperty[str] = DictProperty("sender")
# TODO state_key should be Optional[str], this is generally asserted in Synapse
# by calling is_state() first (which ensures this), but it is hard (not possible?)
# to properly annotate that calling is_state() asserts that state_key exists
# and is non-None.
state_key: DictProperty[str] = DictProperty("state_key")
type: DictProperty[str] = DictProperty("type")
user_id: DictProperty[str] = DictProperty("sender")
@property
def event_id(self) -> str:
raise NotImplementedError()
@property
def membership(self):
def membership(self) -> str:
return self.content["membership"]
def is_state(self):
def is_state(self) -> bool:
return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict:
@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Any:
return self._dict.get(key, default)
def get_internal_metadata_dict(self):
def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None) -> JsonDict:
def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
def __getitem__(self, field: str) -> Optional[Any]:
return self._dict[field]
def __contains__(self, field):
def __contains__(self, field: str) -> bool:
return field in self._dict
def items(self):
def items(self) -> List[Tuple[str, Optional[Any]]]:
return list(self._dict.items())
def keys(self):
def keys(self) -> Iterable[str]:
return self._dict.keys()
def prev_event_ids(self):
def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
The list of event IDs of this event's prev_events
"""
return [e for e, _ in self.prev_events]
return [e for e, _ in self._dict["prev_events"]]
def auth_event_ids(self):
def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
The list of event IDs of this event's auth_events
"""
return [e for e, _ in self.auth_events]
return [e for e, _ in self._dict["auth_events"]]
def freeze(self):
def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
def __str__(self):
def __str__(self) -> str:
return self.__repr__()
def __repr__(self):
def __repr__(self) -> str:
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
return (
@ -443,7 +508,7 @@ class FrozenEventV2(EventBase):
else:
frozen_dict = event_dict
self._event_id = None
self._event_id: Optional[str] = None
super().__init__(
frozen_dict,
@ -455,7 +520,7 @@ class FrozenEventV2(EventBase):
)
@property
def event_id(self):
def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@ -465,23 +530,23 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self):
def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
The list of event IDs of this event's prev_events
"""
return self.prev_events
return self._dict["prev_events"]
def auth_event_ids(self):
def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
The list of event IDs of this event's auth_events
"""
return self.auth_events
return self._dict["auth_events"]
class FrozenEventV3(FrozenEventV2):
@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2):
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
def event_id(self):
def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def _event_type_from_format_version(
format_version: int,
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the
given event format version.
Args:
format_version (int): The event format version
format_version: The event format version
Returns:
type: A type that can be initialized as per the initializer of

View File

@ -55,7 +55,7 @@ class EventValidator:
]
for k in required:
if not hasattr(event, k):
if k not in event:
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values

View File

@ -1643,7 +1643,7 @@ class FederationEventHandler:
event: the event whose auth_events we want
Returns:
all of the events in `event.auth_events`, after deduplication
all of the events listed in `event.auth_events_ids`, after deduplication
Raises:
AuthError if we were unable to fetch the auth_events for any reason.

View File

@ -1318,6 +1318,8 @@ class EventCreationHandler:
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
assert event.redacts is not None
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1413,6 +1415,8 @@ class EventCreationHandler:
)
if event.type == EventTypes.Redaction:
assert event.redacts is not None
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1500,6 +1504,8 @@ class EventCreationHandler:
next_batch_id = event.content.get(
EventContentFields.MSC2716_NEXT_BATCH_ID
)
conflicting_insertion_event_id = None
if next_batch_id:
conflicting_insertion_event_id = (
await self.store.get_insertion_event_by_batch_id(
event.room_id, next_batch_id

View File

@ -525,7 +525,7 @@ class RoomCreationHandler:
):
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
UserID.from_string(old_event.state_key),
new_room_id,
"ban",
ratelimit=False,

View File

@ -355,7 +355,7 @@ class RoomBatchHandler:
for (event, context) in reversed(events_to_persist):
await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service(
event["sender"], app_service_requester.app_service
event.sender, app_service_requester.app_service
),
event=event,
context=context,

View File

@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
# the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
auth_event_ids = (
list(previous_membership_event.auth_event_ids()) + prev_event_ids
)
event, context = await self.event_creation_handler.create_event(
requester,

View File

@ -232,6 +232,8 @@ class BulkPushRuleEvaluator:
# that user, as they might not be already joined.
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
if not isinstance(display_name, str):
display_name = None
if count_as_unread:
# Add an element for the current user if the event needs to be marked as
@ -268,7 +270,7 @@ def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
display_name: str,
display_name: Optional[str],
cache: Dict[str, bool],
) -> bool:
for cond in conditions:

View File

@ -18,7 +18,7 @@ import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache
@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent:
self._value_cache = _flatten_dict(event)
def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str
self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent:
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name: str) -> bool:
def _contains_display_name(self, display_name: Optional[str]) -> bool:
if not display_name:
return False
@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
def _flatten_dict(
d: Union[EventBase, dict],
d: Union[EventBase, JsonDict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
@ -233,7 +233,7 @@ def _flatten_dict(
for key, value in d.items():
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
elif isinstance(value, dict):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result

View File

@ -191,7 +191,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
batch_id_to_connect_to = base_insertion_event["content"][
batch_id_to_connect_to = base_insertion_event.content[
EventContentFields.MSC2716_NEXT_BATCH_ID
]

View File

@ -247,7 +247,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str]
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids

View File

@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@ -1794,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
for prev_event_id in event.prev_events:
for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",

View File

@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
):
) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state(self, room_id, state_entry):
async def get_joined_users_from_state(
self, room_id, state_entry
) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
):
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.