Merge pull request #4428 from matrix-org/rav/matrix_federation_agent
Move SRV magic into an Agent-like thing
This commit is contained in:
commit
a0ae475219
|
@ -0,0 +1 @@
|
|||
Move SRV logic into the Agent layer
|
|
@ -13,15 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.internet.error import ConnectError
|
||||
|
||||
from synapse.http.federation.srv_resolver import Server, resolve_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
|
|||
))
|
||||
|
||||
return host, port
|
||||
|
||||
|
||||
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
||||
timeout=None):
|
||||
"""Construct an endpoint for the given matrix destination.
|
||||
|
||||
Args:
|
||||
reactor: Twisted reactor.
|
||||
destination (unicode): The name of the server to connect to.
|
||||
tls_client_options_factory
|
||||
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
||||
Factory which generates TLS options for client connections.
|
||||
timeout (int): connection timeout in seconds
|
||||
"""
|
||||
|
||||
domain, port = parse_server_name(destination)
|
||||
|
||||
endpoint_kw_args = {}
|
||||
|
||||
if timeout is not None:
|
||||
endpoint_kw_args.update(timeout=timeout)
|
||||
|
||||
if tls_client_options_factory is None:
|
||||
transport_endpoint = HostnameEndpoint
|
||||
default_port = 8008
|
||||
else:
|
||||
# the SNI string should be the same as the Host header, minus the port.
|
||||
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
|
||||
# the Host header and SNI should therefore be the server_name of the remote
|
||||
# server.
|
||||
tls_options = tls_client_options_factory.get_options(domain)
|
||||
|
||||
def transport_endpoint(reactor, host, port, timeout):
|
||||
return wrapClientTLS(
|
||||
tls_options,
|
||||
HostnameEndpoint(reactor, host, port, timeout=timeout),
|
||||
)
|
||||
default_port = 8448
|
||||
|
||||
if port is None:
|
||||
return SRVClientEndpoint(
|
||||
reactor, "matrix", domain, protocol="tcp",
|
||||
default_port=default_port, endpoint=transport_endpoint,
|
||||
endpoint_kw_args=endpoint_kw_args
|
||||
)
|
||||
else:
|
||||
return transport_endpoint(
|
||||
reactor, domain, port, **endpoint_kw_args
|
||||
)
|
||||
|
||||
|
||||
class SRVClientEndpoint(object):
|
||||
"""An endpoint which looks up SRV records for a service.
|
||||
Cycles through the list of servers starting with each call to connect
|
||||
picking the next server.
|
||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, service, domain, protocol="tcp",
|
||||
default_port=None, endpoint=HostnameEndpoint,
|
||||
endpoint_kw_args={}):
|
||||
self.reactor = reactor
|
||||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
||||
|
||||
if default_port is not None:
|
||||
self.default_server = Server(
|
||||
host=domain,
|
||||
port=default_port,
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.endpoint_kw_args = endpoint_kw_args
|
||||
|
||||
self.servers = None
|
||||
self.used_servers = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_servers(self):
|
||||
self.used_servers = []
|
||||
self.servers = yield resolve_service(self.service_name)
|
||||
|
||||
def pick_server(self):
|
||||
if not self.servers:
|
||||
if self.used_servers:
|
||||
self.servers = self.used_servers
|
||||
self.used_servers = []
|
||||
self.servers.sort()
|
||||
elif self.default_server:
|
||||
return self.default_server
|
||||
else:
|
||||
raise ConnectError(
|
||||
"No server available for %s" % self.service_name
|
||||
)
|
||||
|
||||
# look for all servers with the same priority
|
||||
min_priority = self.servers[0].priority
|
||||
weight_indexes = list(
|
||||
(index, server.weight + 1)
|
||||
for index, server in enumerate(self.servers)
|
||||
if server.priority == min_priority
|
||||
)
|
||||
|
||||
total_weight = sum(weight for index, weight in weight_indexes)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
for index, weight in weight_indexes:
|
||||
target_weight -= weight
|
||||
if target_weight <= 0:
|
||||
server = self.servers[index]
|
||||
# XXX: this looks totally dubious:
|
||||
#
|
||||
# (a) we never reuse a server until we have been through
|
||||
# all of the servers at the same priority, so if the
|
||||
# weights are A: 100, B:1, we always do ABABAB instead of
|
||||
# AAAA...AAAB (approximately).
|
||||
#
|
||||
# (b) After using all the servers at the lowest priority,
|
||||
# we move onto the next priority. We should only use the
|
||||
# second priority if servers at the top priority are
|
||||
# unreachable.
|
||||
#
|
||||
del self.servers[index]
|
||||
self.used_servers.append(server)
|
||||
return server
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
if self.servers is None:
|
||||
yield self.fetch_servers()
|
||||
server = self.pick_server()
|
||||
logger.info("Connecting to %s:%s", server.host, server.port)
|
||||
endpoint = self.endpoint(
|
||||
self.reactor, server.host, server.port, **self.endpoint_kw_args
|
||||
)
|
||||
connection = yield endpoint.connect(protocolFactory)
|
||||
defer.returnValue(connection)
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.endpoint import parse_server_name
|
||||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class MatrixFederationAgent(object):
|
||||
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
||||
server and send an HTTP request to it.
|
||||
|
||||
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||
|
||||
Args:
|
||||
reactor (IReactor): twisted reactor to use for underlying requests
|
||||
|
||||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||
factory to use for fetching client tls options, or none to disable TLS.
|
||||
|
||||
srv_resolver (SrvResolver|None):
|
||||
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||
):
|
||||
self._reactor = reactor
|
||||
self._tls_client_options_factory = tls_client_options_factory
|
||||
if _srv_resolver is None:
|
||||
_srv_resolver = SrvResolver()
|
||||
self._srv_resolver = _srv_resolver
|
||||
|
||||
self._pool = HTTPConnectionPool(reactor)
|
||||
self._pool.retryAutomatically = False
|
||||
self._pool.maxPersistentPerHost = 5
|
||||
self._pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||
"""
|
||||
Args:
|
||||
method (bytes): HTTP method: GET/POST/etc
|
||||
|
||||
uri (bytes): Absolute URI to be retrieved
|
||||
|
||||
headers (twisted.web.http_headers.Headers|None):
|
||||
HTTP headers to send with the request, or None to
|
||||
send no extra headers.
|
||||
|
||||
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||
An object which can generate bytes to make up the
|
||||
body of this request (for example, the properly encoded contents of
|
||||
a file for a file upload). Or None if the request is to have
|
||||
no body.
|
||||
|
||||
Returns:
|
||||
Deferred[twisted.web.iweb.IResponse]:
|
||||
fires when the header of the response has been received (regardless of the
|
||||
response status code). Fails if there is any problem which prevents that
|
||||
response from being received (including problems that prevent the request
|
||||
from being sent).
|
||||
"""
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
server_name_bytes = parsed_uri.netloc
|
||||
host, port = parse_server_name(server_name_bytes.decode("ascii"))
|
||||
|
||||
# XXX disabling TLS is really only supported here for the benefit of the
|
||||
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||
# the code support the unit tests.
|
||||
if self._tls_client_options_factory is None:
|
||||
tls_options = None
|
||||
else:
|
||||
tls_options = self._tls_client_options_factory.get_options(host)
|
||||
|
||||
if port is not None:
|
||||
target = (host, port)
|
||||
else:
|
||||
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
|
||||
if not server_list:
|
||||
target = (host, 8448)
|
||||
logger.debug("No SRV record for %s, using %s", host, target)
|
||||
else:
|
||||
target = pick_server_from_list(server_list)
|
||||
|
||||
class EndpointFactory(object):
|
||||
@staticmethod
|
||||
def endpointForURI(_uri):
|
||||
logger.info("Connecting to %s:%s", target[0], target[1])
|
||||
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
|
||||
if tls_options is not None:
|
||||
ep = wrapClientTLS(tls_options, ep)
|
||||
return ep
|
||||
|
||||
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
||||
res = yield make_deferred_yieldable(
|
||||
agent.request(method, uri, headers, bodyProducer)
|
||||
)
|
||||
defer.returnValue(res)
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
import attr
|
||||
|
@ -51,74 +52,118 @@ class Server(object):
|
|||
expires = attr.ib(default=0)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
||||
"""Look up a SRV record, with caching
|
||||
def pick_server_from_list(server_list):
|
||||
"""Randomly choose a server from the server list
|
||||
|
||||
Args:
|
||||
server_list (list[Server]): list of candidate servers
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int]: (host, port) pair for the chosen server
|
||||
"""
|
||||
if not server_list:
|
||||
raise RuntimeError("pick_server_from_list called with empty list")
|
||||
|
||||
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
||||
# cache of servers known to be "down" and filter them out
|
||||
|
||||
min_priority = min(s.priority for s in server_list)
|
||||
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
||||
total_weight = sum(s.weight for s in eligible_servers)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
|
||||
for s in eligible_servers:
|
||||
target_weight -= s.weight
|
||||
|
||||
if target_weight <= 0:
|
||||
return s.host, s.port
|
||||
|
||||
# this should be impossible.
|
||||
raise RuntimeError(
|
||||
"pick_server_from_list got to end of eligible server list.",
|
||||
)
|
||||
|
||||
|
||||
class SrvResolver(object):
|
||||
"""Interface to the dns client to do SRV lookups, with result caching.
|
||||
|
||||
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
||||
but the cache never gets populated), so we add our own caching layer here.
|
||||
|
||||
Args:
|
||||
service_name (unicode|bytes): record to look up
|
||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||
cache (dict): cache object
|
||||
clock (object): clock implementation. must provide a time() method.
|
||||
|
||||
Returns:
|
||||
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
|
||||
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||
"""
|
||||
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
|
||||
# byteses; however they will obviously end up as separate entries in the cache. We
|
||||
# should pick one form and stick with it.
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||
self._dns_client = dns_client
|
||||
self._cache = cache
|
||||
self._get_time = get_time
|
||||
|
||||
try:
|
||||
answers, _, _ = yield make_deferred_yieldable(
|
||||
dns_client.lookupService(service_name),
|
||||
)
|
||||
except DNSNameError:
|
||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||
# the negative-TTL value.
|
||||
defer.returnValue([])
|
||||
except DomainError as e:
|
||||
# We failed to resolve the name (other than a NameError)
|
||||
# Try something in the cache, else rereaise
|
||||
cache_entry = cache.get(service_name, None)
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(self, service_name):
|
||||
"""Look up a SRV record
|
||||
|
||||
Args:
|
||||
service_name (bytes): record to look up
|
||||
|
||||
Returns:
|
||||
Deferred[list[Server]]:
|
||||
a list of the SRV records, or an empty list if none found
|
||||
"""
|
||||
now = int(self._get_time())
|
||||
|
||||
if not isinstance(service_name, bytes):
|
||||
raise TypeError("%r is not a byte string" % (service_name,))
|
||||
|
||||
cache_entry = self._cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
logger.warn(
|
||||
"Failed to resolve %r, falling back to cache. %r",
|
||||
service_name, e
|
||||
if all(s.expires > now for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
try:
|
||||
answers, _, _ = yield make_deferred_yieldable(
|
||||
self._dns_client.lookupService(service_name),
|
||||
)
|
||||
defer.returnValue(list(cache_entry))
|
||||
else:
|
||||
raise e
|
||||
except DNSNameError:
|
||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||
# the negative-TTL value.
|
||||
defer.returnValue([])
|
||||
except DomainError as e:
|
||||
# We failed to resolve the name (other than a NameError)
|
||||
# Try something in the cache, else rereaise
|
||||
cache_entry = self._cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
logger.warn(
|
||||
"Failed to resolve %r, falling back to cache. %r",
|
||||
service_name, e
|
||||
)
|
||||
defer.returnValue(list(cache_entry))
|
||||
else:
|
||||
raise e
|
||||
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b'.')):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b'.')):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
|
||||
servers = []
|
||||
servers = []
|
||||
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
|
||||
payload = answer.payload
|
||||
payload = answer.payload
|
||||
|
||||
servers.append(Server(
|
||||
host=payload.target.name,
|
||||
port=payload.port,
|
||||
priority=payload.priority,
|
||||
weight=payload.weight,
|
||||
expires=int(clock.time()) + answer.ttl,
|
||||
))
|
||||
servers.append(Server(
|
||||
host=payload.target.name,
|
||||
port=payload.port,
|
||||
priority=payload.priority,
|
||||
weight=payload.weight,
|
||||
expires=now + answer.ttl,
|
||||
))
|
||||
|
||||
servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
|
||||
cache[service_name] = list(servers)
|
||||
defer.returnValue(servers)
|
||||
self._cache[service_name] = list(servers)
|
||||
defer.returnValue(servers)
|
||||
|
|
|
@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
|
|||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.task import _EPSILON, Cooperator
|
||||
from twisted.web._newclient import ResponseDone
|
||||
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
|
||||
from twisted.web.client import FileBodyProducer
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
import synapse.metrics
|
||||
|
@ -44,7 +44,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -66,20 +66,6 @@ else:
|
|||
MAXINT = sys.maxint
|
||||
|
||||
|
||||
class MatrixFederationEndpointFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.reactor = hs.get_reactor()
|
||||
self.tls_client_options_factory = hs.tls_client_options_factory
|
||||
|
||||
def endpointForURI(self, uri):
|
||||
destination = uri.netloc.decode('ascii')
|
||||
|
||||
return matrix_federation_endpoint(
|
||||
self.reactor, destination, timeout=10,
|
||||
tls_client_options_factory=self.tls_client_options_factory
|
||||
)
|
||||
|
||||
|
||||
_next_id = 1
|
||||
|
||||
|
||||
|
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
|
|||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
reactor = hs.get_reactor()
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.retryAutomatically = False
|
||||
pool.maxPersistentPerHost = 5
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
self.agent = Agent.usingEndpointFactory(
|
||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
hs.get_reactor(),
|
||||
hs.tls_client_options_factory,
|
||||
)
|
||||
self.clock = hs.get_clock()
|
||||
self._store = hs.get_datastore()
|
||||
|
@ -316,9 +300,9 @@ class MatrixFederationHttpClient(object):
|
|||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Sending request: %s %s",
|
||||
"{%s} [%s] Sending request: %s %s; timeout %fs",
|
||||
request.txn_id, request.destination, request.method,
|
||||
url_str,
|
||||
url_str, _sec_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -338,12 +322,11 @@ class MatrixFederationHttpClient(object):
|
|||
reactor=self.hs.get_reactor(),
|
||||
)
|
||||
|
||||
response = yield make_deferred_yieldable(
|
||||
request_deferred,
|
||||
)
|
||||
response = yield request_deferred
|
||||
except DNSLookupError as e:
|
||||
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
||||
except Exception as e:
|
||||
logger.info("Failed to send request: %s", e)
|
||||
raise_from(RequestSendFailed(e, can_retry=True), e)
|
||||
|
||||
logger.info(
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import treq
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.test.ssl_helpers import ServerTLSContext
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatrixFederationAgentTests(TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
self.mock_resolver = Mock()
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
reactor=self.reactor,
|
||||
tls_client_options_factory=ClientTLSOptionsFactory(None),
|
||||
_srv_resolver=self.mock_resolver,
|
||||
)
|
||||
|
||||
def _make_connection(self, client_factory):
|
||||
"""Builds a test server, and completes the outgoing client connection
|
||||
|
||||
Returns:
|
||||
HTTPChannel: the test server
|
||||
"""
|
||||
|
||||
# build the test server
|
||||
server_tls_protocol = _build_test_server()
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
|
||||
|
||||
# tell the server tls protocol to send its stuff back to the client, too
|
||||
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||
|
||||
# finally, give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
return server_tls_protocol.wrappedProtocol
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_get_request(self, uri):
|
||||
"""
|
||||
Sends a simple GET request via the agent, and checks its logcontext management
|
||||
"""
|
||||
with LoggingContext("one") as context:
|
||||
fetch_d = self.agent.request(b'GET', uri)
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(fetch_d)
|
||||
|
||||
# should have reset logcontext to the sentinel
|
||||
_check_logcontext(LoggingContext.sentinel)
|
||||
|
||||
try:
|
||||
fetch_res = yield fetch_d
|
||||
defer.returnValue(fetch_res)
|
||||
finally:
|
||||
_check_logcontext(context)
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
happy-path test of a GET request
|
||||
"""
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(client_factory)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'testserv:8448']
|
||||
)
|
||||
content = request.content.read()
|
||||
self.assertEqual(content, b'')
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# send the headers
|
||||
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
|
||||
request.write('')
|
||||
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
response = self.successResultOf(test_d)
|
||||
|
||||
# that should give us a Response object
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
# Send the body
|
||||
request.write('{ "a": 1 }'.encode('ascii'))
|
||||
request.finish()
|
||||
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# check it can be read
|
||||
json = self.successResultOf(treq.json_content(response))
|
||||
self.assertEqual(json, {"a": 1})
|
||||
|
||||
|
||||
def _check_logcontext(context):
|
||||
current = LoggingContext.current_context()
|
||||
if current is not context:
|
||||
raise AssertionError(
|
||||
"Expected logcontext %s but was %s" % (context, current),
|
||||
)
|
||||
|
||||
|
||||
def _build_test_server():
|
||||
"""Construct a test server
|
||||
|
||||
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||
|
||||
Returns:
|
||||
TLSMemoryBIOProtocol
|
||||
"""
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
server_tls_factory = TLSMemoryBIOFactory(
|
||||
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
|
||||
)
|
||||
|
||||
return server_tls_factory.buildProtocol(None)
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
|
@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
|
|||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import dns, error
|
||||
|
||||
from synapse.http.federation.srv_resolver import resolve_service
|
||||
from synapse.http.federation.srv_resolver import SrvResolver
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests import unittest
|
||||
|
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock.lookupService.return_value = result_deferred
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
|
||||
with LoggingContext("one") as ctx:
|
||||
resolve_d = resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
|
@ -83,16 +83,15 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 0
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||
|
||||
|
@ -106,17 +105,18 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
|
||||
resolver = SrvResolver(
|
||||
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
||||
)
|
||||
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
self.assertFalse(dns_client_mock.lookupService.called)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
|
@ -128,12 +128,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
with self.assertRaises(error.DNSServerError):
|
||||
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
|
||||
yield resolver.resolve_service(service_name)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_name_error(self):
|
||||
|
@ -141,13 +142,12 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
self.assertEquals(len(servers), 0)
|
||||
self.assertEquals(len(cache), 0)
|
||||
|
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
resolve_d = resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
# returning a single "." should make the lookup fail with a ConenctError
|
||||
|
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
resolve_d = resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
lookup_deferred.callback((
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import TimeoutError
|
||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
@ -26,11 +27,20 @@ from synapse.http.matrixfederationclient import (
|
|||
MatrixFederationHttpClient,
|
||||
MatrixFederationRequest,
|
||||
)
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.server import FakeTransport
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
def check_logcontext(context):
|
||||
current = LoggingContext.current_context()
|
||||
if current is not context:
|
||||
raise AssertionError(
|
||||
"Expected logcontext %s but was %s" % (context, current),
|
||||
)
|
||||
|
||||
|
||||
class FederationClientTests(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
|
@ -43,6 +53,70 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.cl = MatrixFederationHttpClient(self.hs)
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
def test_client_get(self):
|
||||
"""
|
||||
happy-path test of a GET request
|
||||
"""
|
||||
@defer.inlineCallbacks
|
||||
def do_request():
|
||||
with LoggingContext("one") as context:
|
||||
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(fetch_d)
|
||||
|
||||
# should have reset logcontext to the sentinel
|
||||
check_logcontext(LoggingContext.sentinel)
|
||||
|
||||
try:
|
||||
fetch_res = yield fetch_d
|
||||
defer.returnValue(fetch_res)
|
||||
finally:
|
||||
check_logcontext(context)
|
||||
|
||||
test_d = do_request()
|
||||
|
||||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8008)
|
||||
|
||||
# complete the connection and wire it up to a fake transport
|
||||
protocol = factory.buildProtocol(None)
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
# that should have made it send the request to the transport
|
||||
self.assertRegex(transport.value(), b"^GET /foo/bar")
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Send it the HTTP response
|
||||
res_json = '{ "a": 1 }'.encode('ascii')
|
||||
protocol.dataReceived(
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Server: Fake\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: %i\r\n"
|
||||
b"\r\n"
|
||||
b"%s" % (len(res_json), res_json)
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
res = self.successResultOf(test_d)
|
||||
|
||||
# check the response is as expected
|
||||
self.assertEqual(res, {"a": 1})
|
||||
|
||||
def test_dns_error(self):
|
||||
"""
|
||||
If the DNS lookup returns an error, it will bubble up.
|
||||
|
@ -54,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
||||
|
||||
def test_client_connection_refused(self):
|
||||
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
||||
|
||||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(d)
|
||||
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8008)
|
||||
e = Exception("go away")
|
||||
factory.clientConnectionFailed(None, e)
|
||||
self.pump(0.5)
|
||||
|
||||
f = self.failureResultOf(d)
|
||||
|
||||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
self.assertIs(f.value.inner_exception, e)
|
||||
|
||||
def test_client_never_connect(self):
|
||||
"""
|
||||
If the HTTP request is not connected and is timed out, it'll give a
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
from six import text_type
|
||||
|
@ -22,6 +23,8 @@ from synapse.util import Clock
|
|||
|
||||
from tests.utils import setup_test_homeserver as _sth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimedOutException(Exception):
|
||||
"""
|
||||
|
@ -339,7 +342,7 @@ def get_clock():
|
|||
return (clock, hs_clock)
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(cmp=False)
|
||||
class FakeTransport(object):
|
||||
"""
|
||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||
|
@ -414,6 +417,11 @@ class FakeTransport(object):
|
|||
self.buffer = self.buffer + byt
|
||||
|
||||
def _write():
|
||||
if not self.buffer:
|
||||
# nothing to do. Don't write empty buffers: it upsets the
|
||||
# TLSMemoryBIOProtocol
|
||||
return
|
||||
|
||||
if getattr(self.other, "transport") is not None:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
|
@ -421,7 +429,10 @@ class FakeTransport(object):
|
|||
|
||||
self._reactor.callLater(0.0, _write)
|
||||
|
||||
_write()
|
||||
# always actually do the write asynchronously. Some protocols (notably the
|
||||
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
|
||||
# still doing a write. Doing a callLater here breaks the cycle.
|
||||
self._reactor.callLater(0.0, _write)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
for x in seq:
|
||||
|
|
Loading…
Reference in New Issue