Merge pull request #1654 from matrix-org/rav/no_more_refresh_tokens

Stop generating refresh_tokens
This commit is contained in:
Richard van der Hoff 2016-12-01 11:42:53 +00:00 committed by GitHub
commit 471200074b
9 changed files with 26 additions and 210 deletions

View File

@ -791,7 +791,7 @@ class Auth(object):
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", "refresh", type_string(str): The kind of token required (e.g. "access",
"delete_pusher") "delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
user_id (str): The user_id required user_id (str): The user_id required
@ -820,8 +820,7 @@ class Auth(object):
else: else:
v.satisfy_general(lambda c: c.startswith("time < ")) v.satisfy_general(lambda c: c.startswith("time < "))
# access_tokens and refresh_tokens include a nonce for uniqueness: any # access_tokens include a nonce for uniqueness: any value is acceptable
# value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)

View File

@ -380,12 +380,10 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None): initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Creates a new access token for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case. machanism (e.g. CAS), and the user_id converted to the canonical case.
@ -400,16 +398,13 @@ class AuthHandler(BaseHandler):
initial_display_name (str): display name to associate with the initial_display_name (str): display name to associate with the
device if it needs re-registering device if it needs re-registering
Returns: Returns:
A tuple of:
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -420,7 +415,7 @@ class AuthHandler(BaseHandler):
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
defer.returnValue((access_token, refresh_token)) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def check_user_exists(self, user_id):
@ -531,13 +526,6 @@ class AuthHandler(BaseHandler):
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
@ -551,16 +539,6 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")

View File

@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
password=login_submission["password"], password=login_submission["password"],
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id, user_id, device_id,
login_submission.get("initial_device_display_name") login_submission.get("initial_device_display_name"),
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id, user_id, device_id,
login_submission.get("initial_device_display_name") login_submission.get("initial_device_display_name"),
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device( device_id = yield self._register_device(
registered_user_id, login_submission registered_user_id, login_submission
) )
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id, registered_user_id, device_id,
login_submission.get("initial_device_display_name") login_submission.get("initial_device_display_name"),
)
) )
result = { result = {
"user_id": registered_user_id, "user_id": registered_user_id,
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:

View File

@ -374,8 +374,7 @@ class RegisterRestServlet(RestServlet):
def _create_registration_details(self, user_id, params): def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token Allocates device_id if one was not given; also creates access_token.
and refresh_token.
Args: Args:
(str) user_id: full canonical @user:id (str) user_id: full canonical @user:id
@ -386,8 +385,8 @@ class RegisterRestServlet(RestServlet):
""" """
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token, refresh_token = ( access_token = (
yield self.auth_handler.get_login_tuple_for_user_id( yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name") initial_display_name=params.get("initial_device_display_name")
) )
@ -397,7 +396,6 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id, "device_id": device_id,
}) })
@ -441,8 +439,6 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth_handler.generate_access_token( access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"] user_id, ["guest = true"]
) )
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,

View File

@ -15,8 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) raise AuthError(403, "tokenrefresh is no longer supported.")
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token
)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}))
except KeyError:
raise SynapseError(400, "Missing required key 'refresh_token'.")
except StoreError:
raise AuthError(403, "Did not recognize refresh token")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")

View File

@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
},
desc="add_refresh_token_to_user",
)
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_localpart=None, admin=False):
@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
token token
) )
def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
return self.runInteraction(
"exchange_refresh_token",
self._exchange_refresh_token,
refresh_token,
token_generator
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
new_token = token_generator(user_id)
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
return user_id, new_token, device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(

View File

@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=user_id return_value=user_id
) )
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname "home_server": self.hs.hostname
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}, None) }, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
self.device_handler.check_device_registered = \ self.device_handler.check_device_registered = \
Mock(return_value=device_id) Mock(return_value=device_id)
@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None) user_id, device_id=device_id, initial_device_display_name=None)

View File

@ -17,9 +17,6 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.util import stringutils
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -80,64 +77,12 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertTrue("token_id" in result) self.assertTrue("token_id" in result)
@defer.inlineCallbacks
def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32)
device_id = stringutils.random_string(16)
generator = TokenGenerator()
last_token = generator.generate(uid)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token, device_id) "
"VALUES(?,?,?)",
(uid, last_token, device_id))
(found_user_id, refresh_token, device_id) = \
yield self.store.exchange_refresh_token(last_token,
generator.generate)
self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery(
"SELECT token, device_id FROM refresh_tokens WHERE user_id = ?",
(uid, ))
self.assertEqual([(refresh_token, device_id)], rows)
# We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token)
@defer.inlineCallbacks
def test_exchange_refresh_token_none(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_exchange_refresh_token_invalid(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
wrong_token = "%s-wrong" % (last_token,)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, wrong_token,))
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
generator = TokenGenerator()
refresh_token = generator.generate(self.user_id)
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id) self.device_id)
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
self.device_id)
# now delete some # now delete some
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
@ -146,9 +91,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
# check they were deleted # check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1]) user = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertIsNone(user, "access token was not deleted by device_id") self.assertIsNone(user, "access token was not deleted by device_id")
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(refresh_token,
generator.generate)
# check the one not associated with the device was not deleted # check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0]) user = yield self.store.get_user_by_access_token(self.tokens[0])