Tests: replace mocked Authenticator with the real thing (#11913)
If we prepopulate the test homeserver with a key for a remote homeserver, we can make federation requests to it without having to stub out the authenticator. This has two advantages: * means that what we are testing is closer to reality (ie, we now have complete tests for the incoming-request-authorisation flow) * some tests require that other objects be signed by the remote server (eg, the event in `/send_join`), and doing that would require a whole separate set of mocking out. It's much simpler just to use real keys.
This commit is contained in:
parent
d36943c4df
commit
c3db7a0b59
|
@ -0,0 +1 @@
|
||||||
|
Tests: replace mocked `Authenticator` with the real thing.
|
|
@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the room complexity
|
# Get the room complexity
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEquals(200, channel.code)
|
self.assertEquals(200, channel.code)
|
||||||
|
@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
|
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
|
||||||
|
|
||||||
# Get the room complexity again -- make sure it's our artificial value
|
# Get the room complexity again -- make sure it's our artificial value
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEquals(200, channel.code)
|
self.assertEquals(200, channel.code)
|
||||||
|
|
|
@ -113,7 +113,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
|
||||||
room_1 = self.helper.create_room_as(u1, tok=u1_token)
|
room_1 = self.helper.create_room_as(u1, tok=u1_token)
|
||||||
self.inject_room_member(room_1, "@user:other.example.com", "join")
|
self.inject_room_member(room_1, "@user:other.example.com", "join")
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
|
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
@ -145,7 +145,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
room_1 = self.helper.create_room_as(u1, tok=u1_token)
|
room_1 = self.helper.create_room_as(u1, tok=u1_token)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
|
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEquals(403, channel.code, channel.result)
|
self.assertEquals(403, channel.code, channel.result)
|
||||||
|
|
|
@ -245,7 +245,7 @@ class FederationKnockingTestCase(
|
||||||
self.hs, room_id, user_id
|
self.hs, room_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
|
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
|
||||||
% (
|
% (
|
||||||
|
@ -288,7 +288,7 @@ class FederationKnockingTestCase(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the signed knock event into the room
|
# Send the signed knock event into the room
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
"/_matrix/federation/v1/send_knock/%s/%s"
|
"/_matrix/federation/v1/send_knock/%s/%s"
|
||||||
% (room_id, signed_knock_event.event_id),
|
% (room_id, signed_knock_event.event_id),
|
||||||
|
|
|
@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
"""Test that unauthenticated requests to the public rooms directory 403 when
|
"""Test that unauthenticated requests to the public rooms directory 403 when
|
||||||
allow_public_rooms_over_federation is False.
|
allow_public_rooms_over_federation is False.
|
||||||
"""
|
"""
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/federation/v1/publicRooms",
|
"/_matrix/federation/v1/publicRooms",
|
||||||
federation_auth_origin=b"example.com",
|
|
||||||
)
|
)
|
||||||
self.assertEquals(403, channel.code)
|
self.assertEquals(403, channel.code)
|
||||||
|
|
||||||
|
@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
"""Test that unauthenticated requests to the public rooms directory 200 when
|
"""Test that unauthenticated requests to the public rooms directory 200 when
|
||||||
allow_public_rooms_over_federation is True.
|
allow_public_rooms_over_federation is True.
|
||||||
"""
|
"""
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/federation/v1/publicRooms",
|
"/_matrix/federation/v1/publicRooms",
|
||||||
federation_auth_origin=b"example.com",
|
|
||||||
)
|
)
|
||||||
self.assertEquals(200, channel.code)
|
self.assertEquals(200, channel.code)
|
||||||
|
|
|
@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
super().prepare(reactor, clock, homeserver)
|
||||||
# Create some users and a room to play with during the tests
|
# Create some users and a room to play with during the tests
|
||||||
self.user_id = self.register_user("kermit", "monkey")
|
self.user_id = self.register_user("kermit", "monkey")
|
||||||
self.invitee = self.register_user("invitee", "hackme")
|
self.invitee = self.register_user("invitee", "hackme")
|
||||||
|
@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
def _send_event_over_federation(self) -> None:
|
def _send_event_over_federation(self) -> None:
|
||||||
"""Send a dummy event over federation and check that the request succeeds."""
|
"""Send a dummy event over federation and check that the request succeeds."""
|
||||||
body = {
|
body = {
|
||||||
"origin": self.hs.config.server.server_name,
|
|
||||||
"origin_server_ts": self.clock.time_msec(),
|
|
||||||
"pdus": [
|
"pdus": [
|
||||||
{
|
{
|
||||||
"sender": self.user_id,
|
"sender": self.user_id,
|
||||||
|
@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_signed_federation_request(
|
||||||
method="PUT",
|
method="PUT",
|
||||||
path="/_matrix/federation/v1/send/1",
|
path="/_matrix/federation/v1/send/1",
|
||||||
content=body,
|
content=body,
|
||||||
federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
|
@ -36,9 +37,11 @@ from typing import (
|
||||||
)
|
)
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from canonicaljson import json
|
import canonicaljson
|
||||||
|
import signedjson.key
|
||||||
|
import unpaddedbase64
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred, succeed
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.python.threadpool import ThreadPool
|
from twisted.python.threadpool import ThreadPool
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
@ -49,8 +52,7 @@ from twisted.web.server import Request
|
||||||
from synapse import events
|
from synapse import events
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.ratelimiting import FederationRateLimitConfig
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
from synapse.federation.transport import server as federation_server
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.site import SynapseRequest, SynapseSite
|
from synapse.http.site import SynapseRequest, SynapseSite
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
|
@ -61,10 +63,10 @@ from synapse.logging.context import (
|
||||||
)
|
)
|
||||||
from synapse.rest import RegisterServletsFunc
|
from synapse.rest import RegisterServletsFunc
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.keys import FetchKeyResult
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
||||||
|
|
||||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
||||||
from tests.test_utils import event_injection, setup_awaitable_errors
|
from tests.test_utils import event_injection, setup_awaitable_errors
|
||||||
|
@ -755,42 +757,116 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
class FederatingHomeserverTestCase(HomeserverTestCase):
|
class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
A federating homeserver that authenticates incoming requests as `other.example.com`.
|
A federating homeserver, set up to validate incoming federation requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
OTHER_SERVER_NAME = "other.example.com"
|
||||||
|
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||||
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
|
# poke the other server's signing key into the key store, so that we don't
|
||||||
|
# make requests for it
|
||||||
|
verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
|
||||||
|
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
hs.get_datastore().store_server_verify_keys(
|
||||||
|
from_server=self.OTHER_SERVER_NAME,
|
||||||
|
ts_added_ms=clock.time_msec(),
|
||||||
|
verify_keys=[
|
||||||
|
(
|
||||||
|
self.OTHER_SERVER_NAME,
|
||||||
|
verify_key_id,
|
||||||
|
FetchKeyResult(
|
||||||
|
verify_key=verify_key,
|
||||||
|
valid_until_ts=clock.time_msec() + 1000,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
d = super().create_resource_dict()
|
d = super().create_resource_dict()
|
||||||
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
|
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def make_signed_federation_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
content: Optional[JsonDict] = None,
|
||||||
|
await_result: bool = True,
|
||||||
|
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||||
|
client_ip: str = "127.0.0.1",
|
||||||
|
) -> FakeChannel:
|
||||||
|
"""Make an inbound signed federation request to this server
|
||||||
|
|
||||||
class TestTransportLayerServer(JsonResource):
|
The request is signed as if it came from "other.example.com", which our HS
|
||||||
"""A test implementation of TransportLayerServer
|
already has the keys for.
|
||||||
|
"""
|
||||||
|
|
||||||
authenticates incoming requests as `other.example.com`.
|
if custom_headers is None:
|
||||||
"""
|
custom_headers = []
|
||||||
|
else:
|
||||||
|
custom_headers = list(custom_headers)
|
||||||
|
|
||||||
def __init__(self, hs):
|
custom_headers.append(
|
||||||
super().__init__(hs)
|
(
|
||||||
|
"Authorization",
|
||||||
class Authenticator:
|
_auth_header_for_request(
|
||||||
def authenticate_request(self, request, content):
|
origin=self.OTHER_SERVER_NAME,
|
||||||
return succeed("other.example.com")
|
destination=self.hs.hostname,
|
||||||
|
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
|
||||||
authenticator = Authenticator()
|
method=method,
|
||||||
|
path=path,
|
||||||
ratelimiter = FederationRateLimiter(
|
content=content,
|
||||||
hs.get_clock(),
|
),
|
||||||
FederationRateLimitConfig(
|
)
|
||||||
window_size=1,
|
|
||||||
sleep_limit=1,
|
|
||||||
sleep_delay=1,
|
|
||||||
reject_limit=1000,
|
|
||||||
concurrent=1000,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
|
return make_request(
|
||||||
|
self.reactor,
|
||||||
|
self.site,
|
||||||
|
method=method,
|
||||||
|
path=path,
|
||||||
|
content=content,
|
||||||
|
shorthand=False,
|
||||||
|
await_result=await_result,
|
||||||
|
custom_headers=custom_headers,
|
||||||
|
client_ip=client_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_header_for_request(
|
||||||
|
origin: str,
|
||||||
|
destination: str,
|
||||||
|
signing_key: signedjson.key.SigningKey,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
content: Optional[JsonDict],
|
||||||
|
) -> str:
|
||||||
|
"""Build a suitable Authorization header for an outgoing federation request"""
|
||||||
|
request_description: JsonDict = {
|
||||||
|
"method": method,
|
||||||
|
"uri": path,
|
||||||
|
"destination": destination,
|
||||||
|
"origin": origin,
|
||||||
|
}
|
||||||
|
if content is not None:
|
||||||
|
request_description["content"] = content
|
||||||
|
signature_base64 = unpaddedbase64.encode_base64(
|
||||||
|
signing_key.sign(
|
||||||
|
canonicaljson.encode_canonical_json(request_description)
|
||||||
|
).signature
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"X-Matrix origin={origin},"
|
||||||
|
f"key={signing_key.alg}:{signing_key.version},"
|
||||||
|
f"sig={signature_base64}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def override_config(extra_config):
|
def override_config(extra_config):
|
||||||
|
|
Loading…
Reference in New Issue