Faster joins: Support for calling `/federation/v1/state` (#12013)
This is an endpoint that we have server-side support for, but no client-side support. It's going to be useful for resyncing partial-stated rooms, so let's introduce it.
This commit is contained in:
parent
066171643b
commit
7273011f60
|
@ -0,0 +1 @@
|
|||
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
|
|
@ -47,6 +47,11 @@ class FederationBase:
|
|||
) -> EventBase:
|
||||
"""Checks that event is correctly signed by the sending server.
|
||||
|
||||
Also checks the content hash, and redacts the event if there is a mismatch.
|
||||
|
||||
Also runs the event through the spam checker; if it fails, redacts the event
|
||||
and flags it as soft-failed.
|
||||
|
||||
Args:
|
||||
room_version: The room version of the PDU
|
||||
pdu: the event to be checked
|
||||
|
@ -55,7 +60,10 @@ class FederationBase:
|
|||
* the original event if the checks pass
|
||||
* a redacted version of the event (if the signature
|
||||
matched but the hash did not)
|
||||
* throws a SynapseError if the signature check failed."""
|
||||
|
||||
Raises:
|
||||
SynapseError if the signature check failed.
|
||||
"""
|
||||
try:
|
||||
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
|
||||
except SynapseError as e:
|
||||
|
|
|
@ -419,26 +419,90 @@ class FederationClient(FederationBase):
|
|||
|
||||
return state_event_ids, auth_event_ids
|
||||
|
||||
async def get_room_state(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
room_version: RoomVersion,
|
||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||
"""Calls the /state endpoint to fetch the state at a particular point
|
||||
in the room.
|
||||
|
||||
Any invalid events (those with incorrect or unverifiable signatures or hashes)
|
||||
are filtered out from the response, and any duplicate events are removed.
|
||||
|
||||
(Size limits and other event-format checks are *not* performed.)
|
||||
|
||||
Note that the result is not ordered, so callers must be careful to process
|
||||
the events in an order that handles dependencies.
|
||||
|
||||
Returns:
|
||||
a tuple of (state events, auth events)
|
||||
"""
|
||||
result = await self.transport_layer.get_room_state(
|
||||
room_version,
|
||||
destination,
|
||||
room_id,
|
||||
event_id,
|
||||
)
|
||||
state_events = result.state
|
||||
auth_events = result.auth_events
|
||||
|
||||
# we may as well filter out any duplicates from the response, to save
|
||||
# processing them multiple times. (In particular, events may be present in
|
||||
# `auth_events` as well as `state`, which is redundant).
|
||||
#
|
||||
# We don't rely on the sort order of the events, so we can just stick them
|
||||
# in a dict.
|
||||
state_event_map = {event.event_id: event for event in state_events}
|
||||
auth_event_map = {
|
||||
event.event_id: event
|
||||
for event in auth_events
|
||||
if event.event_id not in state_event_map
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Processing from /state: %d state events, %d auth events",
|
||||
len(state_event_map),
|
||||
len(auth_event_map),
|
||||
)
|
||||
|
||||
valid_auth_events = await self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_event_map.values(), room_version
|
||||
)
|
||||
|
||||
valid_state_events = await self._check_sigs_and_hash_and_fetch(
|
||||
destination, state_event_map.values(), room_version
|
||||
)
|
||||
|
||||
return valid_state_events, valid_auth_events
|
||||
|
||||
async def _check_sigs_and_hash_and_fetch(
|
||||
self,
|
||||
origin: str,
|
||||
pdus: Collection[EventBase],
|
||||
room_version: RoomVersion,
|
||||
) -> List[EventBase]:
|
||||
"""Takes a list of PDUs and checks the signatures and hashes of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
that PDU.
|
||||
"""Checks the signatures and hashes of a list of events.
|
||||
|
||||
If a PDU fails its signature check then we check if we have it in
|
||||
the database, and if not then request it from the sender's server (if that
|
||||
is different from `origin`). If that still fails, the event is omitted from
|
||||
the returned list.
|
||||
|
||||
If a PDU fails its content hash check then it is redacted.
|
||||
|
||||
The given list of PDUs are not modified, instead the function returns
|
||||
Also runs each event through the spam checker; if it fails, redacts the event
|
||||
and flags it as soft-failed.
|
||||
|
||||
The given list of PDUs are not modified; instead the function returns
|
||||
a new list.
|
||||
|
||||
Args:
|
||||
origin
|
||||
pdu
|
||||
room_version
|
||||
origin: The server that sent us these events
|
||||
pdus: The events to be checked
|
||||
room_version: the version of the room these events are in
|
||||
|
||||
Returns:
|
||||
A list of PDUs that have valid signatures and hashes.
|
||||
|
@ -469,11 +533,16 @@ class FederationClient(FederationBase):
|
|||
origin: str,
|
||||
room_version: RoomVersion,
|
||||
) -> Optional[EventBase]:
|
||||
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
|
||||
its signature check then we check if we have it in the database and if
|
||||
not then request if from the originating server of that PDU.
|
||||
"""Takes a PDU and checks its signatures and hashes.
|
||||
|
||||
If then PDU fails its content hash check then it is redacted.
|
||||
If the PDU fails its signature check then we check if we have it in the
|
||||
database; if not, we then request it from sender's server (if that is not the
|
||||
same as `origin`). If that still fails, we return None.
|
||||
|
||||
If the PDU fails its content hash check, it is redacted.
|
||||
|
||||
Also runs the event through the spam checker; if it fails, redacts the event
|
||||
and flags it as soft-failed.
|
||||
|
||||
Args:
|
||||
origin
|
||||
|
|
|
@ -65,13 +65,12 @@ class TransportLayerClient:
|
|||
async def get_room_state_ids(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> JsonDict:
|
||||
"""Requests all state for a given room from the given server at the
|
||||
given event. Returns the state's event_id's
|
||||
"""Requests the IDs of all state for a given room at the given event.
|
||||
|
||||
Args:
|
||||
destination: The host name of the remote homeserver we want
|
||||
to get the state from.
|
||||
context: The name of the context we want the state of
|
||||
room_id: the room we want the state of
|
||||
event_id: The event we want the context at.
|
||||
|
||||
Returns:
|
||||
|
@ -87,6 +86,29 @@ class TransportLayerClient:
|
|||
try_trailing_slash_on_400=True,
|
||||
)
|
||||
|
||||
async def get_room_state(
|
||||
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
|
||||
) -> "StateRequestResponse":
|
||||
"""Requests the full state for a given room at the given event.
|
||||
|
||||
Args:
|
||||
room_version: the version of the room (required to build the event objects)
|
||||
destination: The host name of the remote homeserver we want
|
||||
to get the state from.
|
||||
room_id: the room we want the state of
|
||||
event_id: The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
path = _create_v1_path("/state/%s", room_id)
|
||||
return await self.client.get_json(
|
||||
destination,
|
||||
path=path,
|
||||
args={"event_id": event_id},
|
||||
parser=_StateParser(room_version),
|
||||
)
|
||||
|
||||
async def get_event(
|
||||
self, destination: str, event_id: str, timeout: Optional[int] = None
|
||||
) -> JsonDict:
|
||||
|
@ -1284,6 +1306,14 @@ class SendJoinResponse:
|
|||
servers_in_room: Optional[List[str]] = None
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class StateRequestResponse:
|
||||
"""The parsed response of a `/state` request."""
|
||||
|
||||
auth_events: List[EventBase]
|
||||
state: List[EventBase]
|
||||
|
||||
|
||||
@ijson.coroutine
|
||||
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
|
||||
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
||||
|
@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
|||
self._response.event_dict, self._room_version
|
||||
)
|
||||
return self._response
|
||||
|
||||
|
||||
class _StateParser(ByteParser[StateRequestResponse]):
|
||||
"""A parser for the response to `/state` requests.
|
||||
|
||||
Args:
|
||||
room_version: The version of the room.
|
||||
"""
|
||||
|
||||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self, room_version: RoomVersion):
|
||||
self._response = StateRequestResponse([], [])
|
||||
self._room_version = room_version
|
||||
self._coros = [
|
||||
ijson.items_coro(
|
||||
_event_list_parser(room_version, self._response.state),
|
||||
"pdus.item",
|
||||
use_float=True,
|
||||
),
|
||||
ijson.items_coro(
|
||||
_event_list_parser(room_version, self._response.auth_events),
|
||||
"auth_chain.item",
|
||||
use_float=True,
|
||||
),
|
||||
]
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
for c in self._coros:
|
||||
c.send(data)
|
||||
return len(data)
|
||||
|
||||
def finish(self) -> StateRequestResponse:
|
||||
return self._response
|
||||
|
|
|
@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
|
|||
)
|
||||
return body
|
||||
|
||||
@overload
|
||||
async def get_json(
|
||||
self,
|
||||
destination: str,
|
||||
|
@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
max_response_size: Optional[int] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_json(
|
||||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = ...,
|
||||
retry_on_dns_fail: bool = ...,
|
||||
timeout: Optional[int] = ...,
|
||||
ignore_backoff: bool = ...,
|
||||
try_trailing_slash_on_400: bool = ...,
|
||||
parser: ByteParser[T] = ...,
|
||||
max_response_size: Optional[int] = ...,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
async def get_json(
|
||||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
max_response_size: Optional[int] = None,
|
||||
):
|
||||
"""GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
|
@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
|
|||
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
|
||||
response we should try appending a trailing slash to the end of
|
||||
the request. Workaround for #3622 in Synapse <= v0.99.3.
|
||||
|
||||
parser: The parser to use to decode the response. Defaults to
|
||||
parsing as JSON.
|
||||
|
||||
max_response_size: The maximum size to read from the response. If None,
|
||||
uses the default.
|
||||
|
||||
Returns:
|
||||
Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
|
|||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
|
||||
self.reactor,
|
||||
_sec_timeout,
|
||||
request,
|
||||
response,
|
||||
start_ms,
|
||||
parser=parser,
|
||||
max_response_size=max_response_size,
|
||||
)
|
||||
|
||||
return body
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2022 Matrix.org Federation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import twisted.web.client
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import FederatingHomeserverTestCase
|
||||
|
||||
|
||||
class FederationClientTest(FederatingHomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
||||
super().prepare(reactor, clock, homeserver)
|
||||
|
||||
# mock out the Agent used by the federation client, which is easier than
|
||||
# catching the HTTPS connection and do the TLS stuff.
|
||||
self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
|
||||
homeserver.get_federation_http_client().agent = self._mock_agent
|
||||
|
||||
def test_get_room_state(self):
|
||||
creator = f"@creator:{self.OTHER_SERVER_NAME}"
|
||||
test_room_id = "!room_id"
|
||||
|
||||
# mock up some events to use in the response.
|
||||
# In real life, these would have things in `prev_events` and `auth_events`, but that's
|
||||
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
|
||||
create_event_dict = self.add_hashes_and_signatures(
|
||||
{
|
||||
"room_id": test_room_id,
|
||||
"type": "m.room.create",
|
||||
"state_key": "",
|
||||
"sender": creator,
|
||||
"content": {"creator": creator},
|
||||
"prev_events": [],
|
||||
"auth_events": [],
|
||||
"origin_server_ts": 500,
|
||||
}
|
||||
)
|
||||
member_event_dict = self.add_hashes_and_signatures(
|
||||
{
|
||||
"room_id": test_room_id,
|
||||
"type": "m.room.member",
|
||||
"sender": creator,
|
||||
"state_key": creator,
|
||||
"content": {"membership": "join"},
|
||||
"prev_events": [],
|
||||
"auth_events": [],
|
||||
"origin_server_ts": 600,
|
||||
}
|
||||
)
|
||||
pl_event_dict = self.add_hashes_and_signatures(
|
||||
{
|
||||
"room_id": test_room_id,
|
||||
"type": "m.room.power_levels",
|
||||
"sender": creator,
|
||||
"state_key": "",
|
||||
"content": {},
|
||||
"prev_events": [],
|
||||
"auth_events": [],
|
||||
"origin_server_ts": 700,
|
||||
}
|
||||
)
|
||||
|
||||
# mock up the response, and have the agent return it
|
||||
self._mock_agent.request.return_value = defer.succeed(
|
||||
_mock_response(
|
||||
{
|
||||
"pdus": [
|
||||
create_event_dict,
|
||||
member_event_dict,
|
||||
pl_event_dict,
|
||||
],
|
||||
"auth_chain": [
|
||||
create_event_dict,
|
||||
member_event_dict,
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# now fire off the request
|
||||
state_resp, auth_resp = self.get_success(
|
||||
self.hs.get_federation_client().get_room_state(
|
||||
"yet_another_server",
|
||||
test_room_id,
|
||||
"event_id",
|
||||
RoomVersions.V9,
|
||||
)
|
||||
)
|
||||
|
||||
# check the right call got made to the agent
|
||||
self._mock_agent.request.assert_called_once_with(
|
||||
b"GET",
|
||||
b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
|
||||
headers=mock.ANY,
|
||||
bodyProducer=None,
|
||||
)
|
||||
|
||||
# ... and that the response is correct.
|
||||
|
||||
# the auth_resp should be empty because all the events are also in state
|
||||
self.assertEqual(auth_resp, [])
|
||||
|
||||
# all of the events should be returned in state_resp, though not necessarily
|
||||
# in the same order. We just check the type on the assumption that if the type
|
||||
# is right, so is the rest of the event.
|
||||
self.assertCountEqual(
|
||||
[e.type for e in state_resp],
|
||||
["m.room.create", "m.room.member", "m.room.power_levels"],
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(resp: JsonDict):
|
||||
body = json.dumps(resp).encode("utf-8")
|
||||
|
||||
def deliver_body(p: Protocol):
|
||||
p.dataReceived(body)
|
||||
p.connectionLost(Failure(twisted.web.client.ResponseDone()))
|
||||
|
||||
response = mock.Mock(
|
||||
code=200,
|
||||
phrase=b"OK",
|
||||
headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
|
||||
length=len(body),
|
||||
deliverBody=deliver_body,
|
||||
)
|
||||
mock.seal(response)
|
||||
return response
|
|
@ -51,7 +51,10 @@ from twisted.web.server import Request
|
|||
|
||||
from synapse import events
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
|
@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||
client_ip=client_ip,
|
||||
)
|
||||
|
||||
def add_hashes_and_signatures(
|
||||
self,
|
||||
event_dict: JsonDict,
|
||||
room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
|
||||
) -> JsonDict:
|
||||
"""Adds hashes and signatures to the given event dict
|
||||
|
||||
Returns:
|
||||
The modified event dict, for convenience
|
||||
"""
|
||||
add_hashes_and_signatures(
|
||||
room_version,
|
||||
event_dict,
|
||||
signature_name=self.OTHER_SERVER_NAME,
|
||||
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
|
||||
)
|
||||
return event_dict
|
||||
|
||||
|
||||
def _auth_header_for_request(
|
||||
origin: str,
|
||||
|
|
Loading…
Reference in New Issue