Add tests for blacklisting reactor/agent. (#9563)

This commit is contained in:
Patrick Cloke 2021-03-11 09:15:22 -05:00 committed by GitHub
parent 70d1b6abff
commit e55bd0e110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 14 deletions

1
changelog.d/9563.misc Normal file
View File

@ -0,0 +1 @@
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.

View File

@ -39,6 +39,7 @@ from zope.interface import implementer, provider
from OpenSSL import SSL from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress, IAddress,
IHostResolution, IHostResolution,
@ -151,16 +152,17 @@ class _IPBlacklistingResolver:
def resolveHostName( def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver: ) -> IResolutionReceiver:
r = recv()
addresses = [] # type: List[IAddress] addresses = [] # type: List[IAddress]
def _callback() -> None: def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False has_bad_ip = False
for i in addresses: for address in addresses:
ip_address = IPAddress(i.host) # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
# should go through this path.
if not isinstance(address, (IPv4Address, IPv6Address)):
continue
ip_address = IPAddress(address.host)
if check_against_blacklist( if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist ip_address, self._ip_whitelist, self._ip_blacklist
@ -175,15 +177,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no # request, but all we can really do from here is claim that there were no
# valid results. # valid results.
if not has_bad_ip: if not has_bad_ip:
for i in addresses: for address in addresses:
r.addressResolved(i) recv.addressResolved(address)
r.resolutionComplete() recv.resolutionComplete()
@provider(IResolutionReceiver) @provider(IResolutionReceiver)
class EndpointReceiver: class EndpointReceiver:
@staticmethod @staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None: def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass recv.resolutionBegan(resolutionInProgress)
@staticmethod @staticmethod
def addressResolved(address: IAddress) -> None: def addressResolved(address: IAddress) -> None:
@ -197,7 +199,7 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber EndpointReceiver, hostname, portNumber=portNumber
) )
return r return recv
@implementer(ISynapseReactor) @implementer(ISynapseReactor)
@ -346,7 +348,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(), contextFactory=self.hs.get_http_client_context_factory(),
pool=pool, pool=pool,
use_proxy=use_proxy, use_proxy=use_proxy,
) ) # type: IAgent
if self._ip_blacklist: if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent # If we have an IP blacklist, we then install the blacklisting Agent

View File

@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock from mock import Mock
from netaddr import IPSet
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.client import ResponseDone from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
read_body_with_max_size,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase from tests.unittest import TestCase
@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed. # The data is never consumed.
self.assertEqual(result.getvalue(), b"") self.assertEqual(result.getvalue(), b"")
class BlacklistingAgentTest(TestCase):
def setUp(self):
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)
# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)
# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
def test_agent(self):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
)
# The unsafe IPs should be rejected.
self.failureResultOf(
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
)
# The safe and unsafe domains and safe IPs should be accepted.
for domain in (
self.safe_domain,
self.unsafe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)