Merge pull request #918 from negzi/bugfix_for_token_expiry
Bug fix: expire invalid access tokens
This commit is contained in:
commit
209e04fa11
|
@ -629,7 +629,10 @@ class Auth(object):
|
|||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
if self.hs.config.expire_access_token:
|
||||
raise
|
||||
ret = yield self._look_up_user_by_access_token(token)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -637,12 +637,13 @@ class AuthHandler(BaseHandler):
|
|||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||
defer.returnValue(refresh_token)
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
def generate_access_token(self, user_id, extra_caveats=None,
|
||||
duration_in_ms=(60 * 60 * 1000)):
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (60 * 60 * 1000)
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
|
|
|
@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler):
|
|||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds,
|
||||
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
||||
password_hash=None):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
token = self.auth_handler().generate_short_term_login_token(
|
||||
user_id, duration_seconds)
|
||||
token = self.auth_handler().generate_access_token(
|
||||
user_id, None, duration_in_ms)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
|
|
|
@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds,
|
||||
duration_in_ms=(duration_seconds * 1000),
|
||||
password_hash=password_hash
|
||||
)
|
||||
|
||||
|
|
|
@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
macaroon.add_first_party_caveat("time < 1") # ms
|
||||
macaroon.add_first_party_caveat("time < -2000") # ms
|
||||
|
||||
self.hs.clock.now = 5000 # seconds
|
||||
self.hs.config.expire_access_token = True
|
||||
|
@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase):
|
|||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_from_macaroon_with_valid_duration(self):
|
||||
# TODO(danielwh): Remove this mock when we remove the
|
||||
# get_user_by_access_token fallback.
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value={"name": "@baldrick:matrix.org"}
|
||||
)
|
||||
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value={"name": "@baldrick:matrix.org"}
|
||||
)
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
macaroon.add_first_party_caveat("time < 900000000") # ms
|
||||
|
||||
self.hs.clock.now = 5000 # seconds
|
||||
self.hs.config.expire_access_token = True
|
||||
|
||||
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
|
|
@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
http_client=None,
|
||||
expire_access_token=True)
|
||||
self.auth_handler = Mock(
|
||||
generate_short_term_login_token=Mock(return_value='secret'))
|
||||
generate_access_token=Mock(return_value='secret'))
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
self.hs.get_handlers().profile_handler = Mock()
|
||||
self.mock_handler = Mock(spec=[
|
||||
"generate_short_term_login_token",
|
||||
"generate_access_token",
|
||||
])
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
|
||||
|
|
Loading…
Reference in New Issue