Merge pull request #677 from matrix-org/erikj/dns_cache

Read from DNS cache if within TTL
This commit is contained in:
Erik Johnston 2016-04-08 14:09:56 +01:00
commit 79fc4ff6f9
2 changed files with 55 additions and 18 deletions

View File

@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections import collections
import logging import logging
import random import random
import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port" "_Server", "priority weight host port expires"
) )
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
host=domain, host=domain,
port=default_port, port=default_port,
priority=0, priority=0,
weight=0 weight=0,
expires=0,
) )
else: else:
self.default_server = None self.default_server = None
@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = [] servers = []
try: try:
@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target) host = str(payload.target)
srv_ttl = answer.ttl
try: try:
answers, _, _ = yield dns_client.lookupAddress(host) answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError: except DNSNameError:
continue continue
ips = [ for answer in answers:
answer.payload.dottedQuad() if answer.type == dns.A and answer.payload:
for answer in answers ip = answer.payload.dottedQuad()
if answer.type == dns.A and answer.payload host_ttl = min(srv_ttl, answer.ttl)
]
for ip in ips: servers.append(_Server(
servers.append(_Server( host=ip,
host=ip, port=int(payload.port),
port=int(payload.port), priority=int(payload.priority),
priority=int(payload.priority), weight=int(payload.weight),
weight=int(payload.weight) expires=int(clock.time()) + host_ttl,
)) ))
servers.sort() servers.sort()
cache[service_name] = list(servers) cache[service_name] = list(servers)

View File

@ -21,6 +21,8 @@ from mock import Mock
from synapse.http.endpoint import resolve_service from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, ip_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache(self): def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.examle.com" service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = { cache = {
service_name: [object()] service_name: [entry]
} }
servers = yield resolve_service( servers = yield resolve_service(
@ -82,6 +87,31 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {
service_name: [entry]
}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_empty_cache(self): def test_empty_cache(self):
dns_client_mock = Mock() dns_client_mock = Mock()