Add missing type hints to synapse.replication.http. (#11856)

This commit is contained in:
Patrick Cloke 2022-02-08 07:44:39 -05:00 committed by GitHub
parent 8b309adb43
commit 63d90f10ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 258 additions and 162 deletions

1
changelog.d/11856.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to replication code.

View File

@ -40,7 +40,7 @@ class ReplicationRestResource(JsonResource):
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs: "HomeServer"):
def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)

View File

@ -15,16 +15,20 @@
import abc
import logging
import re
import urllib
import urllib.parse
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from prometheus_client import Counter, Gauge
from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None:
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
raise RuntimeError("Missing Authorization header.")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod
async def _serialize_payload(**kwargs):
async def _serialize_payload(**kwargs) -> JsonDict:
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return {}
@abc.abstractmethod
async def _handle_request(self, request, **kwargs):
async def _handle_request(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
tuple[int, dict]: HTTP status code and a JSON serialisable dict
to be used as response body of request.
HTTP status code and a JSON serialisable dict to be used as response
body of request.
"""
pass
@classmethod
def make_client(cls, hs: "HomeServer"):
def make_client(cls, hs: "HomeServer") -> Callable:
"""Create a client that makes requests.
Returns a callable that accepts the same parameters as
@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
)
@trace(opname="outgoing_replication_request")
async def send_request(*, instance_name="master", **kwargs):
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return send_request
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the
appropriate path.
"""
@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__,
)
async def _check_auth_and_handle(self, request, **kwargs):
async def _check_auth_and_handle(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.

View File

@ -13,10 +13,14 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -48,14 +52,18 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user(
@ -89,14 +97,18 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room(
@ -130,14 +142,18 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, tag: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room(
@ -173,11 +189,13 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag):
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
room_id,
@ -187,7 +205,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)

View File

@ -13,9 +13,13 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -63,14 +67,16 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id):
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)

View File

@ -13,17 +13,22 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from twisted.web.server import Request
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
async def _serialize_payload( # type: ignore[override]
store: "DataStore",
room_id: str,
event_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
) -> JsonDict:
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
room_id
event_and_contexts
backfilled: Whether or not the events are the result of backfilling
"""
event_payloads = []
for event, context in event_and_contexts:
@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
async def _handle_request(self, request):
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
async def _serialize_payload(edu_type, origin, content):
async def _serialize_payload( # type: ignore[override]
edu_type: str, origin: str, content: JsonDict
) -> JsonDict:
return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type):
async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin)
result = await self.registry.on_edu(edu_type, origin, edu_content)
await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result
return 200, {}
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
async def _serialize_payload(query_type, args):
async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override]
"""
Args:
query_type (str)
args (dict): The arguments received for the given query type
query_type
args: The arguments received for the given query type
"""
return {"args": args}
async def _handle_request(self, request, query_type):
async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
async def _serialize_payload(room_id, args):
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
"""
Args:
room_id (str)
room_id
"""
return {}
async def _handle_request(self, request, room_id):
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)
return 200, {}
@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
async def _serialize_payload(room_id, room_version):
async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override]
return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id):
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, cast
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -39,25 +43,24 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
async def _serialize_payload(
user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost,
should_issue_refresh_token,
auth_provider_id,
auth_provider_session_id,
):
async def _serialize_payload( # type: ignore[override]
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool,
is_appservice_ghost: bool,
should_issue_refresh_token: bool,
auth_provider_id: Optional[str],
auth_provider_session_id: Optional[str],
) -> JsonDict:
"""
Args:
user_id (int)
device_id (str|None): Device ID to use, if None a new one is
generated.
initial_display_name (str|None)
is_guest (bool)
is_appservice_ghost (bool)
should_issue_refresh_token (bool)
user_id
device_id: Device ID to use, if None a new one is generated.
initial_display_name
is_guest
is_appservice_ghost
should_issue_refresh_token
"""
return {
"device_id": device_id,
@ -69,7 +72,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"auth_provider_session_id": auth_provider_session_id,
}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
device_id = content["device_id"]
@ -91,8 +96,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
auth_provider_session_id=auth_provider_session_id,
)
return 200, res
return 200, cast(JsonDict, res)
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RegisterDeviceReplicationServlet(hs).register(http_server)

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
@ -53,7 +54,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore
async def _serialize_payload( # type: ignore[override]
requester: Requester,
room_id: str,
user_id: str,
@ -77,7 +78,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request( # type: ignore
async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@ -122,13 +123,13 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore
async def _serialize_payload( # type: ignore[override]
requester: Requester,
room_id: str,
user_id: str,
remote_room_hosts: List[str],
content: JsonDict,
):
) -> JsonDict:
"""
Args:
requester: The user making the request, according to the access token.
@ -143,12 +144,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request( # type: ignore
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
room_id: str,
user_id: str,
):
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@ -192,7 +193,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler()
@staticmethod
async def _serialize_payload( # type: ignore
async def _serialize_payload( # type: ignore[override]
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
@ -215,7 +216,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request( # type: ignore
async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@ -262,12 +263,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler()
@staticmethod
async def _serialize_payload( # type: ignore
async def _serialize_payload( # type: ignore[override]
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
) -> JsonDict:
"""
Args:
knock_event_id: The ID of the knock to be rescinded.
@ -281,11 +282,11 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request( # type: ignore
async def _handle_request( # type: ignore[override]
self,
request: SynapseRequest,
knock_event_id: str,
):
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@ -329,7 +330,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
async def _serialize_payload( # type: ignore
async def _serialize_payload( # type: ignore[override]
room_id: str, user_id: str, change: str
) -> JsonDict:
"""
@ -345,7 +346,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {}
async def _handle_request( # type: ignore
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
@ -360,7 +361,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return 200, {}
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)

View File

@ -13,11 +13,14 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -49,18 +52,17 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
async def _serialize_payload(user_id):
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id)
)
return (
200,
{},
)
return (200, {})
class ReplicationPresenceSetState(ReplicationEndpoint):
@ -92,16 +94,21 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
async def _serialize_payload(
user_id, state, ignore_status_msg=False, force_notify=False
):
async def _serialize_payload( # type: ignore[override]
user_id: str,
state: JsonDict,
ignore_status_msg: bool = False,
force_notify: bool = False,
) -> JsonDict:
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
"force_notify": force_notify,
}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
await self._presence_handler.set_state(
@ -111,12 +118,9 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
content["force_notify"],
)
return (
200,
{},
)
return (200, {})
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -48,7 +52,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
self.pusher_pool = hs.get_pusherpool()
@staticmethod
async def _serialize_payload(app_id, pushkey, user_id):
async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDict: # type: ignore[override]
payload = {
"app_id": app_id,
"pushkey": pushkey,
@ -56,7 +60,9 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return payload
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
app_id = content["app_id"]
@ -67,5 +73,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {}
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -36,34 +40,34 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
async def _serialize_payload(
user_id,
password_hash,
was_guest,
make_guest,
appservice_id,
create_profile_with_displayname,
admin,
user_type,
address,
shadow_banned,
):
async def _serialize_payload( # type: ignore[override]
user_id: str,
password_hash: Optional[str],
was_guest: bool,
make_guest: bool,
appservice_id: Optional[str],
create_profile_with_displayname: Optional[str],
admin: bool,
user_type: Optional[str],
address: Optional[str],
shadow_banned: bool,
) -> JsonDict:
"""
Args:
user_id (str): The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
shadow_banned (bool): Whether to shadow-ban the user
user_id: The desired user ID to register.
password_hash: Optional. The password hash for this user.
was_guest: Optional. Whether this is a guest account being upgraded
to a non-guest account.
make_guest: True if the the new user should be guest, false to add a
regular user account.
appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname: Optionally create a profile for the
user, setting their displayname to the given value
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
address: the IP address used to perform the regitration.
shadow_banned: Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@ -77,7 +81,9 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
await self.registration_handler.check_registration_ratelimit(content["address"])
@ -110,18 +116,21 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
async def _serialize_payload(user_id, auth_result, access_token):
async def _serialize_payload( # type: ignore[override]
user_id: str, auth_result: JsonDict, access_token: Optional[str]
) -> JsonDict:
"""
Args:
user_id (str): The user ID that consented
auth_result (dict): The authenticated credentials of the newly
registered user.
access_token (str|None): The access token of the newly logged in
user_id: The user ID that consented
auth_result: The authenticated credentials of the newly registered user.
access_token: The access token of the newly logged in
device, or None if `inhibit_login` enabled.
"""
return {"auth_result": auth_result, "access_token": access_token}
async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
auth_result = content["auth_result"]
@ -134,6 +143,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return 200, {}
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRegisterServlet(hs).register(http_server)
ReplicationPostRegisterActionsServlet(hs).register(http_server)

View File

@ -13,18 +13,22 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Tuple
from twisted.web.server import Request
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -70,18 +74,24 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users
):
async def _serialize_payload( # type: ignore[override]
event_id: str,
store: "DataStore",
event: EventBase,
context: EventContext,
requester: Requester,
ratelimit: bool,
extra_users: List[UserID],
) -> JsonDict:
"""
Args:
event_id (str)
store (DataStore)
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
event_id
store
requester
event
context
ratelimit
extra_users: Any extra users to notify about event
"""
serialized_context = await context.serialize(event, store)
@ -100,7 +110,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
async def _handle_request(self, request, event_id):
async def _handle_request( # type: ignore[override]
self, request: Request, event_id: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
@ -120,8 +132,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
@ -139,5 +149,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
)
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationSendEventRestServlet(hs).register(http_server)

View File

@ -13,11 +13,15 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -57,10 +61,14 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams()
@staticmethod
async def _serialize_payload(stream_name, from_token, upto_token):
async def _serialize_payload( # type: ignore[override]
stream_name: str, from_token: int, upto_token: int
) -> JsonDict:
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name):
async def _handle_request( # type: ignore[override]
self, request: Request, stream_name: str
) -> Tuple[int, JsonDict]:
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
@ -78,5 +86,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
)
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationGetStreamUpdates(hs).register(http_server)