Fix tests on Twisted trunk. (#16528)

Twisted trunk makes a change to the `TLSMemoryBIOFactory` where
the underlying protocol is changed from `TLSMemoryBIOProtocol` to
`BufferingTLSTransport` to improve performance of TLS code (see
https://github.com/twisted/twisted/issues/11989).

In order to properly hook this code up in tests we need to pass the test
reactor's clock into `TLSMemoryBIOFactory` to avoid the global (trial)
reactor being used by default.

Twisted does something similar internally for tests:
157cd8e659/src/twisted/web/test/test_agent.py (L871-L874)
This commit is contained in:
Patrick Cloke 2023-10-25 07:39:45 -04:00 committed by GitHub
parent 95076f77c1
commit e182dbb5b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 111 deletions

1
changelog.d/16528.misc Normal file
View File

@ -0,0 +1 @@
Fix running unit tests on Twisted trunk.

View File

@ -15,14 +15,20 @@ import os.path
import subprocess
from typing import List
from incremental import Version
from zope.interface import implementer
import twisted
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.interfaces import (
IOpenSSLServerConnectionCreator,
IProtocolFactory,
IReactorTime,
)
from twisted.internet.ssl import Certificate, trustRootFromCertificates
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
@ -153,6 +159,33 @@ class TestServerTLSConnectionFactory:
return Connection(ctx, None)
def wrap_server_factory_for_tls(
factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
# Twisted > 23.8.0 has a different API that accepts a clock.
if twisted.version <= Version("Twisted", 23, 8, 0):
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
else:
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg]
)
# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)

View File

@ -31,7 +31,7 @@ from twisted.internet.interfaces import (
IProtocolFactory,
)
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
@ -57,11 +57,7 @@ from synapse.types import ISynapseReactor
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_ca_cert_file,
)
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import checked_cast, default_config
@ -125,7 +121,18 @@ class MatrixFederationAgentTests(unittest.TestCase):
# build the test server
server_factory = _get_test_protocol_factory()
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory,
self.reactor,
tls_sanlist
or [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
)
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
@ -435,8 +442,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
request.finish()
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(),
self.reactor,
sanlist=[
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
]
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Returns:

View File

@ -29,18 +29,14 @@ from twisted.internet.endpoints import (
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import BasicProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_https_policy,
)
from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
from tests.utils import checked_cast
@ -272,7 +268,9 @@ class MatrixFederationAgentTests(TestCase):
the server Protocol returned by server_factory
"""
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
)
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
@ -639,8 +637,8 @@ class MatrixFederationAgentTests(TestCase):
request.finish()
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@ -806,7 +804,9 @@ class MatrixFederationAgentTests(TestCase):
request.finish()
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
ssl_factory = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
)
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
@ -870,30 +870,6 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [b"DNS:test.com"]
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel

View File

@ -15,9 +15,7 @@ import logging
import os
from typing import Any, Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
@ -27,7 +25,11 @@ from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.http import (
TestServerTLSConnectionFactory,
get_test_ca_cert_file,
wrap_server_factory_for_tls,
)
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
server_tls_protocol = _build_test_server(get_connection_factory())
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_protocol = wrap_server_factory_for_tls(
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
).buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
)
# fish the test server back out of the server-side TLS protocol.
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
test_server_connection_factory = TestServerTLSConnectionFactory(
sanlist=[b"DNS:example.com"]
)
return test_server_connection_factory
def _build_test_server(
connection_creator: IOpenSSLServerConnectionCreator,
) -> TLSMemoryBIOProtocol:
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Args:
connection_creator: thing to build SSL connections
Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=server_factory
)
return server_tls_factory.buildProtocol(None)
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)

View File

@ -43,9 +43,11 @@ from typing import (
from unittest.mock import Mock
import attr
from incremental import Version
from typing_extensions import ParamSpec
from zope.interface import implementer
import twisted
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@ -474,6 +476,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
# In order for the TLS protocol tests to work, modify _get_default_clock
# on newer Twisted versions to use the test reactor's clock.
#
# This is *super* dirty since it is never undone and relies on the next
# test to overwrite it.
if twisted.version > Version("Twisted", 23, 8, 0):
from twisted.protocols import tls
tls._get_default_clock = lambda: self # type: ignore[attr-defined]
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()