Cache token introspection response from OIDC provider (#16117)
This commit is contained in:
parent
eb0dbab15b
commit
54a51ff6c1
|
@ -0,0 +1 @@
|
||||||
|
Cache token introspection response from OIDC provider.
|
|
@ -39,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import Requester, UserID, create_requester
|
from synapse.types import Requester, UserID, create_requester
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -106,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
|
|
||||||
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
|
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
|
||||||
|
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
|
||||||
|
cache_name="introspection_token_cache",
|
||||||
|
clock=self._clock,
|
||||||
|
max_len=10000,
|
||||||
|
expiry_ms=5 * 60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(auth_method, PrivateKeyJWTWithKid):
|
if isinstance(auth_method, PrivateKeyJWTWithKid):
|
||||||
# Use the JWK as the client secret when using the private_key_jwt method
|
# Use the JWK as the client secret when using the private_key_jwt method
|
||||||
assert self._config.jwk, "No JWK provided"
|
assert self._config.jwk, "No JWK provided"
|
||||||
|
@ -144,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
Returns:
|
Returns:
|
||||||
The introspection response
|
The introspection response
|
||||||
"""
|
"""
|
||||||
|
# check the cache before doing a request
|
||||||
|
introspection_token = self._token_cache.get(token, None)
|
||||||
|
|
||||||
|
if introspection_token:
|
||||||
|
# check the expiration field of the token (if it exists)
|
||||||
|
exp = introspection_token.get("exp", None)
|
||||||
|
if exp:
|
||||||
|
time_now = self._clock.time()
|
||||||
|
expired = time_now > exp
|
||||||
|
if not expired:
|
||||||
|
return introspection_token
|
||||||
|
else:
|
||||||
|
return introspection_token
|
||||||
|
|
||||||
metadata = await self._issuer_metadata.get()
|
metadata = await self._issuer_metadata.get()
|
||||||
introspection_endpoint = metadata.get("introspection_endpoint")
|
introspection_endpoint = metadata.get("introspection_endpoint")
|
||||||
raw_headers: Dict[str, str] = {
|
raw_headers: Dict[str, str] = {
|
||||||
|
@ -157,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
|
|
||||||
# Fill the body/headers with credentials
|
# Fill the body/headers with credentials
|
||||||
uri, raw_headers, body = self._client_auth.prepare(
|
uri, raw_headers, body = self._client_auth.prepare(
|
||||||
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
|
method="POST",
|
||||||
|
uri=introspection_endpoint,
|
||||||
|
headers=raw_headers,
|
||||||
|
body=body,
|
||||||
)
|
)
|
||||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||||
|
|
||||||
|
@ -187,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
"The introspection endpoint returned an invalid JSON response."
|
"The introspection endpoint returned an invalid JSON response."
|
||||||
)
|
)
|
||||||
|
|
||||||
return IntrospectionToken(**resp)
|
expiration = resp.get("exp", None)
|
||||||
|
if expiration:
|
||||||
|
if self._clock.time() > expiration:
|
||||||
|
raise InvalidClientTokenError("Token is expired.")
|
||||||
|
|
||||||
|
introspection_token = IntrospectionToken(**resp)
|
||||||
|
|
||||||
|
# add token to cache
|
||||||
|
self._token_cache[token] = introspection_token
|
||||||
|
|
||||||
|
return introspection_token
|
||||||
|
|
||||||
async def is_server_admin(self, requester: Requester) -> bool:
|
async def is_server_admin(self, requester: Requester) -> bool:
|
||||||
return "urn:synapse:admin:*" in requester.scope
|
return "urn:synapse:admin:*" in requester.scope
|
||||||
|
|
|
@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||||
self.assertEqual(error.value.code, 503)
|
self.assertEqual(error.value.code, 503)
|
||||||
|
|
||||||
|
def test_introspection_token_cache(self) -> None:
|
||||||
|
access_token = "open_sesame"
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse.json(
|
||||||
|
code=200,
|
||||||
|
payload={"active": "true", "scope": "guest", "jti": access_token},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# first call should cache response
|
||||||
|
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
|
||||||
|
# for regular auth code via the config
|
||||||
|
self.get_success(
|
||||||
|
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(introspection_token["jti"], access_token)
|
||||||
|
# there's been one http request
|
||||||
|
self.http_client.request.assert_called_once()
|
||||||
|
|
||||||
|
# second call should pull from cache, there should still be only one http request
|
||||||
|
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
|
||||||
|
self.http_client.request.assert_called_once()
|
||||||
|
self.assertEqual(token["jti"], access_token)
|
||||||
|
|
||||||
|
# advance past five minutes and check that cache expired - there should be more than one http call now
|
||||||
|
self.reactor.advance(360)
|
||||||
|
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 2)
|
||||||
|
self.assertEqual(token_2["jti"], access_token)
|
||||||
|
|
||||||
|
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
|
||||||
|
# token with a soon-to-expire `exp` field to the cache
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse.json(
|
||||||
|
code=200,
|
||||||
|
payload={
|
||||||
|
"active": "true",
|
||||||
|
"scope": "guest",
|
||||||
|
"jti": "stale",
|
||||||
|
"exp": self.clock.time() + 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.auth._introspect_token("stale") # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(introspection_token["jti"], "stale")
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 1)
|
||||||
|
|
||||||
|
# advance the reactor past the token expiry but less than the cache expiry
|
||||||
|
self.reactor.advance(120)
|
||||||
|
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# check that the next call causes another http request (which will fail because the token is technically expired
|
||||||
|
# but the important thing is we discard the token from the cache and try the network)
|
||||||
|
self.get_failure(
|
||||||
|
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 2)
|
||||||
|
|
||||||
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
|
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
|
||||||
# We only generate a master key to simplify the test.
|
# We only generate a master key to simplify the test.
|
||||||
master_signing_key = generate_signing_key(device_id)
|
master_signing_key = generate_signing_key(device_id)
|
||||||
|
|
Loading…
Reference in New Issue