Merge pull request #5864 from matrix-org/erikj/reliable_lookups
Refactor MatrixFederationAgent to retry SRV.
This commit is contained in:
commit
dfd10f5133
|
@ -0,0 +1 @@
|
|||
Correctly retry all hosts returned from SRV when we fail to connect.
|
|
@ -14,21 +14,21 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
import attr
|
||||
from netaddr import IPAddress
|
||||
from netaddr import AddrFormatError, IPAddress
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||
from twisted.web.client import Agent, HTTPConnectionPool
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IAgent
|
||||
from twisted.web.iweb import IAgent, IAgentEndpointFactory
|
||||
|
||||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
||||
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@implementer(IAgent)
|
||||
class MatrixFederationAgent(object):
|
||||
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
||||
server and send an HTTP request to it.
|
||||
"""An Agent-like thing which provides a `request` method which correctly
|
||||
handles resolving matrix server names when using matrix://. Handles standard
|
||||
https URIs as normal.
|
||||
|
||||
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||
|
||||
|
@ -65,17 +66,19 @@ class MatrixFederationAgent(object):
|
|||
):
|
||||
self._reactor = reactor
|
||||
self._clock = Clock(reactor)
|
||||
|
||||
self._tls_client_options_factory = tls_client_options_factory
|
||||
if _srv_resolver is None:
|
||||
_srv_resolver = SrvResolver()
|
||||
self._srv_resolver = _srv_resolver
|
||||
|
||||
self._pool = HTTPConnectionPool(reactor)
|
||||
self._pool.retryAutomatically = False
|
||||
self._pool.maxPersistentPerHost = 5
|
||||
self._pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
self._agent = Agent.usingEndpointFactory(
|
||||
self._reactor,
|
||||
MatrixHostnameEndpointFactory(
|
||||
reactor, tls_client_options_factory, _srv_resolver
|
||||
),
|
||||
pool=self._pool,
|
||||
)
|
||||
|
||||
if _well_known_resolver is None:
|
||||
_well_known_resolver = WellKnownResolver(
|
||||
self._reactor,
|
||||
|
@ -93,19 +96,15 @@ class MatrixFederationAgent(object):
|
|||
"""
|
||||
Args:
|
||||
method (bytes): HTTP method: GET/POST/etc
|
||||
|
||||
uri (bytes): Absolute URI to be retrieved
|
||||
|
||||
headers (twisted.web.http_headers.Headers|None):
|
||||
HTTP headers to send with the request, or None to
|
||||
send no extra headers.
|
||||
|
||||
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||
An object which can generate bytes to make up the
|
||||
body of this request (for example, the properly encoded contents of
|
||||
a file for a file upload). Or None if the request is to have
|
||||
no body.
|
||||
|
||||
Returns:
|
||||
Deferred[twisted.web.iweb.IResponse]:
|
||||
fires when the header of the response has been received (regardless of the
|
||||
|
@ -113,210 +112,207 @@ class MatrixFederationAgent(object):
|
|||
response from being received (including problems that prevent the request
|
||||
from being sent).
|
||||
"""
|
||||
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
|
||||
res = yield self._route_matrix_uri(parsed_uri)
|
||||
# We use urlparse as that will set `port` to None if there is no
|
||||
# explicit port.
|
||||
parsed_uri = urllib.parse.urlparse(uri)
|
||||
|
||||
# set up the TLS connection params
|
||||
# If this is a matrix:// URI check if the server has delegated matrix
|
||||
# traffic using well-known delegation.
|
||||
#
|
||||
# XXX disabling TLS is really only supported here for the benefit of the
|
||||
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||
# the code support the unit tests.
|
||||
if self._tls_client_options_factory is None:
|
||||
tls_options = None
|
||||
else:
|
||||
tls_options = self._tls_client_options_factory.get_options(
|
||||
res.tls_server_name.decode("ascii")
|
||||
# We have to do this here and not in the endpoint as we need to rewrite
|
||||
# the host header with the delegated server name.
|
||||
delegated_server = None
|
||||
if (
|
||||
parsed_uri.scheme == b"matrix"
|
||||
and not _is_ip_literal(parsed_uri.hostname)
|
||||
and not parsed_uri.port
|
||||
):
|
||||
well_known_result = yield self._well_known_resolver.get_well_known(
|
||||
parsed_uri.hostname
|
||||
)
|
||||
delegated_server = well_known_result.delegated_server
|
||||
|
||||
# make sure that the Host header is set correctly
|
||||
if delegated_server:
|
||||
# Ok, the server has delegated matrix traffic to somewhere else, so
|
||||
# lets rewrite the URL to replace the server with the delegated
|
||||
# server name.
|
||||
uri = urllib.parse.urlunparse(
|
||||
(
|
||||
parsed_uri.scheme,
|
||||
delegated_server,
|
||||
parsed_uri.path,
|
||||
parsed_uri.params,
|
||||
parsed_uri.query,
|
||||
parsed_uri.fragment,
|
||||
)
|
||||
)
|
||||
parsed_uri = urllib.parse.urlparse(uri)
|
||||
|
||||
# We need to make sure the host header is set to the netloc of the
|
||||
# server.
|
||||
if headers is None:
|
||||
headers = Headers()
|
||||
else:
|
||||
headers = headers.copy()
|
||||
|
||||
if not headers.hasHeader(b"host"):
|
||||
headers.addRawHeader(b"host", res.host_header)
|
||||
headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||
|
||||
class EndpointFactory(object):
|
||||
@staticmethod
|
||||
def endpointForURI(_uri):
|
||||
ep = LoggingHostnameEndpoint(
|
||||
self._reactor, res.target_host, res.target_port
|
||||
)
|
||||
if tls_options is not None:
|
||||
ep = wrapClientTLS(tls_options, ep)
|
||||
return ep
|
||||
|
||||
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
||||
res = yield make_deferred_yieldable(
|
||||
agent.request(method, uri, headers, bodyProducer)
|
||||
self._agent.request(method, uri, headers, bodyProducer)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
|
||||
"""Helper for `request`: determine the routing for a Matrix URI
|
||||
|
||||
Args:
|
||||
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
|
||||
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
|
||||
if there is no explicit port given.
|
||||
@implementer(IAgentEndpointFactory)
|
||||
class MatrixHostnameEndpointFactory(object):
|
||||
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
|
||||
"""
|
||||
|
||||
lookup_well_known (bool): True if we should look up the .well-known file if
|
||||
there is no SRV record.
|
||||
def __init__(self, reactor, tls_client_options_factory, srv_resolver):
|
||||
self._reactor = reactor
|
||||
self._tls_client_options_factory = tls_client_options_factory
|
||||
|
||||
Returns:
|
||||
Deferred[_RoutingResult]
|
||||
"""
|
||||
# check for an IP literal
|
||||
try:
|
||||
ip_address = IPAddress(parsed_uri.host.decode("ascii"))
|
||||
except Exception:
|
||||
# not an IP address
|
||||
ip_address = None
|
||||
if srv_resolver is None:
|
||||
srv_resolver = SrvResolver()
|
||||
|
||||
if ip_address:
|
||||
port = parsed_uri.port
|
||||
if port == -1:
|
||||
port = 8448
|
||||
return _RoutingResult(
|
||||
host_header=parsed_uri.netloc,
|
||||
tls_server_name=parsed_uri.host,
|
||||
target_host=parsed_uri.host,
|
||||
target_port=port,
|
||||
)
|
||||
self._srv_resolver = srv_resolver
|
||||
|
||||
if parsed_uri.port != -1:
|
||||
# there is an explicit port
|
||||
return _RoutingResult(
|
||||
host_header=parsed_uri.netloc,
|
||||
tls_server_name=parsed_uri.host,
|
||||
target_host=parsed_uri.host,
|
||||
target_port=parsed_uri.port,
|
||||
)
|
||||
|
||||
if lookup_well_known:
|
||||
# try a .well-known lookup
|
||||
well_known_result = yield self._well_known_resolver.get_well_known(
|
||||
parsed_uri.host
|
||||
)
|
||||
well_known_server = well_known_result.delegated_server
|
||||
|
||||
if well_known_server:
|
||||
# if we found a .well-known, start again, but don't do another
|
||||
# .well-known lookup.
|
||||
|
||||
# parse the server name in the .well-known response into host/port.
|
||||
# (This code is lifted from twisted.web.client.URI.fromBytes).
|
||||
if b":" in well_known_server:
|
||||
well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
|
||||
try:
|
||||
well_known_port = int(well_known_port)
|
||||
except ValueError:
|
||||
# the part after the colon could not be parsed as an int
|
||||
# - we assume it is an IPv6 literal with no port (the closing
|
||||
# ']' stops it being parsed as an int)
|
||||
well_known_host, well_known_port = well_known_server, -1
|
||||
else:
|
||||
well_known_host, well_known_port = well_known_server, -1
|
||||
|
||||
new_uri = URI(
|
||||
scheme=parsed_uri.scheme,
|
||||
netloc=well_known_server,
|
||||
host=well_known_host,
|
||||
port=well_known_port,
|
||||
path=parsed_uri.path,
|
||||
params=parsed_uri.params,
|
||||
query=parsed_uri.query,
|
||||
fragment=parsed_uri.fragment,
|
||||
)
|
||||
|
||||
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
|
||||
return res
|
||||
|
||||
# try a SRV lookup
|
||||
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
|
||||
server_list = yield self._srv_resolver.resolve_service(service_name)
|
||||
|
||||
if not server_list:
|
||||
target_host = parsed_uri.host
|
||||
port = 8448
|
||||
logger.debug(
|
||||
"No SRV record for %s, using %s:%i",
|
||||
parsed_uri.host.decode("ascii"),
|
||||
target_host.decode("ascii"),
|
||||
port,
|
||||
)
|
||||
else:
|
||||
target_host, port = pick_server_from_list(server_list)
|
||||
logger.debug(
|
||||
"Picked %s:%i from SRV records for %s",
|
||||
target_host.decode("ascii"),
|
||||
port,
|
||||
parsed_uri.host.decode("ascii"),
|
||||
)
|
||||
|
||||
return _RoutingResult(
|
||||
host_header=parsed_uri.netloc,
|
||||
tls_server_name=parsed_uri.host,
|
||||
target_host=target_host,
|
||||
target_port=port,
|
||||
def endpointForURI(self, parsed_uri):
|
||||
return MatrixHostnameEndpoint(
|
||||
self._reactor,
|
||||
self._tls_client_options_factory,
|
||||
self._srv_resolver,
|
||||
parsed_uri,
|
||||
)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class LoggingHostnameEndpoint(object):
|
||||
"""A wrapper for HostnameEndpint which logs when it connects"""
|
||||
class MatrixHostnameEndpoint(object):
|
||||
"""An endpoint that resolves matrix:// URLs using Matrix server name
|
||||
resolution (i.e. via SRV). Does not check for well-known delegation.
|
||||
|
||||
def __init__(self, reactor, host, port, *args, **kwargs):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
|
||||
Args:
|
||||
reactor (IReactor)
|
||||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||
factory to use for fetching client tls options, or none to disable TLS.
|
||||
srv_resolver (SrvResolver): The SRV resolver to use
|
||||
parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
|
||||
to connect to.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
|
||||
self._reactor = reactor
|
||||
|
||||
self._parsed_uri = parsed_uri
|
||||
|
||||
# set up the TLS connection params
|
||||
#
|
||||
# XXX disabling TLS is really only supported here for the benefit of the
|
||||
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||
# the code support the unit tests.
|
||||
|
||||
if tls_client_options_factory is None:
|
||||
self._tls_options = None
|
||||
else:
|
||||
self._tls_options = tls_client_options_factory.get_options(
|
||||
self._parsed_uri.host.decode("ascii")
|
||||
)
|
||||
|
||||
self._srv_resolver = srv_resolver
|
||||
|
||||
def connect(self, protocol_factory):
|
||||
logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
|
||||
return self.ep.connect(protocol_factory)
|
||||
"""Implements IStreamClientEndpoint interface
|
||||
"""
|
||||
|
||||
return run_in_background(self._do_connect, protocol_factory)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_connect(self, protocol_factory):
|
||||
first_exception = None
|
||||
|
||||
server_list = yield self._resolve_server()
|
||||
|
||||
for server in server_list:
|
||||
host = server.host
|
||||
port = server.port
|
||||
|
||||
try:
|
||||
logger.info("Connecting to %s:%i", host.decode("ascii"), port)
|
||||
endpoint = HostnameEndpoint(self._reactor, host, port)
|
||||
if self._tls_options:
|
||||
endpoint = wrapClientTLS(self._tls_options, endpoint)
|
||||
result = yield make_deferred_yieldable(
|
||||
endpoint.connect(protocol_factory)
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
|
||||
)
|
||||
if not first_exception:
|
||||
first_exception = e
|
||||
|
||||
# We return the first failure because that's probably the most interesting.
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
|
||||
# This shouldn't happen as we should always have at least one host/port
|
||||
# to try and if that doesn't work then we'll have an exception.
|
||||
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _resolve_server(self):
|
||||
"""Resolves the server name to a list of hosts and ports to attempt to
|
||||
connect to.
|
||||
|
||||
Returns:
|
||||
Deferred[list[Server]]
|
||||
"""
|
||||
|
||||
if self._parsed_uri.scheme != b"matrix":
|
||||
return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
|
||||
|
||||
# Note: We don't do well-known lookup as that needs to have happened
|
||||
# before now, due to needing to rewrite the Host header of the HTTP
|
||||
# request.
|
||||
|
||||
# We reparse the URI so that defaultPort is -1 rather than 80
|
||||
parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
|
||||
|
||||
host = parsed_uri.hostname
|
||||
port = parsed_uri.port
|
||||
|
||||
# If there is an explicit port or the host is an IP address we bypass
|
||||
# SRV lookups and just use the given host/port.
|
||||
if port or _is_ip_literal(host):
|
||||
return [Server(host, port or 8448)]
|
||||
|
||||
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
|
||||
|
||||
if server_list:
|
||||
return server_list
|
||||
|
||||
# No SRV records, so we fallback to host and 8448
|
||||
return [Server(host, 8448)]
|
||||
|
||||
|
||||
@attr.s
|
||||
class _RoutingResult(object):
|
||||
"""The result returned by `_route_matrix_uri`.
|
||||
def _is_ip_literal(host):
|
||||
"""Test if the given host name is either an IPv4 or IPv6 literal.
|
||||
|
||||
Contains the parameters needed to direct a federation connection to a particular
|
||||
server.
|
||||
Args:
|
||||
host (bytes)
|
||||
|
||||
Where a SRV record points to several servers, this object contains a single server
|
||||
chosen from the list.
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
host_header = attr.ib()
|
||||
"""
|
||||
The value we should assign to the Host header (host:port from the matrix
|
||||
URI, or .well-known).
|
||||
host = host.decode("ascii")
|
||||
|
||||
:type: bytes
|
||||
"""
|
||||
|
||||
tls_server_name = attr.ib()
|
||||
"""
|
||||
The server name we should set in the SNI (typically host, without port, from the
|
||||
matrix URI or .well-known)
|
||||
|
||||
:type: bytes
|
||||
"""
|
||||
|
||||
target_host = attr.ib()
|
||||
"""
|
||||
The hostname (or IP literal) we should route the TCP connection to (the target of the
|
||||
SRV record, or the hostname from the URL/.well-known)
|
||||
|
||||
:type: bytes
|
||||
"""
|
||||
|
||||
target_port = attr.ib()
|
||||
"""
|
||||
The port we should route the TCP connection to (the target of the SRV record, or
|
||||
the port from the URL/.well-known, or 8448)
|
||||
|
||||
:type: int
|
||||
"""
|
||||
try:
|
||||
IPAddress(host)
|
||||
return True
|
||||
except AddrFormatError:
|
||||
return False
|
||||
|
|
|
@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
SERVER_CACHE = {}
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class Server(object):
|
||||
"""
|
||||
Our record of an individual server which can be tried to reach a destination.
|
||||
|
@ -53,34 +53,47 @@ class Server(object):
|
|||
expires = attr.ib(default=0)
|
||||
|
||||
|
||||
def pick_server_from_list(server_list):
|
||||
"""Randomly choose a server from the server list
|
||||
|
||||
Args:
|
||||
server_list (list[Server]): list of candidate servers
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int]: (host, port) pair for the chosen server
|
||||
def _sort_server_list(server_list):
|
||||
"""Given a list of SRV records sort them into priority order and shuffle
|
||||
each priority with the given weight.
|
||||
"""
|
||||
if not server_list:
|
||||
raise RuntimeError("pick_server_from_list called with empty list")
|
||||
priority_map = {}
|
||||
|
||||
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
||||
# cache of servers known to be "down" and filter them out
|
||||
for server in server_list:
|
||||
priority_map.setdefault(server.priority, []).append(server)
|
||||
|
||||
min_priority = min(s.priority for s in server_list)
|
||||
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
||||
total_weight = sum(s.weight for s in eligible_servers)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
results = []
|
||||
for priority in sorted(priority_map):
|
||||
servers = priority_map[priority]
|
||||
|
||||
for s in eligible_servers:
|
||||
target_weight -= s.weight
|
||||
# This algorithms roughly follows the algorithm described in RFC2782,
|
||||
# changed to remove an off-by-one error.
|
||||
#
|
||||
# N.B. Weights can be zero, which means that they should be picked
|
||||
# rarely.
|
||||
|
||||
if target_weight <= 0:
|
||||
return s.host, s.port
|
||||
total_weight = sum(s.weight for s in servers)
|
||||
|
||||
# this should be impossible.
|
||||
raise RuntimeError("pick_server_from_list got to end of eligible server list.")
|
||||
# Total weight can become zero if there are only zero weight servers
|
||||
# left, which we handle by just shuffling and appending to the results.
|
||||
while servers and total_weight:
|
||||
target_weight = random.randint(1, total_weight)
|
||||
|
||||
for s in servers:
|
||||
target_weight -= s.weight
|
||||
|
||||
if target_weight <= 0:
|
||||
break
|
||||
|
||||
results.append(s)
|
||||
servers.remove(s)
|
||||
total_weight -= s.weight
|
||||
|
||||
if servers:
|
||||
random.shuffle(servers)
|
||||
results.extend(servers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class SrvResolver(object):
|
||||
|
@ -120,7 +133,7 @@ class SrvResolver(object):
|
|||
if cache_entry:
|
||||
if all(s.expires > now for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
return servers
|
||||
return _sort_server_list(servers)
|
||||
|
||||
try:
|
||||
answers, _, _ = yield make_deferred_yieldable(
|
||||
|
@ -169,4 +182,4 @@ class SrvResolver(object):
|
|||
)
|
||||
|
||||
self._cache[service_name] = list(servers)
|
||||
return servers
|
||||
return _sort_server_list(servers)
|
||||
|
|
|
@ -20,7 +20,6 @@ from synapse.federation.federation_server import server_matches_acl_event
|
|||
from tests import unittest
|
||||
|
||||
|
||||
@unittest.DEBUG
|
||||
class ServerACLsTestCase(unittest.TestCase):
|
||||
def test_blacklisted_server(self):
|
||||
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||
|
|
|
@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
|
|||
from synapse.logging.context import LoggingContext
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
|
||||
from tests import unittest
|
||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
from tests.utils import default_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -67,7 +67,7 @@ def get_connection_factory():
|
|||
return test_server_connection_factory
|
||||
|
||||
|
||||
class MatrixFederationAgentTests(TestCase):
|
||||
class MatrixFederationAgentTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
|
@ -1069,8 +1069,64 @@ class MatrixFederationAgentTests(TestCase):
|
|||
r = self.successResultOf(fetch_d)
|
||||
self.assertEqual(r.delegated_server, None)
|
||||
|
||||
def test_srv_fallbacks(self):
|
||||
"""Test that other SRV results are tried if the first one fails.
|
||||
"""
|
||||
|
||||
class TestCachePeriodFromHeaders(TestCase):
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
||||
Server(host=b"target.com", port=8443),
|
||||
Server(host=b"target.com", port=8444),
|
||||
]
|
||||
self.reactor.lookups["target.com"] = "1.2.3.4"
|
||||
|
||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||
b"_matrix._tcp.testserv"
|
||||
)
|
||||
|
||||
# We should see an attempt to connect to the first server
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 8443)
|
||||
|
||||
# Fonx the connection
|
||||
client_factory.clientConnectionFailed(None, Exception("nope"))
|
||||
|
||||
# There's a 300ms delay in HostnameEndpoint
|
||||
self.reactor.pump((0.4,))
|
||||
|
||||
# Hasn't failed yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# We shouldnow see an attempt to connect to the second server
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 8444)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(client_factory, expected_sni=b"testserv")
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/foo/bar")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
|
||||
|
||||
# finish the request
|
||||
request.finish()
|
||||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
|
||||
class TestCachePeriodFromHeaders(unittest.TestCase):
|
||||
def test_cache_control(self):
|
||||
# uppercase
|
||||
self.assertEqual(
|
||||
|
|
|
@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry = Mock(spec_set=["expires", "priority", "weight"])
|
||||
entry.expires = 0
|
||||
entry.priority = 0
|
||||
entry.weight = 0
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry = Mock(spec_set=["expires", "priority", "weight"])
|
||||
entry.expires = 999999999
|
||||
entry.priority = 0
|
||||
entry.weight = 0
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
resolver = SrvResolver(
|
||||
|
|
|
@ -74,7 +74,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
|||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertEqual(filtered[i].content["a"], "b")
|
||||
|
||||
@tests.unittest.DEBUG
|
||||
@defer.inlineCallbacks
|
||||
def test_erased_user(self):
|
||||
# 4 message events, from erased and unerased users, with a membership
|
||||
|
|
Loading…
Reference in New Issue