diff --git a/changelog.d/10728.misc b/changelog.d/10728.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10728.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fb5ad2906e..aefaaa8ae8 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -16,9 +16,11 @@ import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from urllib.parse import urlparse +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -28,15 +30,17 @@ from synapse.api.errors import ( ) from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed, validate_email @@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=self.config.email_password_reset_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. + requester = None if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) try: @@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise user_id = requester.user.to_string() else: - requester = None try: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], @@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise @@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet): # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: - password_hash = await self.auth_handler.hash(new_password) + password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None @@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config @@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=self.config.email_add_threepid_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): "/add_threepid/email/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() @@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config.email_add_threepid_template_failure_html ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> None: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet): class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet): class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ @@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} -def assert_valid_next_link(hs: "HomeServer", next_link: str): +def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: """ Raises a SynapseError if a given next_link value is invalid @@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str): class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) response = {"user_id": requester.user.to_string()} @@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 68fb08d0ba..0152a0c66a 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + def on_PUT( + self, request: Request, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a28acd4041..8f3dd2a101 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -375,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): unstable=True, ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -390,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) if not self.hs.config.enable_registration: @@ -730,7 +726,11 @@ class RegisterRestServlet(RestServlet): return 200, return_dict async def _do_appservice_registration( - self, username, as_token, body, should_issue_refresh_token: bool = False + self, + username: str, + as_token: str, + body: JsonDict, + should_issue_refresh_token: bool = False, ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 07ea39a8a3..d4a4adb50c 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -14,26 +14,35 @@ import logging from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() - async def on_POST(self, request, room_id, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 3172aba605..ed96978448 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -14,10 +14,14 @@ import logging import re +from typing import TYPE_CHECKING, Awaitable, List, Tuple + +from twisted.web.server import Request from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import AuthError, Codes, SynapseError from synapse.appservice import ApplicationService +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,10 +29,14 @@ from synapse.http.servlet import ( parse_string, parse_strings_from_args, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet): ), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet): self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def _create_insertion_event_dict( self, sender: str, room_id: str, origin_server_ts: int - ): + ) -> JsonDict: """Creates an event dict for an "insertion" event with the proper fields and a random chunk ID. @@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet): origin_server_ts: Timestamp when the event was sent Returns: - Tuple of event ID and stream ordering position + The new event dictionary to insert. """ next_chunk_id = random_string(8) @@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet): return create_requester(user_id, app_service=app_service) - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) if not requester.app_service: @@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["state_events_at_start", "events"]) + assert request.args is not None prev_events_from_query = parse_strings_from_args(request.args, "prev_event") chunk_id_from_query = parse_string(request, "chunk_id") @@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet): ], } - def on_GET(self, request, room_id): + def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: return 501, "Not implemented" - def on_PUT(self, request, room_id): + def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled if msc2716_enabled: diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 263596be86..37e39570f6 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,16 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet): "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_PUT(self, request, room_id, session_id): + async def on_PUT( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet): ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - async def on_GET(self, request, room_id, session_id): + async def on_GET( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - async def on_DELETE(self, request, room_id, session_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_GET(self, request, version): + async def on_GET( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet): raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - async def on_DELETE(self, request, version): + async def on_DELETE( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet): await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - async def on_PUT(self, request, version): + async def on_PUT( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomKeysServlet(hs).register(http_server) RoomKeysVersionServlet(hs).register(http_server) RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index d537d811d8..3322c8ef48 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -13,15 +13,21 @@ # limitations under the License. import logging -from typing import Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from synapse.http import servlet +from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.device_message_handler = hs.get_device_message_handler() @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): + def on_PUT( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id ) - async def _put(self, request, message_type, txn_id): + async def _put( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response: Tuple[int, dict] = (200, {}) - return response + return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 65c37be3e9..1259058b9b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,12 +14,24 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.presence import UserPresenceState +from synapse.events import EventBase from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, @@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events): + def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: return self._event_serializer.serialize_events( events, time_now=time_now,