Merge pull request #530 from matrix-org/erikj/server_refactor

Remove redundant BaseHomeServer
This commit is contained in:
Erik Johnston 2016-01-27 17:36:31 +00:00
commit 5610880003
12 changed files with 191 additions and 545 deletions

View File

@ -50,16 +50,14 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse import events from synapse import events
@ -95,80 +94,37 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
def build_resource_for_web_client(hs):
webclient_path = hs.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_client_resource(self):
return ClientRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_address = listener_config.get("bind_address", "")
@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls: if tls and config.no_tls:
return return
metrics_resource = self.get_resource_for_metrics()
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
if name == "client": if name == "client":
client_resource = self.get_client_resource() client_resource = ClientRestResource(self)
if res["compress"]: if res["compress"]:
client_resource = gz_wrap(client_resource) client_resource = gz_wrap(client_resource)
@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer):
if name == "federation": if name == "federation":
resources.update({ resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(), FEDERATION_PREFIX: TransportLayerServer(self),
}) })
if name in ["static", "client"]: if name in ["static", "client"]:
resources.update({ resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(), STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
resources.update({ resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(), MEDIA_PREFIX: MediaRepositoryResource(self),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
}) })
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources.update({ resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(), SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
}) })
if name == "webclient": if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and metrics_resource: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = metrics_resource resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources) root_resource = create_resource_tree(resources)
if tls: if tls:
@ -675,7 +633,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B. the mapping should looks like _resource_id(A,C) = B.
Args: Args:
resource (Resource): The *parent* Resource resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached. path_seg (str): The name of the child Resource to be attached.
Returns: Returns:
str: A unique string which can be a key to the child Resource. str: A unique string which can be a key to the child Resource.
@ -761,6 +719,7 @@ def run(hs):
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
chdir=os.path.dirname(os.path.abspath(__file__)),
) )
daemon.start() daemon.start()

View File

@ -17,15 +17,10 @@
""" """
from .replication import ReplicationLayer from .replication import ReplicationLayer
from .transport import TransportLayer from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver): def initialize_http_replication(homeserver):
transport = TransportLayer( transport = TransportLayerClient(homeserver)
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
)
return ReplicationLayer(homeserver, transport) return ReplicationLayer(homeserver, transport)

View File

@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self self.federation_client = self

View File

@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer
from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object): class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs):
self.server_name = hs.hostname
self.client = hs.get_http_client()
@log_function @log_function
def get_room_state(self, destination, room_id, event_id): def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the """ Requests all state for a given room from the given server at the

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.http.server import JsonResource
from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
import logging import logging
@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransportLayerServer(object): class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
)
self.register_servlets()
def register_servlets(self):
register_servlets(
self.hs,
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
)
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request):
@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter): def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/" PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs) super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs
)
self.server_name = server_name self.server_name = server_name
# This is when someone is trying to send us a bunch of data. # This is when someone is trying to send us a bunch of data.
@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
) )
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)

View File

@ -20,6 +20,8 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -36,8 +38,10 @@ from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
class BaseHomeServer(object):
class HomeServer(object):
"""A basic homeserver object without lazy component builders. """A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as This will need all of the components it requires to either be passed as
@ -102,36 +106,6 @@ class BaseHomeServer(object):
for depname in kwargs: for depname in kwargs:
setattr(self, depname, kwargs[depname]) setattr(self, depname, kwargs[depname])
@classmethod
def _make_dependency_method(cls, depname):
def _get(self):
if hasattr(self, depname):
return getattr(self, depname)
if hasattr(self, "build_%s" % (depname)):
# Prevent cyclic dependencies from deadlocking
if depname in self._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
self._building[depname] = 1
builder = getattr(self, "build_%s" % (depname))
dep = builder()
setattr(self, depname, dep)
del self._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(self).__name__, depname,
)
)
setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type. # X-Forwarded-For is handled by our custom request type.
return request.getClientIP() return request.getClientIP()
@ -142,24 +116,6 @@ class BaseHomeServer(object):
def is_mine_id(self, string): def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname return string.split(":", 1)[1] == self.hostname
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)
class HomeServer(BaseHomeServer):
"""A homeserver object that will construct most of its dependencies as
required.
It still requires the following to be specified by the caller:
resource_for_client
resource_for_web_client
resource_for_federation
resource_for_content_repo
http_client
db_pool
"""
def build_clock(self): def build_clock(self):
return Clock() return Clock()
@ -224,3 +180,55 @@ class HomeServer(BaseHomeServer):
def build_pusherpool(self): def build_pusherpool(self):
return PusherPool(self) return PusherPool(self)
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def _make_dependency_method(depname):
def _get(hs):
try:
return getattr(hs, depname)
except AttributeError:
pass
try:
builder = getattr(hs, "build_%s" % (depname))
except AttributeError:
builder = None
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
raise ValueError("Cyclic dependency while building %s" % (
depname,
))
hs._building[depname] = 1
dep = builder()
setattr(hs, depname, dep)
del hs._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (
type(hs).__name__, depname,
)
)
setattr(HomeServer, "get_%s" % (depname), _get)
# Build magic accessors for every dependency
for depname in HomeServer.DEPENDENCIES:
_make_dependency_method(depname)

View File

@ -1,303 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# trial imports
from twisted.internet import defer
from tests import unittest
# python imports
from mock import Mock, ANY
from ..utils import MockHttpResource, MockClock, setup_test_homeserver
from synapse.federation import initialize_http_replication
from synapse.events import FrozenEvent
def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple."""
pdu_fields = {
"state_key": None,
"prev_events": prev_pdus,
}
pdu_fields.update(kwargs)
return FrozenEvent(pdu_fields)
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[
"get_json",
"put_json",
])
self.mock_persistence = Mock(spec=[
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_auth_chain",
])
self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None)
)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
self.mock_persistence.get_auth_chain.return_value = []
self.clock = MockClock()
hs = yield setup_test_homeserver(
resource_for_federation=self.mock_resource,
http_client=self.mock_http_client,
datastore=self.mock_persistence,
clock=self.clock,
keyring=Mock(),
)
self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor()
@defer.inlineCallbacks
def test_get_state(self):
mock_handler = Mock(spec=[
"get_state_for_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
# Empty context initially
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertFalse(response["pdus"])
# Now lets give the context some state
mock_handler.get_state_for_pdu.return_value = (
defer.succeed([
make_pdu(
event_id="the-pdu-id",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.topic",
origin_server_ts=123456789000,
depth=1,
content={"topic": "The topic"},
state_key="",
power_level=1000,
prev_state="last-pdu-id",
),
])
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
@defer.inlineCallbacks
def test_get_pdu(self):
mock_handler = Mock(spec=[
"get_persisted_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(None)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(404, code)
# Now insert such a PDU
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(
make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
)
)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
self.assertEquals("m.text", response["pdus"][0]["type"])
@defer.inlineCallbacks
def test_send_pdu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
pdu = make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
)
yield self.federation.send_pdu(pdu, ["remote"])
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin_server_ts": 1000000,
"origin": "test",
"pdus": [
pdu.get_pdu_json(),
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_send_edu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
)
yield self.federation.send_edu(
destination="remote",
edu_type="m.test",
content={"testing": "content here"},
)
# MockClock ensures we can guess these timestamps
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin": "test",
"origin_server_ts": 1000000,
"pdus": [],
"edus": [
{
"edu_type": "m.test",
"content": {"testing": "content here"},
}
],
'pdu_failures': [],
},
json_data_callback=ANY,
long_retries=True,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()
recv_observer.return_value = defer.succeed(())
self.federation.register_edu_handler("m.test", recv_observer)
yield self.mock_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1001000/",
"""{
"origin": "remote",
"origin_server_ts": 1001000,
"pdus": [],
"edus": [
{
"origin": "remote",
"destination": "test",
"edu_type": "m.test",
"content": {"testing": "reply here"}
}
]
}"""
)
recv_observer.assert_called_with(
"remote", {"testing": "reply here"}
)
@defer.inlineCallbacks
def test_send_query(self):
self.mock_http_client.get_json.return_value = defer.succeed(
{"your": "response"}
)
response = yield self.federation.make_query(
destination="remote",
query_type="a-question",
args={"one": "1", "two": "2"},
)
self.assertEquals({"your": "response"}, response)
self.mock_http_client.get_json.assert_called_with(
destination="remote",
path="/_matrix/federation/v1/query/a-question",
args={"one": "1", "two": "2"},
retry_on_dns_fail=True,
)
@defer.inlineCallbacks
def test_recv_query(self):
recv_handler = Mock()
recv_handler.return_value = defer.succeed({"another": "response"})
self.federation.register_query_handler("a-question", recv_handler)
code, response = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/query/a-question?three=3&four=4",
None
)
self.assertEquals(200, code)
self.assertEquals({"another": "response"}, response)
recv_handler.assert_called_with(
{"three": "3", "four": "4"}
)

View File

@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase):
} }
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
clock = Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
])
clock.time_msec.return_value = 1000000
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"get_presence_list", "get_presence_list",
"get_rooms_for_user", "get_rooms_for_user",
]), ]),
clock=Mock(spec=[ clock=clock,
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
]),
) )
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False): def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False) return Requester(UserID.from_string(myid), "", False)

View File

@ -16,10 +16,10 @@
from tests import unittest from tests import unittest
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import BaseHomeServer from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
mock_homeserver = BaseHomeServer(hostname="my.domain") mock_homeserver = HomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase): class UserIDTestCase(unittest.TestCase):
@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase):
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("") UserID.from_string("")
def test_build(self): def test_build(self):
user = UserID("5678efgh", "my.domain") user = UserID("5678efgh", "my.domain")

View File

@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -80,6 +82,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
fed = kargs.get("resource_for_federation", None)
if fed:
server.register_servlets(
hs,
resource=fed,
authenticator=server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(),
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent
),
)
defer.returnValue(hs) defer.returnValue(hs)