Merge pull request #677 from matrix-org/erikj/dns_cache
Read from DNS cache if within TTL
This commit is contained in:
commit
79fc4ff6f9
|
@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,7 +32,7 @@ SERVER_CACHE = {}
|
||||||
|
|
||||||
|
|
||||||
_Server = collections.namedtuple(
|
_Server = collections.namedtuple(
|
||||||
"_Server", "priority weight host port"
|
"_Server", "priority weight host port expires"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
|
||||||
host=domain,
|
host=domain,
|
||||||
port=default_port,
|
port=default_port,
|
||||||
priority=0,
|
priority=0,
|
||||||
weight=0
|
weight=0,
|
||||||
|
expires=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.default_server = None
|
self.default_server = None
|
||||||
|
@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
||||||
|
cache_entry = cache.get(service_name, None)
|
||||||
|
if cache_entry:
|
||||||
|
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||||
|
servers = list(cache_entry)
|
||||||
|
defer.returnValue(servers)
|
||||||
|
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = answer.payload
|
payload = answer.payload
|
||||||
|
|
||||||
host = str(payload.target)
|
host = str(payload.target)
|
||||||
|
srv_ttl = answer.ttl
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answers, _, _ = yield dns_client.lookupAddress(host)
|
answers, _, _ = yield dns_client.lookupAddress(host)
|
||||||
except DNSNameError:
|
except DNSNameError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ips = [
|
for answer in answers:
|
||||||
answer.payload.dottedQuad()
|
if answer.type == dns.A and answer.payload:
|
||||||
for answer in answers
|
ip = answer.payload.dottedQuad()
|
||||||
if answer.type == dns.A and answer.payload
|
host_ttl = min(srv_ttl, answer.ttl)
|
||||||
]
|
|
||||||
|
|
||||||
for ip in ips:
|
servers.append(_Server(
|
||||||
servers.append(_Server(
|
host=ip,
|
||||||
host=ip,
|
port=int(payload.port),
|
||||||
port=int(payload.port),
|
priority=int(payload.priority),
|
||||||
priority=int(payload.priority),
|
weight=int(payload.weight),
|
||||||
weight=int(payload.weight)
|
expires=int(clock.time()) + host_ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
servers.sort()
|
servers.sort()
|
||||||
cache[service_name] = list(servers)
|
cache[service_name] = list(servers)
|
||||||
|
|
|
@ -21,6 +21,8 @@ from mock import Mock
|
||||||
|
|
||||||
from synapse.http.endpoint import resolve_service
|
from synapse.http.endpoint import resolve_service
|
||||||
|
|
||||||
|
from tests.utils import MockClock
|
||||||
|
|
||||||
|
|
||||||
class DnsTestCase(unittest.TestCase):
|
class DnsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
|
||||||
self.assertEquals(servers[0].host, ip_address)
|
self.assertEquals(servers[0].host, ip_address)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_from_cache(self):
|
def test_from_cache_expired_and_dns_fail(self):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = "test_service.examle.com"
|
service_name = "test_service.examle.com"
|
||||||
|
|
||||||
|
entry = Mock(spec_set=["expires"])
|
||||||
|
entry.expires = 0
|
||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
service_name: [object()]
|
service_name: [entry]
|
||||||
}
|
}
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolve_service(
|
||||||
|
@ -82,6 +87,31 @@ class DnsTestCase(unittest.TestCase):
|
||||||
self.assertEquals(len(servers), 1)
|
self.assertEquals(len(servers), 1)
|
||||||
self.assertEquals(servers, cache[service_name])
|
self.assertEquals(servers, cache[service_name])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_from_cache(self):
|
||||||
|
clock = MockClock()
|
||||||
|
|
||||||
|
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||||
|
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||||
|
|
||||||
|
service_name = "test_service.examle.com"
|
||||||
|
|
||||||
|
entry = Mock(spec_set=["expires"])
|
||||||
|
entry.expires = 999999999
|
||||||
|
|
||||||
|
cache = {
|
||||||
|
service_name: [entry]
|
||||||
|
}
|
||||||
|
|
||||||
|
servers = yield resolve_service(
|
||||||
|
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
|
self.assertEquals(len(servers), 1)
|
||||||
|
self.assertEquals(servers, cache[service_name])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_empty_cache(self):
|
def test_empty_cache(self):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
|
|
Loading…
Reference in New Issue