Add tests for blacklisting reactor/agent. (#9563)
This commit is contained in:
parent
70d1b6abff
commit
e55bd0e110
|
@ -0,0 +1 @@
|
||||||
|
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
|
|
@ -39,6 +39,7 @@ from zope.interface import implementer, provider
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import VERIFY_NONE
|
from OpenSSL.SSL import VERIFY_NONE
|
||||||
from twisted.internet import defer, error as twisted_error, protocol, ssl
|
from twisted.internet import defer, error as twisted_error, protocol, ssl
|
||||||
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IAddress,
|
IAddress,
|
||||||
IHostResolution,
|
IHostResolution,
|
||||||
|
@ -151,16 +152,17 @@ class _IPBlacklistingResolver:
|
||||||
def resolveHostName(
|
def resolveHostName(
|
||||||
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
|
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
|
||||||
) -> IResolutionReceiver:
|
) -> IResolutionReceiver:
|
||||||
|
|
||||||
r = recv()
|
|
||||||
addresses = [] # type: List[IAddress]
|
addresses = [] # type: List[IAddress]
|
||||||
|
|
||||||
def _callback() -> None:
|
def _callback() -> None:
|
||||||
r.resolutionBegan(None)
|
|
||||||
|
|
||||||
has_bad_ip = False
|
has_bad_ip = False
|
||||||
for i in addresses:
|
for address in addresses:
|
||||||
ip_address = IPAddress(i.host)
|
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
|
||||||
|
# should go through this path.
|
||||||
|
if not isinstance(address, (IPv4Address, IPv6Address)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
ip_address = IPAddress(address.host)
|
||||||
|
|
||||||
if check_against_blacklist(
|
if check_against_blacklist(
|
||||||
ip_address, self._ip_whitelist, self._ip_blacklist
|
ip_address, self._ip_whitelist, self._ip_blacklist
|
||||||
|
@ -175,15 +177,15 @@ class _IPBlacklistingResolver:
|
||||||
# request, but all we can really do from here is claim that there were no
|
# request, but all we can really do from here is claim that there were no
|
||||||
# valid results.
|
# valid results.
|
||||||
if not has_bad_ip:
|
if not has_bad_ip:
|
||||||
for i in addresses:
|
for address in addresses:
|
||||||
r.addressResolved(i)
|
recv.addressResolved(address)
|
||||||
r.resolutionComplete()
|
recv.resolutionComplete()
|
||||||
|
|
||||||
@provider(IResolutionReceiver)
|
@provider(IResolutionReceiver)
|
||||||
class EndpointReceiver:
|
class EndpointReceiver:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
|
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
|
||||||
pass
|
recv.resolutionBegan(resolutionInProgress)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def addressResolved(address: IAddress) -> None:
|
def addressResolved(address: IAddress) -> None:
|
||||||
|
@ -197,7 +199,7 @@ class _IPBlacklistingResolver:
|
||||||
EndpointReceiver, hostname, portNumber=portNumber
|
EndpointReceiver, hostname, portNumber=portNumber
|
||||||
)
|
)
|
||||||
|
|
||||||
return r
|
return recv
|
||||||
|
|
||||||
|
|
||||||
@implementer(ISynapseReactor)
|
@implementer(ISynapseReactor)
|
||||||
|
@ -346,7 +348,7 @@ class SimpleHttpClient:
|
||||||
contextFactory=self.hs.get_http_client_context_factory(),
|
contextFactory=self.hs.get_http_client_context_factory(),
|
||||||
pool=pool,
|
pool=pool,
|
||||||
use_proxy=use_proxy,
|
use_proxy=use_proxy,
|
||||||
)
|
) # type: IAgent
|
||||||
|
|
||||||
if self._ip_blacklist:
|
if self._ip_blacklist:
|
||||||
# If we have an IP blacklist, we then install the blacklisting Agent
|
# If we have an IP blacklist, we then install the blacklisting Agent
|
||||||
|
|
|
@ -16,12 +16,23 @@ from io import BytesIO
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from netaddr import IPSet
|
||||||
|
|
||||||
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.client import ResponseDone
|
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||||
|
from twisted.web.client import Agent, ResponseDone
|
||||||
from twisted.web.iweb import UNKNOWN_LENGTH
|
from twisted.web.iweb import UNKNOWN_LENGTH
|
||||||
|
|
||||||
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.client import (
|
||||||
|
BlacklistingAgentWrapper,
|
||||||
|
BlacklistingReactorWrapper,
|
||||||
|
BodyExceededMaxSize,
|
||||||
|
read_body_with_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.server import FakeTransport, get_clock
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
||||||
|
|
||||||
# The data is never consumed.
|
# The data is never consumed.
|
||||||
self.assertEqual(result.getvalue(), b"")
|
self.assertEqual(result.getvalue(), b"")
|
||||||
|
|
||||||
|
|
||||||
|
class BlacklistingAgentTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor, self.clock = get_clock()
|
||||||
|
|
||||||
|
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
||||||
|
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
|
||||||
|
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
|
||||||
|
|
||||||
|
# Configure the reactor's DNS resolver.
|
||||||
|
for (domain, ip) in (
|
||||||
|
(self.safe_domain, self.safe_ip),
|
||||||
|
(self.unsafe_domain, self.unsafe_ip),
|
||||||
|
(self.allowed_domain, self.allowed_ip),
|
||||||
|
):
|
||||||
|
self.reactor.lookups[domain.decode()] = ip.decode()
|
||||||
|
self.reactor.lookups[ip.decode()] = ip.decode()
|
||||||
|
|
||||||
|
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
||||||
|
self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
||||||
|
|
||||||
|
def test_reactor(self):
|
||||||
|
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
||||||
|
agent = Agent(
|
||||||
|
BlacklistingReactorWrapper(
|
||||||
|
self.reactor,
|
||||||
|
ip_whitelist=self.ip_whitelist,
|
||||||
|
ip_blacklist=self.ip_blacklist,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The unsafe domains and IPs should be rejected.
|
||||||
|
for domain in (self.unsafe_domain, self.unsafe_ip):
|
||||||
|
self.failureResultOf(
|
||||||
|
agent.request(b"GET", b"http://" + domain), DNSLookupError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The safe domains IPs should be accepted.
|
||||||
|
for domain in (
|
||||||
|
self.safe_domain,
|
||||||
|
self.allowed_domain,
|
||||||
|
self.safe_ip,
|
||||||
|
self.allowed_ip,
|
||||||
|
):
|
||||||
|
d = agent.request(b"GET", b"http://" + domain)
|
||||||
|
|
||||||
|
# Grab the latest TCP connection.
|
||||||
|
(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
client_factory,
|
||||||
|
_timeout,
|
||||||
|
_bindAddress,
|
||||||
|
) = self.reactor.tcpClients[-1]
|
||||||
|
|
||||||
|
# Make the connection and pump data through it.
|
||||||
|
client = client_factory.buildProtocol(None)
|
||||||
|
server = AccumulatingProtocol()
|
||||||
|
server.makeConnection(FakeTransport(client, self.reactor))
|
||||||
|
client.makeConnection(FakeTransport(server, self.reactor))
|
||||||
|
client.dataReceived(
|
||||||
|
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.successResultOf(d)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
def test_agent(self):
|
||||||
|
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
||||||
|
agent = BlacklistingAgentWrapper(
|
||||||
|
Agent(self.reactor),
|
||||||
|
ip_whitelist=self.ip_whitelist,
|
||||||
|
ip_blacklist=self.ip_blacklist,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The unsafe IPs should be rejected.
|
||||||
|
self.failureResultOf(
|
||||||
|
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The safe and unsafe domains and safe IPs should be accepted.
|
||||||
|
for domain in (
|
||||||
|
self.safe_domain,
|
||||||
|
self.unsafe_domain,
|
||||||
|
self.allowed_domain,
|
||||||
|
self.safe_ip,
|
||||||
|
self.allowed_ip,
|
||||||
|
):
|
||||||
|
d = agent.request(b"GET", b"http://" + domain)
|
||||||
|
|
||||||
|
# Grab the latest TCP connection.
|
||||||
|
(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
client_factory,
|
||||||
|
_timeout,
|
||||||
|
_bindAddress,
|
||||||
|
) = self.reactor.tcpClients[-1]
|
||||||
|
|
||||||
|
# Make the connection and pump data through it.
|
||||||
|
client = client_factory.buildProtocol(None)
|
||||||
|
server = AccumulatingProtocol()
|
||||||
|
server.makeConnection(FakeTransport(client, self.reactor))
|
||||||
|
client.makeConnection(FakeTransport(server, self.reactor))
|
||||||
|
client.dataReceived(
|
||||||
|
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.successResultOf(d)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
Loading…
Reference in New Issue