Fix tests on Twisted trunk. (#16528)
Twisted trunk makes a change to the `TLSMemoryBIOFactory` where
the underlying protocol is changed from `TLSMemoryBIOProtocol` to
`BufferingTLSTransport` to improve performance of TLS code (see
https://github.com/twisted/twisted/issues/11989).
In order to properly hook this code up in tests we need to pass the test
reactor's clock into `TLSMemoryBIOFactory` to avoid the global (trial)
reactor being used by default.
Twisted does something similar internally for tests:
157cd8e659/src/twisted/web/test/test_agent.py (L871-L874)
This commit is contained in:
parent
95076f77c1
commit
e182dbb5b9
|
@ -0,0 +1 @@
|
||||||
|
Fix running unit tests on Twisted trunk.
|
|
@ -15,14 +15,20 @@ import os.path
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from incremental import Version
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
import twisted
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import Connection
|
from OpenSSL.SSL import Connection
|
||||||
from twisted.internet.address import IPv4Address
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
from twisted.internet.interfaces import (
|
||||||
|
IOpenSSLServerConnectionCreator,
|
||||||
|
IProtocolFactory,
|
||||||
|
IReactorTime,
|
||||||
|
)
|
||||||
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||||
from twisted.protocols.tls import TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||||
|
|
||||||
|
@ -153,6 +159,33 @@ class TestServerTLSConnectionFactory:
|
||||||
return Connection(ctx, None)
|
return Connection(ctx, None)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_server_factory_for_tls(
|
||||||
|
factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
|
||||||
|
) -> TLSMemoryBIOFactory:
|
||||||
|
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||||
|
|
||||||
|
The resultant factory will create a TLS server which presents a certificate
|
||||||
|
signed by our test CA, valid for the domains in `sanlist`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory: protocol factory to wrap
|
||||||
|
sanlist: list of domains the cert should be valid for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IProtocolFactory
|
||||||
|
"""
|
||||||
|
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||||
|
# Twisted > 23.8.0 has a different API that accepts a clock.
|
||||||
|
if twisted.version <= Version("Twisted", 23, 8, 0):
|
||||||
|
return TLSMemoryBIOFactory(
|
||||||
|
connection_creator, isClient=False, wrappedFactory=factory
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TLSMemoryBIOFactory(
|
||||||
|
connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# A dummy address, useful for tests that use FakeTransport and don't care about where
|
# A dummy address, useful for tests that use FakeTransport and don't care about where
|
||||||
# packets are going to/coming from.
|
# packets are going to/coming from.
|
||||||
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
|
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
|
||||||
|
|
|
@ -31,7 +31,7 @@ from twisted.internet.interfaces import (
|
||||||
IProtocolFactory,
|
IProtocolFactory,
|
||||||
)
|
)
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOProtocol
|
||||||
from twisted.web._newclient import ResponseNeverReceived
|
from twisted.web._newclient import ResponseNeverReceived
|
||||||
from twisted.web.client import Agent
|
from twisted.web.client import Agent
|
||||||
from twisted.web.http import HTTPChannel, Request
|
from twisted.web.http import HTTPChannel, Request
|
||||||
|
@ -57,11 +57,7 @@ from synapse.types import ISynapseReactor
|
||||||
from synapse.util.caches.ttlcache import TTLCache
|
from synapse.util.caches.ttlcache import TTLCache
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.http import (
|
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
|
||||||
TestServerTLSConnectionFactory,
|
|
||||||
dummy_address,
|
|
||||||
get_test_ca_cert_file,
|
|
||||||
)
|
|
||||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
from tests.utils import checked_cast, default_config
|
from tests.utils import checked_cast, default_config
|
||||||
|
|
||||||
|
@ -125,7 +121,18 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
# build the test server
|
# build the test server
|
||||||
server_factory = _get_test_protocol_factory()
|
server_factory = _get_test_protocol_factory()
|
||||||
if ssl:
|
if ssl:
|
||||||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
server_factory = wrap_server_factory_for_tls(
|
||||||
|
server_factory,
|
||||||
|
self.reactor,
|
||||||
|
tls_sanlist
|
||||||
|
or [
|
||||||
|
b"DNS:testserv",
|
||||||
|
b"DNS:target-server",
|
||||||
|
b"DNS:xn--bcher-kva.com",
|
||||||
|
b"IP:1.2.3.4",
|
||||||
|
b"IP:::1",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
server_protocol = server_factory.buildProtocol(dummy_address)
|
server_protocol = server_factory.buildProtocol(dummy_address)
|
||||||
assert server_protocol is not None
|
assert server_protocol is not None
|
||||||
|
@ -435,8 +442,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
request.finish()
|
request.finish()
|
||||||
|
|
||||||
# now we make another test server to act as the upstream HTTP server.
|
# now we make another test server to act as the upstream HTTP server.
|
||||||
server_ssl_protocol = _wrap_server_factory_for_tls(
|
server_ssl_protocol = wrap_server_factory_for_tls(
|
||||||
_get_test_protocol_factory()
|
_get_test_protocol_factory(),
|
||||||
|
self.reactor,
|
||||||
|
sanlist=[
|
||||||
|
b"DNS:testserv",
|
||||||
|
b"DNS:target-server",
|
||||||
|
b"DNS:xn--bcher-kva.com",
|
||||||
|
b"IP:1.2.3.4",
|
||||||
|
b"IP:::1",
|
||||||
|
],
|
||||||
).buildProtocol(dummy_address)
|
).buildProtocol(dummy_address)
|
||||||
|
|
||||||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
||||||
|
@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
|
||||||
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
||||||
|
|
||||||
|
|
||||||
def _wrap_server_factory_for_tls(
|
|
||||||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
|
|
||||||
) -> TLSMemoryBIOFactory:
|
|
||||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
|
||||||
The resultant factory will create a TLS server which presents a certificate
|
|
||||||
signed by our test CA, valid for the domains in `sanlist`
|
|
||||||
Args:
|
|
||||||
factory: protocol factory to wrap
|
|
||||||
sanlist: list of domains the cert should be valid for
|
|
||||||
Returns:
|
|
||||||
interfaces.IProtocolFactory
|
|
||||||
"""
|
|
||||||
if sanlist is None:
|
|
||||||
sanlist = [
|
|
||||||
b"DNS:testserv",
|
|
||||||
b"DNS:target-server",
|
|
||||||
b"DNS:xn--bcher-kva.com",
|
|
||||||
b"IP:1.2.3.4",
|
|
||||||
b"IP:::1",
|
|
||||||
]
|
|
||||||
|
|
||||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
|
||||||
return TLSMemoryBIOFactory(
|
|
||||||
connection_creator, isClient=False, wrappedFactory=factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_test_protocol_factory() -> IProtocolFactory:
|
def _get_test_protocol_factory() -> IProtocolFactory:
|
||||||
"""Get a protocol Factory which will build an HTTPChannel
|
"""Get a protocol Factory which will build an HTTPChannel
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -29,18 +29,14 @@ from twisted.internet.endpoints import (
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOProtocol
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
from synapse.http.client import BlocklistingReactorWrapper
|
from synapse.http.client import BlocklistingReactorWrapper
|
||||||
from synapse.http.connectproxyclient import BasicProxyCredentials
|
from synapse.http.connectproxyclient import BasicProxyCredentials
|
||||||
from synapse.http.proxyagent import ProxyAgent, parse_proxy
|
from synapse.http.proxyagent import ProxyAgent, parse_proxy
|
||||||
|
|
||||||
from tests.http import (
|
from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
|
||||||
TestServerTLSConnectionFactory,
|
|
||||||
dummy_address,
|
|
||||||
get_test_https_policy,
|
|
||||||
)
|
|
||||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
from tests.utils import checked_cast
|
from tests.utils import checked_cast
|
||||||
|
@ -272,7 +268,9 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
the server Protocol returned by server_factory
|
the server Protocol returned by server_factory
|
||||||
"""
|
"""
|
||||||
if ssl:
|
if ssl:
|
||||||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
server_factory = wrap_server_factory_for_tls(
|
||||||
|
server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
|
||||||
|
)
|
||||||
|
|
||||||
server_protocol = server_factory.buildProtocol(dummy_address)
|
server_protocol = server_factory.buildProtocol(dummy_address)
|
||||||
assert server_protocol is not None
|
assert server_protocol is not None
|
||||||
|
@ -639,8 +637,8 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
request.finish()
|
request.finish()
|
||||||
|
|
||||||
# now we make another test server to act as the upstream HTTP server.
|
# now we make another test server to act as the upstream HTTP server.
|
||||||
server_ssl_protocol = _wrap_server_factory_for_tls(
|
server_ssl_protocol = wrap_server_factory_for_tls(
|
||||||
_get_test_protocol_factory()
|
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
|
||||||
).buildProtocol(dummy_address)
|
).buildProtocol(dummy_address)
|
||||||
|
|
||||||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
||||||
|
@ -806,7 +804,9 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
request.finish()
|
request.finish()
|
||||||
|
|
||||||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
ssl_factory = wrap_server_factory_for_tls(
|
||||||
|
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
|
||||||
|
)
|
||||||
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
|
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
|
||||||
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
|
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
|
||||||
http_server = ssl_protocol.wrappedProtocol
|
http_server = ssl_protocol.wrappedProtocol
|
||||||
|
@ -870,30 +870,6 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
|
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_server_factory_for_tls(
|
|
||||||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
|
|
||||||
) -> TLSMemoryBIOFactory:
|
|
||||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
|
||||||
|
|
||||||
The resultant factory will create a TLS server which presents a certificate
|
|
||||||
signed by our test CA, valid for the domains in `sanlist`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factory: protocol factory to wrap
|
|
||||||
sanlist: list of domains the cert should be valid for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
interfaces.IProtocolFactory
|
|
||||||
"""
|
|
||||||
if sanlist is None:
|
|
||||||
sanlist = [b"DNS:test.com"]
|
|
||||||
|
|
||||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
|
||||||
return TLSMemoryBIOFactory(
|
|
||||||
connection_creator, isClient=False, wrappedFactory=factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_test_protocol_factory() -> IProtocolFactory:
|
def _get_test_protocol_factory() -> IProtocolFactory:
|
||||||
"""Get a protocol Factory which will build an HTTPChannel
|
"""Get a protocol Factory which will build an HTTPChannel
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -27,7 +25,11 @@ from synapse.rest.client import login
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
from tests.http import (
|
||||||
|
TestServerTLSConnectionFactory,
|
||||||
|
get_test_ca_cert_file,
|
||||||
|
wrap_server_factory_for_tls,
|
||||||
|
)
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.server import FakeChannel, FakeTransport, make_request
|
from tests.server import FakeChannel, FakeTransport, make_request
|
||||||
from tests.test_utils import SMALL_PNG
|
from tests.test_utils import SMALL_PNG
|
||||||
|
@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||||
|
|
||||||
# build the test server
|
# build the test server
|
||||||
server_tls_protocol = _build_test_server(get_connection_factory())
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
server_tls_protocol = wrap_server_factory_for_tls(
|
||||||
|
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
|
||||||
|
).buildProtocol(None)
|
||||||
|
|
||||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||||
|
@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# fish the test server back out of the server-side TLS protocol.
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
|
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
# give the reactor a pump to get the TLS juices flowing.
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
|
@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
return sum(len(files) for _, _, files in os.walk(path))
|
return sum(len(files) for _, _, files in os.walk(path))
|
||||||
|
|
||||||
|
|
||||||
def get_connection_factory() -> TestServerTLSConnectionFactory:
|
|
||||||
# this needs to happen once, but not until we are ready to run the first test
|
|
||||||
global test_server_connection_factory
|
|
||||||
if test_server_connection_factory is None:
|
|
||||||
test_server_connection_factory = TestServerTLSConnectionFactory(
|
|
||||||
sanlist=[b"DNS:example.com"]
|
|
||||||
)
|
|
||||||
return test_server_connection_factory
|
|
||||||
|
|
||||||
|
|
||||||
def _build_test_server(
|
|
||||||
connection_creator: IOpenSSLServerConnectionCreator,
|
|
||||||
) -> TLSMemoryBIOProtocol:
|
|
||||||
"""Construct a test server
|
|
||||||
|
|
||||||
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_creator: thing to build SSL connections
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TLSMemoryBIOProtocol
|
|
||||||
"""
|
|
||||||
server_factory = Factory.forProtocol(HTTPChannel)
|
|
||||||
# Request.finish expects the factory to have a 'log' method.
|
|
||||||
server_factory.log = _log_request
|
|
||||||
|
|
||||||
server_tls_factory = TLSMemoryBIOFactory(
|
|
||||||
connection_creator, isClient=False, wrappedFactory=server_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
return server_tls_factory.buildProtocol(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _log_request(request: Request) -> None:
|
def _log_request(request: Request) -> None:
|
||||||
"""Implements Factory.log, which is expected by Request.finish"""
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
logger.info("Completed request %s", request)
|
logger.info("Completed request %s", request)
|
||||||
|
|
|
@ -43,9 +43,11 @@ from typing import (
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from incremental import Version
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
import twisted
|
||||||
from twisted.internet import address, tcp, threads, udp
|
from twisted.internet import address, tcp, threads, udp
|
||||||
from twisted.internet._resolver import SimpleResolverComplexifier
|
from twisted.internet._resolver import SimpleResolverComplexifier
|
||||||
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
||||||
|
@ -474,6 +476,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
||||||
return succeed(lookups[name])
|
return succeed(lookups[name])
|
||||||
|
|
||||||
|
# In order for the TLS protocol tests to work, modify _get_default_clock
|
||||||
|
# on newer Twisted versions to use the test reactor's clock.
|
||||||
|
#
|
||||||
|
# This is *super* dirty since it is never undone and relies on the next
|
||||||
|
# test to overwrite it.
|
||||||
|
if twisted.version > Version("Twisted", 23, 8, 0):
|
||||||
|
from twisted.protocols import tls
|
||||||
|
|
||||||
|
tls._get_default_clock = lambda: self # type: ignore[attr-defined]
|
||||||
|
|
||||||
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
|
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue