Don't send IP addresses as SNI (#4452)
The problem here is that we have cut-and-pasted an impl from Twisted, and then failed to maintain it. It was fixed in Twisted in https://github.com/twisted/twisted/pull/1047/files; let's do the same here.
This commit is contained in:
parent
6b90ae6efc
commit
97fd29c019
|
@ -0,0 +1 @@
|
||||||
|
Don't send IP addresses as SNI
|
|
@ -17,6 +17,7 @@ from zope.interface import implementer
|
||||||
|
|
||||||
from OpenSSL import SSL, crypto
|
from OpenSSL import SSL, crypto
|
||||||
from twisted.internet._sslverify import _defaultCurveName
|
from twisted.internet._sslverify import _defaultCurveName
|
||||||
|
from twisted.internet.abstract import isIPAddress, isIPv6Address
|
||||||
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
||||||
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -98,8 +99,14 @@ class ClientTLSOptions(object):
|
||||||
|
|
||||||
def __init__(self, hostname, ctx):
|
def __init__(self, hostname, ctx):
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
self._hostname = hostname
|
|
||||||
self._hostnameBytes = _idnaBytes(hostname)
|
if isIPAddress(hostname) or isIPv6Address(hostname):
|
||||||
|
self._hostnameBytes = hostname.encode('ascii')
|
||||||
|
self._sendSNI = False
|
||||||
|
else:
|
||||||
|
self._hostnameBytes = _idnaBytes(hostname)
|
||||||
|
self._sendSNI = True
|
||||||
|
|
||||||
ctx.set_info_callback(
|
ctx.set_info_callback(
|
||||||
_tolerateErrors(self._identityVerifyingInfoCallback)
|
_tolerateErrors(self._identityVerifyingInfoCallback)
|
||||||
)
|
)
|
||||||
|
@ -111,7 +118,9 @@ class ClientTLSOptions(object):
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
||||||
if where & SSL.SSL_CB_HANDSHAKE_START:
|
# Literal IPv4 and IPv6 addresses are not permitted
|
||||||
|
# as host names according to the RFCs
|
||||||
|
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
|
||||||
connection.set_tlsext_host_name(self._hostnameBytes)
|
connection.set_tlsext_host_name(self._hostnameBytes)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
_srv_resolver=self.mock_resolver,
|
_srv_resolver=self.mock_resolver,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_connection(self, client_factory):
|
def _make_connection(self, client_factory, expected_sni):
|
||||||
"""Builds a test server, and completes the outgoing client connection
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -69,9 +69,17 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
# tell the server tls protocol to send its stuff back to the client, too
|
# tell the server tls protocol to send its stuff back to the client, too
|
||||||
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||||
|
|
||||||
# finally, give the reactor a pump to get the TLS juices flowing.
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
|
# check the SNI
|
||||||
|
server_name = server_tls_protocol._tlsConnection.get_servername()
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
# fish the test server back out of the server-side TLS protocol.
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
return server_tls_protocol.wrappedProtocol
|
return server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
@ -113,7 +121,10 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(client_factory)
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=b"testserv",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
|
@ -150,6 +161,52 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
json = self.successResultOf(treq.json_content(response))
|
json = self.successResultOf(treq.json_content(response))
|
||||||
self.assertEqual(json, {"a": 1})
|
self.assertEqual(json, {"a": 1})
|
||||||
|
|
||||||
|
def test_get_ip_address(self):
|
||||||
|
"""
|
||||||
|
Test the behaviour when the server name contains an explicit IP (with no port)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
|
||||||
|
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||||
|
|
||||||
|
# then there will be a getaddrinfo on the IP
|
||||||
|
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
|
||||||
|
|
||||||
|
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
|
||||||
|
|
||||||
|
# Nothing happened yet
|
||||||
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
|
self.mock_resolver.resolve_service.assert_called_once()
|
||||||
|
|
||||||
|
# Make sure treq is trying to connect
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, '1.2.3.4')
|
||||||
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
expected_sni=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b'GET')
|
||||||
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
|
# XXX currently broken
|
||||||
|
# self.assertEqual(
|
||||||
|
# request.requestHeaders.getRawHeaders(b'host'),
|
||||||
|
# [b'1.2.3.4:8448']
|
||||||
|
# )
|
||||||
|
|
||||||
|
# finish the request
|
||||||
|
request.finish()
|
||||||
|
self.reactor.pump((0.1,))
|
||||||
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
|
|
||||||
def _check_logcontext(context):
|
def _check_logcontext(context):
|
||||||
current = LoggingContext.current_context()
|
current = LoggingContext.current_context()
|
||||||
|
|
Loading…
Reference in New Issue