Return the proper 403 Forbidden error during errors with JWT logins. (#7844)

This commit is contained in:
Patrick Cloke 2020-07-15 07:10:21 -04:00 committed by GitHub
parent 1d9dca02f9
commit 111e70d75c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 30 deletions

1
changelog.d/7844.bugfix Normal file
View File

@ -0,0 +1 @@
Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.

View File

@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims:
Providing the audience claim when not configured will cause validation to fail. Providing the audience claim when not configured will cause validation to fail.
In the case that the token is not valid, the homeserver must respond with In the case that the token is not valid, the homeserver must respond with
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`. `403 Forbidden` and an error code of `M_FORBIDDEN`.
(Note that this differs from the token based logins which return a
`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
As with other login types, there are additional fields (e.g. `device_id` and As with other login types, there are additional fields (e.g. `device_id` and
`initial_device_display_name`) which can be included in the above request. `initial_device_display_name`) which can be included in the above request.

View File

@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None) token = login_submission.get("token", None)
if token is None: if token is None:
raise LoginError( raise LoginError(
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
) )
import jwt import jwt
@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e: except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client. # A JWT error occurred, return some info back to the client.
raise LoginError( raise LoginError(
401, 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.UNAUTHORIZED,
) )
user = payload.get("sub", None) user = payload.get("sub", None)
if user is None: if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string() user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login( result = await self._complete_login(

View File

@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self): def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret") channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
"JWT validation failed: Signature verification failed", "JWT validation failed: Signature verification failed",
@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self): def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000}) channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired" channel.json_body["error"], "JWT validation failed: Signature has expired"
) )
@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self): def test_login_jwt_not_before(self):
now = int(time.time()) now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)", "JWT validation failed: The token is not yet valid (nbf)",
@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self): def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"}) channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT") self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config( @override_config(
@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid issuer. # An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer" channel.json_body["error"], "JWT validation failed: Invalid issuer"
) )
# Not providing an issuer. # Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim', 'JWT validation failed: Token is missing the "iss" claim',
@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid audience. # An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience" channel.json_body["error"], "JWT validation failed: Invalid audience"
) )
# Not providing an audience. # Not providing an audience.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim', 'JWT validation failed: Token is missing the "aud" claim',
@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self): def test_login_aud_no_config(self):
"""Test providing an audience without requiring it in the configuration.""" """Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience" channel.json_body["error"], "JWT validation failed: Invalid audience"
) )
@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
params = json.dumps({"type": "org.matrix.login.jwt"}) params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params) request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request) self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self): def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
"JWT validation failed: Signature verification failed", "JWT validation failed: Signature verification failed",