Additional type hints for the sync REST servlet. (#10666)

This commit is contained in:
Patrick Cloke 2021-08-23 08:14:42 -04:00 committed by GitHub
parent 2af6d31b78
commit bd7d398b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 61 deletions

1
changelog.d/10666.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View File

@ -30,6 +30,7 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import current_context from synapse.logging.context import current_context
@ -231,7 +232,7 @@ class SyncResult:
""" """
next_batch: StreamToken next_batch: StreamToken
presence: List[JsonDict] presence: List[UserPresenceState]
account_data: List[JsonDict] account_data: List[JsonDict]
joined: List[JoinedSyncResult] joined: List[JoinedSyncResult]
invited: List[InvitedSyncResult] invited: List[InvitedSyncResult]
@ -2177,14 +2178,14 @@ class SyncResultBuilder:
joined_room_ids: List of rooms the user is joined to joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response # The following mirror the fields in a sync response
presence (list) presence
account_data (list) account_data
joined (list[JoinedSyncResult]) joined
invited (list[InvitedSyncResult]) invited
knocked (list[KnockedSyncResult]) knocked
archived (list[ArchivedSyncResult]) archived
groups (GroupsSyncResult|None) groups
to_device (list) to_device
""" """
sync_config: SyncConfig sync_config: SyncConfig
@ -2193,7 +2194,7 @@ class SyncResultBuilder:
now_token: StreamToken now_token: StreamToken
joined_room_ids: FrozenSet[str] joined_room_ids: FrozenSet[str]
presence: List[JsonDict] = attr.Factory(list) presence: List[UserPresenceState] = attr.Factory(list)
account_data: List[JsonDict] = attr.Factory(list) account_data: List[JsonDict] = attr.Factory(list)
joined: List[JoinedSyncResult] = attr.Factory(list) joined: List[JoinedSyncResult] = attr.Factory(list)
invited: List[InvitedSyncResult] = attr.Factory(list) invited: List[InvitedSyncResult] = attr.Factory(list)

View File

@ -14,17 +14,26 @@
import itertools import itertools
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.events.utils import ( from synapse.events.utils import (
format_event_for_client_v2_without_room_id, format_event_for_client_v2_without_room_id,
format_event_raw, format_event_raw,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import KnockedSyncResult, SyncConfig from synapse.handlers.sync import (
ArchivedSyncResult,
InvitedSyncResult,
JoinedSyncResult,
KnockedSyncResult,
SyncConfig,
SyncResult,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
return 200, {} return 200, {}
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# We know that the the requester has an access token since appservices
# cannot use sync.
response_content = await self.encode_response( response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection time_now, sync_result, requester.access_token_id, filter_collection
) )
@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete") logger.debug("Event formatting complete")
return 200, response_content return 200, response_content
async def encode_response(self, time_now, sync_result, access_token_id, filter): async def encode_response(
self,
time_now: int,
sync_result: SyncResult,
access_token_id: Optional[int],
filter: FilterCollection,
) -> JsonDict:
logger.debug("Formatting events in sync response") logger.debug("Formatting events in sync response")
if filter.event_format == "client": if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id event_formatter = format_event_for_client_v2_without_room_id
@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
logger.debug("building sync response dict") logger.debug("building sync response dict")
response: dict = defaultdict(dict) response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store) response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data: if sync_result.account_data:
@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
if archived: if archived:
response["rooms"][Membership.LEAVE] = archived response["rooms"][Membership.LEAVE] = archived
# By the time we get here groups is no longer optional.
assert sync_result.groups is not None
if sync_result.groups.join: if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite: if sync_result.groups.invite:
@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
return response return response
@staticmethod @staticmethod
def encode_presence(events, time_now): def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return { return {
"events": [ "events": [
{ {
@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
} }
async def encode_joined( async def encode_joined(
self, rooms, time_now, token_id, event_fields, event_formatter self,
): rooms: List[JoinedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
Args: Args:
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync rooms: list of sync results for rooms this user is joined to
results for rooms this user is joined to time_now: current time - used as a baseline for age calculations
time_now(int): current time - used as a baseline for age token_id: ID of the user's auth token - used for namespacing
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields: List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format event_formatter: function to convert from federation format
to client format to client format
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our The joined rooms list, in our response format
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
return joined return joined
async def encode_invited(self, rooms, time_now, token_id, event_formatter): async def encode_invited(
self,
rooms: List[InvitedSyncResult],
time_now: int,
token_id: Optional[int],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
Args: Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of rooms: list of sync results for rooms this user is invited to
sync results for rooms this user is invited to time_now: current time - used as a baseline for age calculations
time_now(int): current time - used as a baseline for age token_id: ID of the user's auth token - used for namespacing
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_formatter (func[dict]): function to convert from federation format event_formatter: function to convert from federation format
to client format to client format
Returns: Returns:
dict[str, dict[str, object]]: the invited rooms list, in our The invited rooms list, in our response format
response format
""" """
invited = {} invited = {}
for room in rooms: for room in rooms:
@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
self, self,
rooms: List[KnockedSyncResult], rooms: List[KnockedSyncResult],
time_now: int, time_now: int,
token_id: int, token_id: Optional[int],
event_formatter: Callable[[Dict], Dict], event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, Dict[str, Any]]:
""" """
@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
return knocked return knocked
async def encode_archived( async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter self,
): rooms: List[ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
Args: Args:
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of rooms: list of sync results for rooms this user is joined to
sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations
time_now(int): current time - used as a baseline for age token_id: ID of the user's auth token - used for namespacing
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields: List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format event_formatter: function to convert from federation format to client format
to client format
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our The archived rooms list, in our response format
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
return joined return joined
async def encode_room( async def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter self,
): room: Union[JoinedSyncResult, ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
joined: bool,
only_fields: Optional[List[str]],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room: sync result for a single room
single room time_now: current time - used as a baseline for age calculations
time_now (int): current time - used as a baseline for age token_id: ID of the user's auth token - used for namespacing
calculations
token_id (int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
joined (bool): True if the user is joined to this room - will mean joined: True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include. only_fields: Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format event_formatter: function to convert from federation format
to client format to client format
Returns: Returns:
dict[str, object]: the room, encoded in our response format The room, encoded in our response format
""" """
def serialize(events): def serialize(events):
@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
account_data = room.account_data account_data = room.account_data
result = { result: JsonDict = {
"timeline": { "timeline": {
"events": serialized_timeline, "events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store), "prev_batch": await room.timeline.prev_batch.to_string(self.store),
@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
} }
if joined: if joined:
assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
return result return result
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)