Read from DNS cache if within TTL

This commit is contained in:
Erik Johnston 2016-03-31 10:04:28 +01:00
parent a68c1b15aa
commit f699b8f997
2 changed files with 26 additions and 16 deletions

View File

@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections
import logging
import random
import time
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple(
"_Server", "priority weight host port"
"_Server", "priority weight host port expires"
)
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
host=domain,
port=default_port,
priority=0,
weight=0
weight=0,
expires=0,
)
else:
self.default_server = None
@ -154,6 +156,12 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(time.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
@ -173,26 +181,25 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue
payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]
for answer in answers:
if answer.type == dns.A and answer.payload:
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
weight=int(payload.weight),
expires=int(time.time()) + host_ttl,
))
servers.sort()

View File

@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase):
service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {
service_name: [object()]
service_name: [entry]
}
servers = yield resolve_service(