Merge pull request #5844 from matrix-org/erikj/retry_well_known_lookup

Retry well-known lookup before expiry.
This commit is contained in:
Erik Johnston 2019-08-14 09:53:33 +01:00 committed by GitHub
commit 09f6152a11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 27 deletions

1
changelog.d/5844.misc Normal file
View File

@ -0,0 +1 @@
Retry well-known lookup before the cache expires, giving a grace period where the remote well-known can be down but we still use the old result.

View File

@ -44,6 +44,12 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period # lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
# Attempt to refetch a cached well-known N% of the TTL before it expires.
# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
# we'll start trying to refetch 1 minute before it expires.
WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,15 +86,38 @@ class WellKnownResolver(object):
Deferred[WellKnownLookupResult]: The result of the lookup Deferred[WellKnownLookupResult]: The result of the lookup
""" """
try: try:
result = self._well_known_cache[server_name] prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
server_name
)
now = self._clock.time()
if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl:
return WellKnownLookupResult(delegated_server=prev_result)
except KeyError: except KeyError:
# TODO: should we linearise so that we don't end up doing two .well-known prev_result = None
# requests for the same server in parallel?
# TODO: should we linearise so that we don't end up doing two .well-known
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"): with Measure(self._clock, "get_well_known"):
result, cache_period = yield self._do_get_well_known(server_name) result, cache_period = yield self._do_get_well_known(server_name)
if cache_period > 0: except _FetchWellKnownFailure as e:
self._well_known_cache.set(server_name, result, cache_period) if prev_result and e.temporary:
# This is a temporary failure and we have a still valid cached
# result, so lets return that. Hopefully the next time we ask
# the remote will be back up again.
return WellKnownLookupResult(delegated_server=prev_result)
result = None
# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
return WellKnownLookupResult(delegated_server=result) return WellKnownLookupResult(delegated_server=result)
@ -99,40 +128,42 @@ class WellKnownResolver(object):
Args: Args:
server_name (bytes): name of the server, from the requested url server_name (bytes): name of the server, from the requested url
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns: Returns:
Deferred[Tuple[bytes|None|object],int]: Deferred[Tuple[bytes,int]]: The lookup result and cache period.
result, cache period, where result is one of:
- the new server name from the .well-known (as a `bytes`)
- None if there was no .well-known file.
- INVALID_WELL_KNOWN if the .well-known was invalid
""" """
uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii") uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str) logger.info("Fetching %s", uri_str)
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
try: try:
response = yield make_deferred_yieldable( response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri) self._well_known_agent.request(b"GET", uri)
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e)
raise _FetchWellKnownFailure(temporary=True)
try:
if response.code != 200: if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,)) raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json.loads(body.decode("utf-8")) parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body) logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict") result = parsed_body["m.server"].encode("ascii")
if "m.server" not in parsed_body:
raise Exception("Missing key 'm.server'")
except Exception as e: except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e) logger.info("Error fetching %s: %s", uri_str, e)
raise _FetchWellKnownFailure(temporary=False)
# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
return (None, cache_period)
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers( cache_period = _cache_period_from_headers(
response.headers, time_now=self._reactor.seconds response.headers, time_now=self._reactor.seconds
@ -185,3 +216,10 @@ def _parse_cache_control(headers):
v = splits[1] if len(splits) > 1 else None v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v cache_controls[k] = v
return cache_controls return cache_controls
@attr.s()
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
temporary = attr.ib()

View File

@ -55,7 +55,7 @@ class TTLCache(object):
if e != SENTINEL: if e != SENTINEL:
self._expiry_list.remove(e) self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, key=key, value=value) entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry self._data[key] = entry
self._expiry_list.add(entry) self._expiry_list.add(entry)
@ -87,7 +87,8 @@ class TTLCache(object):
key: key to look up key: key to look up
Returns: Returns:
Tuple[Any, float]: the value from the cache, and the expiry time Tuple[Any, float, float]: the value from the cache, the expiry time
and the TTL
Raises: Raises:
KeyError if the entry is not found KeyError if the entry is not found
@ -99,7 +100,7 @@ class TTLCache(object):
self._metrics.inc_misses() self._metrics.inc_misses()
raise raise
self._metrics.inc_hits() self._metrics.inc_hits()
return e.value, e.expiry_time return e.value, e.expiry_time, e.ttl
def pop(self, key, default=SENTINEL): def pop(self, key, default=SENTINEL):
"""Remove a value from the cache """Remove a value from the cache
@ -158,5 +159,6 @@ class _CacheEntry(object):
# expiry_time is the first attribute, so that entries are sorted by expiry. # expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib() expiry_time = attr.ib()
ttl = attr.ib()
key = attr.ib() key = attr.ib()
value = attr.ib() value = attr.ib()

View File

@ -987,6 +987,75 @@ class MatrixFederationAgentTests(TestCase):
r = self.successResultOf(fetch_d) r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"other-server") self.assertEqual(r.delegated_server, b"other-server")
def test_well_known_cache_with_temp_failure(self):
"""Test that we refetch well-known before the cache expires, and that
it ignores transient errors.
"""
well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
well_known_cache=self.well_known_cache,
)
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = well_known_resolver.get_well_known(b"testserv")
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
well_known_server = self._handle_well_known_connection(
client_factory,
expected_sni=b"testserv",
response_headers={b"Cache-Control": b"max-age=1000"},
content=b'{ "m.server": "target-server" }',
)
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
# close the tcp connection
well_known_server.loseConnection()
# Get close to the cache expiry, this will cause the resolver to do
# another lookup.
self.reactor.pump((900.0,))
fetch_d = well_known_resolver.get_well_known(b"testserv")
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
# fonx the connection attempt, this will be treated as a temporary
# failure.
client_factory.clientConnectionFailed(None, Exception("nope"))
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
# .well-known request fails.
self.reactor.pump((0.4,))
# Resolver should return cached value, despite the lookup failing.
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
# Expire the cache and repeat the request
self.reactor.pump((100.0,))
# Repated the request, this time it should fail if the lookup fails.
fetch_d = well_known_resolver.get_well_known(b"testserv")
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
client_factory.clientConnectionFailed(None, Exception("nope"))
self.reactor.pump((0.4,))
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
class TestCachePeriodFromHeaders(TestCase): class TestCachePeriodFromHeaders(TestCase):
def test_cache_control(self): def test_cache_control(self):

View File

@ -36,7 +36,7 @@ class CacheTestCase(unittest.TestCase):
self.assertTrue("one" in self.cache) self.assertTrue("one" in self.cache)
self.assertEqual(self.cache.get("one"), "1") self.assertEqual(self.cache.get("one"), "1")
self.assertEqual(self.cache["one"], "1") self.assertEqual(self.cache["one"], "1")
self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110)) self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110, 10))
self.assertEqual(self.cache._metrics.hits, 3) self.assertEqual(self.cache._metrics.hits, 3)
self.assertEqual(self.cache._metrics.misses, 0) self.assertEqual(self.cache._metrics.misses, 0)
@ -77,7 +77,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache["two"], "2") self.assertEqual(self.cache["two"], "2")
self.assertEqual(self.cache["three"], "3") self.assertEqual(self.cache["three"], "3")
self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120)) self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120, 20))
self.assertEqual(self.cache._metrics.hits, 5) self.assertEqual(self.cache._metrics.hits, 5)
self.assertEqual(self.cache._metrics.misses, 0) self.assertEqual(self.cache._metrics.misses, 0)