Add missing type hints to synapse.replication.http. (#11856)
This commit is contained in:
parent
8b309adb43
commit
63d90f10ec
|
@ -0,0 +1 @@
|
|||
Add missing type hints to replication code.
|
|
@ -40,7 +40,7 @@ class ReplicationRestResource(JsonResource):
|
|||
super().__init__(hs, canonical_json=False, extract_context=True)
|
||||
self.register_servlets(hs)
|
||||
|
||||
def register_servlets(self, hs: "HomeServer"):
|
||||
def register_servlets(self, hs: "HomeServer") -> None:
|
||||
send_event.register_servlets(hs, self)
|
||||
federation.register_servlets(hs, self)
|
||||
presence.register_servlets(hs, self)
|
||||
|
|
|
@ -15,16 +15,20 @@
|
|||
import abc
|
||||
import logging
|
||||
import re
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
|
||||
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import HttpResponseException, SynapseError
|
||||
from synapse.http import RequestTimedOutError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
if hs.config.worker.worker_replication_secret:
|
||||
self._replication_secret = hs.config.worker.worker_replication_secret
|
||||
|
||||
def _check_auth(self, request) -> None:
|
||||
def _check_auth(self, request: Request) -> None:
|
||||
# Get the authorization header.
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if not auth_headers:
|
||||
raise RuntimeError("Missing Authorization header.")
|
||||
if len(auth_headers) > 1:
|
||||
raise RuntimeError("Too many Authorization headers.")
|
||||
parts = auth_headers[0].split(b" ")
|
||||
|
@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
raise RuntimeError("Invalid Authorization header.")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _serialize_payload(**kwargs):
|
||||
async def _serialize_payload(**kwargs) -> JsonDict:
|
||||
"""Static method that is called when creating a request.
|
||||
|
||||
Concrete implementations should have explicit parameters (rather than
|
||||
|
@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _handle_request(self, request, **kwargs):
|
||||
async def _handle_request(
|
||||
self, request: Request, **kwargs: Any
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Handle incoming request.
|
||||
|
||||
This is called with the request object and PATH_ARGS.
|
||||
|
||||
Returns:
|
||||
tuple[int, dict]: HTTP status code and a JSON serialisable dict
|
||||
to be used as response body of request.
|
||||
HTTP status code and a JSON serialisable dict to be used as response
|
||||
body of request.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def make_client(cls, hs: "HomeServer"):
|
||||
def make_client(cls, hs: "HomeServer") -> Callable:
|
||||
"""Create a client that makes requests.
|
||||
|
||||
Returns a callable that accepts the same parameters as
|
||||
|
@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
)
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
async def send_request(*, instance_name="master", **kwargs):
|
||||
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
||||
with outgoing_gauge.track_inprogress():
|
||||
if instance_name == local_instance_name:
|
||||
raise Exception("Trying to send HTTP request to self")
|
||||
|
@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
|
||||
return send_request
|
||||
|
||||
def register(self, http_server):
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
"""Called by the server to register this as a handler to the
|
||||
appropriate path.
|
||||
"""
|
||||
|
@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||
self.__class__.__name__,
|
||||
)
|
||||
|
||||
async def _check_auth_and_handle(self, request, **kwargs):
|
||||
async def _check_auth_and_handle(
|
||||
self, request: Request, **kwargs: Any
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Called on new incoming requests when caching is enabled. Checks
|
||||
if there is a cached response for the request and returns that,
|
||||
otherwise calls `_handle_request` and caches its response.
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -48,14 +52,18 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, account_data_type, content):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str, account_data_type: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
payload = {
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request, user_id, account_data_type):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str, account_data_type: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
max_stream_id = await self.handler.add_account_data_for_user(
|
||||
|
@ -89,14 +97,18 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, room_id, account_data_type, content):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
payload = {
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request, user_id, room_id, account_data_type):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str, room_id: str, account_data_type: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
max_stream_id = await self.handler.add_account_data_to_room(
|
||||
|
@ -130,14 +142,18 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, room_id, tag, content):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str, room_id: str, tag: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
payload = {
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request, user_id, room_id, tag):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str, room_id: str, tag: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
max_stream_id = await self.handler.add_tag_to_room(
|
||||
|
@ -173,11 +189,13 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, room_id, tag):
|
||||
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
|
||||
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, user_id, room_id, tag):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str, room_id: str, tag: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
max_stream_id = await self.handler.remove_tag_from_room(
|
||||
user_id,
|
||||
room_id,
|
||||
|
@ -187,7 +205,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
|||
return 200, {"max_stream_id": max_stream_id}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationUserAccountDataRestServlet(hs).register(http_server)
|
||||
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
|
||||
ReplicationAddTagRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,9 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -63,14 +67,16 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id):
|
||||
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
||||
|
||||
return 200, user_devices
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,17 +13,22 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
self.federation_event_handler = hs.get_federation_event_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
store: "DataStore",
|
||||
room_id: str,
|
||||
event_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
store
|
||||
room_id (str)
|
||||
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
||||
backfilled (bool): Whether or not the events are the result of
|
||||
backfilling
|
||||
room_id
|
||||
event_and_contexts
|
||||
backfilled: Whether or not the events are the result of backfilling
|
||||
"""
|
||||
event_payloads = []
|
||||
for event, context in event_and_contexts:
|
||||
|
@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request):
|
||||
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
|
||||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
|||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(edu_type, origin, content):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
edu_type: str, origin: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
return {"origin": origin, "content": content}
|
||||
|
||||
async def _handle_request(self, request, edu_type):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, edu_type: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
|||
|
||||
logger.info("Got %r edu from %s", edu_type, origin)
|
||||
|
||||
result = await self.registry.on_edu(edu_type, origin, edu_content)
|
||||
await self.registry.on_edu(edu_type, origin, edu_content)
|
||||
|
||||
return 200, result
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||
|
@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
|||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(query_type, args):
|
||||
async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override]
|
||||
"""
|
||||
Args:
|
||||
query_type (str)
|
||||
args (dict): The arguments received for the given query type
|
||||
query_type
|
||||
args: The arguments received for the given query type
|
||||
"""
|
||||
return {"args": args}
|
||||
|
||||
async def _handle_request(self, request, query_type):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, query_type: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
with Measure(self.clock, "repl_fed_query_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(room_id, args):
|
||||
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
|
||||
"""
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, room_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.store.clean_room_for_join(room_id)
|
||||
|
||||
return 200, {}
|
||||
|
@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(room_id, room_version):
|
||||
async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override]
|
||||
return {"room_version": room_version.identifier}
|
||||
|
||||
async def _handle_request(self, request, room_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
||||
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
||||
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, cast
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -39,25 +43,24 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(
|
||||
user_id,
|
||||
device_id,
|
||||
initial_display_name,
|
||||
is_guest,
|
||||
is_appservice_ghost,
|
||||
should_issue_refresh_token,
|
||||
auth_provider_id,
|
||||
auth_provider_session_id,
|
||||
):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
initial_display_name: Optional[str],
|
||||
is_guest: bool,
|
||||
is_appservice_ghost: bool,
|
||||
should_issue_refresh_token: bool,
|
||||
auth_provider_id: Optional[str],
|
||||
auth_provider_session_id: Optional[str],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
user_id (int)
|
||||
device_id (str|None): Device ID to use, if None a new one is
|
||||
generated.
|
||||
initial_display_name (str|None)
|
||||
is_guest (bool)
|
||||
is_appservice_ghost (bool)
|
||||
should_issue_refresh_token (bool)
|
||||
user_id
|
||||
device_id: Device ID to use, if None a new one is generated.
|
||||
initial_display_name
|
||||
is_guest
|
||||
is_appservice_ghost
|
||||
should_issue_refresh_token
|
||||
"""
|
||||
return {
|
||||
"device_id": device_id,
|
||||
|
@ -69,7 +72,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||
"auth_provider_session_id": auth_provider_session_id,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
device_id = content["device_id"]
|
||||
|
@ -91,8 +96,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||
auth_provider_session_id=auth_provider_session_id,
|
||||
)
|
||||
|
||||
return 200, res
|
||||
return 200, cast(JsonDict, res)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RegisterDeviceReplicationServlet(hs).register(http_server)
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
@ -53,7 +54,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
|
@ -77,7 +78,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: SynapseRequest, room_id: str, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -122,13 +123,13 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
remote_room_hosts: List[str],
|
||||
content: JsonDict,
|
||||
):
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
requester: The user making the request, according to the access token.
|
||||
|
@ -143,12 +144,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
remote_room_hosts = content["remote_room_hosts"]
|
||||
|
@ -192,7 +193,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||
self.member_handler = hs.get_room_member_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
|
@ -215,7 +216,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: SynapseRequest, invite_event_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -262,12 +263,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
|
|||
self.member_handler = hs.get_room_member_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
knock_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
content: JsonDict,
|
||||
):
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
knock_event_id: The ID of the knock to be rescinded.
|
||||
|
@ -281,11 +282,11 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
knock_event_id: str,
|
||||
):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
txn_id = content["txn_id"]
|
||||
|
@ -329,7 +330,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||
self.distributor = hs.get_distributor()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
room_id: str, user_id: str, change: str
|
||||
) -> JsonDict:
|
||||
"""
|
||||
|
@ -345,7 +346,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, room_id: str, user_id: str, change: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||
|
@ -360,7 +361,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationRemoteJoinRestServlet(hs).register(http_server)
|
||||
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
|
||||
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,11 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -49,18 +52,17 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
|
|||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id):
|
||||
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self._presence_handler.bump_presence_active_time(
|
||||
UserID.from_string(user_id)
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{},
|
||||
)
|
||||
return (200, {})
|
||||
|
||||
|
||||
class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||
|
@ -92,16 +94,21 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
|||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(
|
||||
user_id, state, ignore_status_msg=False, force_notify=False
|
||||
):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str,
|
||||
state: JsonDict,
|
||||
ignore_status_msg: bool = False,
|
||||
force_notify: bool = False,
|
||||
) -> JsonDict:
|
||||
return {
|
||||
"state": state,
|
||||
"ignore_status_msg": ignore_status_msg,
|
||||
"force_notify": force_notify,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
await self._presence_handler.set_state(
|
||||
|
@ -111,12 +118,9 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
|||
content["force_notify"],
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{},
|
||||
)
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationBumpPresenceActiveTime(hs).register(http_server)
|
||||
ReplicationPresenceSetState(hs).register(http_server)
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -48,7 +52,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(app_id, pushkey, user_id):
|
||||
async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDict: # type: ignore[override]
|
||||
payload = {
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
|
@ -56,7 +60,9 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
app_id = content["app_id"]
|
||||
|
@ -67,5 +73,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationRemovePusherRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -36,34 +40,34 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(
|
||||
user_id,
|
||||
password_hash,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id,
|
||||
create_profile_with_displayname,
|
||||
admin,
|
||||
user_type,
|
||||
address,
|
||||
shadow_banned,
|
||||
):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str,
|
||||
password_hash: Optional[str],
|
||||
was_guest: bool,
|
||||
make_guest: bool,
|
||||
appservice_id: Optional[str],
|
||||
create_profile_with_displayname: Optional[str],
|
||||
admin: bool,
|
||||
user_type: Optional[str],
|
||||
address: Optional[str],
|
||||
shadow_banned: bool,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
user_id (str): The desired user ID to register.
|
||||
password_hash (str|None): Optional. The password hash for this user.
|
||||
was_guest (bool): Optional. Whether this is a guest account being
|
||||
upgraded to a non-guest account.
|
||||
make_guest (boolean): True if the the new user should be guest,
|
||||
false to add a regular user account.
|
||||
appservice_id (str|None): The ID of the appservice registering the user.
|
||||
create_profile_with_displayname (unicode|None): Optionally create a
|
||||
profile for the user, setting their displayname to the given value
|
||||
admin (boolean): is an admin user?
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
address (str|None): the IP address used to perform the regitration.
|
||||
shadow_banned (bool): Whether to shadow-ban the user
|
||||
user_id: The desired user ID to register.
|
||||
password_hash: Optional. The password hash for this user.
|
||||
was_guest: Optional. Whether this is a guest account being upgraded
|
||||
to a non-guest account.
|
||||
make_guest: True if the the new user should be guest, false to add a
|
||||
regular user account.
|
||||
appservice_id: The ID of the appservice registering the user.
|
||||
create_profile_with_displayname: Optionally create a profile for the
|
||||
user, setting their displayname to the given value
|
||||
admin: is an admin user?
|
||||
user_type: type of user. One of the values from api.constants.UserTypes,
|
||||
or None for a normal user.
|
||||
address: the IP address used to perform the regitration.
|
||||
shadow_banned: Whether to shadow-ban the user
|
||||
"""
|
||||
return {
|
||||
"password_hash": password_hash,
|
||||
|
@ -77,7 +81,9 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
"shadow_banned": shadow_banned,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
await self.registration_handler.check_registration_ratelimit(content["address"])
|
||||
|
@ -110,18 +116,21 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id, auth_result, access_token):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str, auth_result: JsonDict, access_token: Optional[str]
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
user_id (str): The user ID that consented
|
||||
auth_result (dict): The authenticated credentials of the newly
|
||||
registered user.
|
||||
access_token (str|None): The access token of the newly logged in
|
||||
user_id: The user ID that consented
|
||||
auth_result: The authenticated credentials of the newly registered user.
|
||||
access_token: The access token of the newly logged in
|
||||
device, or None if `inhibit_login` enabled.
|
||||
"""
|
||||
return {"auth_result": auth_result, "access_token": access_token}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
auth_result = content["auth_result"]
|
||||
|
@ -134,6 +143,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationRegisterServlet(hs).register(http_server)
|
||||
ReplicationPostRegisterActionsServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,18 +13,22 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,18 +74,24 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(
|
||||
event_id, store, event, context, requester, ratelimit, extra_users
|
||||
):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
event_id: str,
|
||||
store: "DataStore",
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
requester: Requester,
|
||||
ratelimit: bool,
|
||||
extra_users: List[UserID],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
event_id (str)
|
||||
store (DataStore)
|
||||
requester (Requester)
|
||||
event (FrozenEvent)
|
||||
context (EventContext)
|
||||
ratelimit (bool)
|
||||
extra_users (list(UserID)): Any extra users to notify about event
|
||||
event_id
|
||||
store
|
||||
requester
|
||||
event
|
||||
context
|
||||
ratelimit
|
||||
extra_users: Any extra users to notify about event
|
||||
"""
|
||||
serialized_context = await context.serialize(event, store)
|
||||
|
||||
|
@ -100,7 +110,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
|
||||
return payload
|
||||
|
||||
async def _handle_request(self, request, event_id):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, event_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
with Measure(self.clock, "repl_send_event_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -120,8 +132,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
ratelimit = content["ratelimit"]
|
||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||
|
||||
request.requester = requester
|
||||
|
||||
logger.info(
|
||||
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
||||
)
|
||||
|
@ -139,5 +149,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationSendEventRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,11 +13,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import parse_integer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -57,10 +61,14 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
self.streams = hs.get_replication_streams()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(stream_name, from_token, upto_token):
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
stream_name: str, from_token: int, upto_token: int
|
||||
) -> JsonDict:
|
||||
return {"from_token": from_token, "upto_token": upto_token}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, stream_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
stream = self.streams.get(stream_name)
|
||||
if stream is None:
|
||||
raise SynapseError(400, "Unknown stream")
|
||||
|
@ -78,5 +86,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationGetStreamUpdates(hs).register(http_server)
|
||||
|
|
Loading…
Reference in New Issue