Save login tokens in database (#13844)
* Save login tokens in database Signed-off-by: Quentin Gliech <quenting@element.io> * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
parent
d902181de9
commit
8756d5c87e
|
@ -0,0 +1 @@
|
||||||
|
Save login tokens in database and prevent login token reuse.
|
|
@ -88,6 +88,15 @@ process, for example:
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Upgrading to v1.71.0
|
||||||
|
|
||||||
|
## Removal of the `generate_short_term_login_token` module API method
|
||||||
|
|
||||||
|
As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed.
|
||||||
|
|
||||||
|
Modules relying on it can instead use the `create_login_token` method.
|
||||||
|
|
||||||
|
|
||||||
# Upgrading to v1.69.0
|
# Upgrading to v1.69.0
|
||||||
|
|
||||||
## Changes to the receipts replication streams
|
## Changes to the receipts replication streams
|
||||||
|
|
|
@ -38,6 +38,7 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -48,6 +49,7 @@ from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
InteractiveAuthIncompleteError,
|
InteractiveAuthIncompleteError,
|
||||||
LoginError,
|
LoginError,
|
||||||
|
NotFoundError,
|
||||||
StoreError,
|
StoreError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
UserDeactivatedError,
|
UserDeactivatedError,
|
||||||
|
@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.databases.main.registration import (
|
||||||
|
LoginTokenExpired,
|
||||||
|
LoginTokenLookupResult,
|
||||||
|
LoginTokenReused,
|
||||||
|
)
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util import stringutils as stringutils
|
from synapse.util import stringutils as stringutils
|
||||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||||
from synapse.util.macaroons import LoginTokenAttributes
|
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.stringutils import base62_encode
|
from synapse.util.stringutils import base62_encode
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
|
INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
|
||||||
|
|
||||||
|
invalid_login_token_counter = Counter(
|
||||||
|
"synapse_user_login_invalid_login_tokens",
|
||||||
|
"Counts the number of rejected m.login.token on /login",
|
||||||
|
["reason"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_client_dict_legacy_fields_to_identifier(
|
def convert_client_dict_legacy_fields_to_identifier(
|
||||||
submission: JsonDict,
|
submission: JsonDict,
|
||||||
|
@ -883,6 +895,25 @@ class AuthHandler:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def create_login_token_for_user_id(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
duration_ms: int = (2 * 60 * 1000),
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
login_token = self.generate_login_token()
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
expiry_ts = now + duration_ms
|
||||||
|
await self.store.add_login_token_to_user(
|
||||||
|
user_id=user_id,
|
||||||
|
token=login_token,
|
||||||
|
expiry_ts=expiry_ts,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
)
|
||||||
|
return login_token
|
||||||
|
|
||||||
async def create_refresh_token_for_user_id(
|
async def create_refresh_token_for_user_id(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -1401,6 +1432,18 @@ class AuthHandler:
|
||||||
return None
|
return None
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
def generate_login_token(self) -> str:
|
||||||
|
"""Generates an opaque string, for use as an short-term login token"""
|
||||||
|
|
||||||
|
# we use the following format for access tokens:
|
||||||
|
# syl_<random string>_<base62 crc check>
|
||||||
|
|
||||||
|
random_string = stringutils.random_string(20)
|
||||||
|
base = f"syl_{random_string}"
|
||||||
|
|
||||||
|
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||||
|
return f"{base}_{crc}"
|
||||||
|
|
||||||
def generate_access_token(self, for_user: UserID) -> str:
|
def generate_access_token(self, for_user: UserID) -> str:
|
||||||
"""Generates an opaque string, for use as an access token"""
|
"""Generates an opaque string, for use as an access token"""
|
||||||
|
|
||||||
|
@ -1427,16 +1470,17 @@ class AuthHandler:
|
||||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||||
return f"{base}_{crc}"
|
return f"{base}_{crc}"
|
||||||
|
|
||||||
async def validate_short_term_login_token(
|
async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
|
||||||
self, login_token: str
|
|
||||||
) -> LoginTokenAttributes:
|
|
||||||
try:
|
try:
|
||||||
res = self.macaroon_gen.verify_short_term_login_token(login_token)
|
return await self.store.consume_login_token(login_token)
|
||||||
except Exception:
|
except LoginTokenExpired:
|
||||||
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
invalid_login_token_counter.labels("expired").inc()
|
||||||
|
except LoginTokenReused:
|
||||||
|
invalid_login_token_counter.labels("reused").inc()
|
||||||
|
except NotFoundError:
|
||||||
|
invalid_login_token_counter.labels("not found").inc()
|
||||||
|
|
||||||
await self.auth_blocking.check_auth_blocking(res.user_id)
|
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
|
||||||
return res
|
|
||||||
|
|
||||||
async def delete_access_token(self, access_token: str) -> None:
|
async def delete_access_token(self, access_token: str) -> None:
|
||||||
"""Invalidate a single access token
|
"""Invalidate a single access token
|
||||||
|
@ -1711,7 +1755,7 @@ class AuthHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = await self.create_login_token_for_user_id(
|
||||||
registered_user_id,
|
registered_user_id,
|
||||||
auth_provider_id=auth_provider_id,
|
auth_provider_id=auth_provider_id,
|
||||||
auth_provider_session_id=auth_provider_session_id,
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
|
|
@ -771,50 +771,11 @@ class ModuleApi:
|
||||||
auth_provider_session_id: The session ID got during login from the SSO IdP,
|
auth_provider_session_id: The session ID got during login from the SSO IdP,
|
||||||
if any.
|
if any.
|
||||||
"""
|
"""
|
||||||
# The deprecated `generate_short_term_login_token` method defaulted to an empty
|
return await self._hs.get_auth_handler().create_login_token_for_user_id(
|
||||||
# string for the `auth_provider_id` because of how the underlying macaroon was
|
|
||||||
# generated. This will change to a proper NULL-able field when the tokens get
|
|
||||||
# moved to the database.
|
|
||||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
|
||||||
user_id,
|
user_id,
|
||||||
auth_provider_id or "",
|
|
||||||
auth_provider_session_id,
|
|
||||||
duration_in_ms,
|
duration_in_ms,
|
||||||
)
|
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
duration_in_ms: int = (2 * 60 * 1000),
|
|
||||||
auth_provider_id: str = "",
|
|
||||||
auth_provider_session_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Generate a login token suitable for m.login.token authentication
|
|
||||||
|
|
||||||
Added in Synapse v1.9.0.
|
|
||||||
|
|
||||||
This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will
|
|
||||||
be removed in Synapse 1.71.0.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: gives the ID of the user that the token is for
|
|
||||||
|
|
||||||
duration_in_ms: the time that the token will be valid for
|
|
||||||
|
|
||||||
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
|
|
||||||
to get this token, if any. This is encoded in the token so that
|
|
||||||
/login can report stats on number of successful logins by IdP.
|
|
||||||
"""
|
|
||||||
logger.warn(
|
|
||||||
"A module configured on this server uses ModuleApi.generate_short_term_login_token(), "
|
|
||||||
"which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in "
|
|
||||||
"Synapse 1.71.0",
|
|
||||||
)
|
|
||||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
|
||||||
user_id,
|
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
auth_provider_session_id,
|
auth_provider_session_id,
|
||||||
duration_in_ms,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet):
|
||||||
The body of the JSON response.
|
The body of the JSON response.
|
||||||
"""
|
"""
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
res = await self.auth_handler.consume_login_token(token)
|
||||||
res = await auth_handler.validate_short_term_login_token(token)
|
|
||||||
|
|
||||||
return await self._complete_login(
|
return await self._complete_login(
|
||||||
res.user_id,
|
res.user_id,
|
||||||
|
|
|
@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.server_name = hs.config.server.server_name
|
self.server_name = hs.config.server.server_name
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.token_timeout = hs.config.experimental.msc3882_token_timeout
|
self.token_timeout = hs.config.experimental.msc3882_token_timeout
|
||||||
self.ui_auth = hs.config.experimental.msc3882_ui_auth
|
self.ui_auth = hs.config.experimental.msc3882_ui_auth
|
||||||
|
@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet):
|
||||||
can_skip_ui_auth=False, # Don't allow skipping of UI auth
|
can_skip_ui_auth=False, # Don't allow skipping of UI auth
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = await self.auth_handler.create_login_token_for_user_id(
|
||||||
user_id=requester.user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
auth_provider_id="org.matrix.msc3882.login_token_request",
|
auth_provider_id="org.matrix.msc3882.login_token_request",
|
||||||
duration_in_ms=self.token_timeout,
|
duration_ms=self.token_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import (
|
||||||
|
Codes,
|
||||||
|
NotFoundError,
|
||||||
|
StoreError,
|
||||||
|
SynapseError,
|
||||||
|
ThreepidValidationError,
|
||||||
|
)
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
|
@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
|
||||||
because this external id is given to an other user."""
|
because this external id is given to an other user."""
|
||||||
|
|
||||||
|
|
||||||
|
class LoginTokenExpired(Exception):
|
||||||
|
"""Exception if the login token sent expired"""
|
||||||
|
|
||||||
|
|
||||||
|
class LoginTokenReused(Exception):
|
||||||
|
"""Exception if the login token sent was already used"""
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
class TokenLookupResult:
|
class TokenLookupResult:
|
||||||
"""Result of looking up an access token.
|
"""Result of looking up an access token.
|
||||||
|
@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
|
||||||
If None, the session can be refreshed indefinitely."""
|
If None, the session can be refreshed indefinitely."""
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
class LoginTokenLookupResult:
|
||||||
|
"""Result of looking up a login token."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
"""The user this token belongs to."""
|
||||||
|
|
||||||
|
auth_provider_id: Optional[str]
|
||||||
|
"""The SSO Identity Provider that the user authenticated with, to get this token."""
|
||||||
|
|
||||||
|
auth_provider_session_id: Optional[str]
|
||||||
|
"""The session ID advertised by the SSO Identity Provider."""
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"replace_refresh_token", _replace_refresh_token_txn
|
"replace_refresh_token", _replace_refresh_token_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def add_login_token_to_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
token: str,
|
||||||
|
expiry_ts: int,
|
||||||
|
auth_provider_id: Optional[str],
|
||||||
|
auth_provider_session_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Adds a short-term login token for the given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID.
|
||||||
|
token: The new login token to add.
|
||||||
|
expiry_ts (milliseconds since the epoch): Time after which the login token
|
||||||
|
cannot be used.
|
||||||
|
auth_provider_id: The SSO Identity Provider that the user authenticated with
|
||||||
|
to get this token, if any
|
||||||
|
auth_provider_session_id: The session ID advertised by the SSO Identity
|
||||||
|
Provider, if any.
|
||||||
|
"""
|
||||||
|
await self.db_pool.simple_insert(
|
||||||
|
"login_tokens",
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"user_id": user_id,
|
||||||
|
"expiry_ts": expiry_ts,
|
||||||
|
"auth_provider_id": auth_provider_id,
|
||||||
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
|
},
|
||||||
|
desc="add_login_token_to_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _consume_login_token(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
token: str,
|
||||||
|
ts: int,
|
||||||
|
) -> LoginTokenLookupResult:
|
||||||
|
values = self.db_pool.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
"login_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
retcols=(
|
||||||
|
"user_id",
|
||||||
|
"expiry_ts",
|
||||||
|
"used_ts",
|
||||||
|
"auth_provider_id",
|
||||||
|
"auth_provider_session_id",
|
||||||
|
),
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if values is None:
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
"login_tokens",
|
||||||
|
keyvalues={"token": token},
|
||||||
|
updatevalues={"used_ts": ts},
|
||||||
|
)
|
||||||
|
user_id = values["user_id"]
|
||||||
|
expiry_ts = values["expiry_ts"]
|
||||||
|
used_ts = values["used_ts"]
|
||||||
|
auth_provider_id = values["auth_provider_id"]
|
||||||
|
auth_provider_session_id = values["auth_provider_session_id"]
|
||||||
|
|
||||||
|
# Token was already used
|
||||||
|
if used_ts is not None:
|
||||||
|
raise LoginTokenReused()
|
||||||
|
|
||||||
|
# Token expired
|
||||||
|
if ts > int(expiry_ts):
|
||||||
|
raise LoginTokenExpired()
|
||||||
|
|
||||||
|
return LoginTokenLookupResult(
|
||||||
|
user_id=user_id,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
|
||||||
|
"""Lookup a login token and consume it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The login token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The data stored with that token, including the `user_id`. Returns `None` if
|
||||||
|
the token does not exist or if it expired.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFound if the login token was not found in database
|
||||||
|
LoginTokenExpired if the login token expired
|
||||||
|
LoginTokenReused if the login token was already used
|
||||||
|
"""
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"consume_login_token",
|
||||||
|
self._consume_login_token,
|
||||||
|
token,
|
||||||
|
self._clock.time_msec(),
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def is_guest(self, user_id: str) -> bool:
|
async def is_guest(self, user_id: str) -> bool:
|
||||||
res = await self.db_pool.simple_select_one_onecol(
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
|
@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a background job for removing expired login tokens
|
||||||
|
if hs.config.worker.run_background_tasks:
|
||||||
|
self._clock.looping_call(
|
||||||
|
self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
|
||||||
|
)
|
||||||
|
|
||||||
async def add_access_token_to_user(
|
async def add_access_token_to_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
approved,
|
approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@wrap_as_background_process("delete_expired_login_tokens")
|
||||||
|
async def _delete_expired_login_tokens(self) -> None:
|
||||||
|
"""Remove login tokens with expiry dates that have passed."""
|
||||||
|
|
||||||
|
def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
|
||||||
|
sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
|
||||||
|
txn.execute(sql, (ts,))
|
||||||
|
|
||||||
|
# We keep the expired tokens for an extra 5 minutes so we can measure how many
|
||||||
|
# times a token is being used after its expiry
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"delete_expired_login_tokens",
|
||||||
|
_delete_expired_login_tokens_txn,
|
||||||
|
now - (5 * 60 * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Login tokens are short-lived tokens that are used for the m.login.token
|
||||||
|
-- login method, mainly during SSO logins
|
||||||
|
CREATE TABLE login_tokens (
|
||||||
|
token TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
expiry_ts BIGINT NOT NULL,
|
||||||
|
used_ts BIGINT,
|
||||||
|
auth_provider_id TEXT,
|
||||||
|
auth_provider_session_id TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- We're sometimes querying them by their session ID we got from their IDP
|
||||||
|
CREATE INDEX login_tokens_auth_provider_idx
|
||||||
|
ON login_tokens (auth_provider_id, auth_provider_session_id);
|
||||||
|
|
||||||
|
-- We're deleting them by their expiration time
|
||||||
|
CREATE INDEX login_tokens_expiry_time_idx
|
||||||
|
ON login_tokens (expiry_ts);
|
||||||
|
|
|
@ -24,7 +24,7 @@ from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.util import Clock, stringutils
|
from synapse.util import Clock, stringutils
|
||||||
|
|
||||||
MacaroonType = Literal["access", "delete_pusher", "session", "login"]
|
MacaroonType = Literal["access", "delete_pusher", "session"]
|
||||||
|
|
||||||
|
|
||||||
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
|
@ -111,19 +111,6 @@ class OidcSessionData:
|
||||||
"""The session ID of the ongoing UI Auth ("" if this is a login)"""
|
"""The session ID of the ongoing UI Auth ("" if this is a login)"""
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
|
||||||
class LoginTokenAttributes:
|
|
||||||
"""Data we store in a short-term login token"""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
auth_provider_id: str
|
|
||||||
"""The SSO Identity Provider that the user authenticated with, to get this token."""
|
|
||||||
|
|
||||||
auth_provider_session_id: Optional[str]
|
|
||||||
"""The session ID advertised by the SSO Identity Provider."""
|
|
||||||
|
|
||||||
|
|
||||||
class MacaroonGenerator:
|
class MacaroonGenerator:
|
||||||
def __init__(self, clock: Clock, location: str, secret_key: bytes):
|
def __init__(self, clock: Clock, location: str, secret_key: bytes):
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
|
@ -165,35 +152,6 @@ class MacaroonGenerator:
|
||||||
macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
|
macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
auth_provider_id: str,
|
|
||||||
auth_provider_session_id: Optional[str] = None,
|
|
||||||
duration_in_ms: int = (2 * 60 * 1000),
|
|
||||||
) -> str:
|
|
||||||
"""Generate a short-term login token used during SSO logins
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user for which the token is valid.
|
|
||||||
auth_provider_id: The SSO IdP the user used.
|
|
||||||
auth_provider_session_id: The session ID got during login from the SSO IdP.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A signed token valid for using as a ``m.login.token`` token.
|
|
||||||
"""
|
|
||||||
now = self._clock.time_msec()
|
|
||||||
expiry = now + duration_in_ms
|
|
||||||
macaroon = self._generate_base_macaroon("login")
|
|
||||||
macaroon.add_first_party_caveat(f"user_id = {user_id}")
|
|
||||||
macaroon.add_first_party_caveat(f"time < {expiry}")
|
|
||||||
macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
|
|
||||||
if auth_provider_session_id is not None:
|
|
||||||
macaroon.add_first_party_caveat(
|
|
||||||
f"auth_provider_session_id = {auth_provider_session_id}"
|
|
||||||
)
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def generate_oidc_session_token(
|
def generate_oidc_session_token(
|
||||||
self,
|
self,
|
||||||
state: str,
|
state: str,
|
||||||
|
@ -233,49 +191,6 @@ class MacaroonGenerator:
|
||||||
|
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
|
||||||
"""Verify a short-term-login macaroon
|
|
||||||
|
|
||||||
Checks that the given token is a valid, unexpired short-term-login token
|
|
||||||
minted by this server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The login token to verify.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A set of attributes carried by this token, including the
|
|
||||||
``user_id`` and informations about the SSO IDP used during that
|
|
||||||
login.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
MacaroonVerificationFailedException if the verification failed
|
|
||||||
"""
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
|
||||||
|
|
||||||
v = self._base_verifier("login")
|
|
||||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
|
||||||
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
|
||||||
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
|
|
||||||
satisfy_expiry(v, self._clock.time_msec)
|
|
||||||
v.verify(macaroon, self._secret_key)
|
|
||||||
|
|
||||||
user_id = get_value_from_macaroon(macaroon, "user_id")
|
|
||||||
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
|
||||||
|
|
||||||
auth_provider_session_id: Optional[str] = None
|
|
||||||
try:
|
|
||||||
auth_provider_session_id = get_value_from_macaroon(
|
|
||||||
macaroon, "auth_provider_session_id"
|
|
||||||
)
|
|
||||||
except MacaroonVerificationFailedException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return LoginTokenAttributes(
|
|
||||||
user_id=user_id,
|
|
||||||
auth_provider_id=auth_provider_id,
|
|
||||||
auth_provider_session_id=auth_provider_session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify_guest_token(self, token: str) -> str:
|
def verify_guest_token(self, token: str) -> str:
|
||||||
"""Verify a guest access token macaroon
|
"""Verify a guest access token macaroon
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, ResourceLimitError
|
from synapse.api.errors import AuthError, ResourceLimitError
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client import login
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
|
||||||
class AuthTestCase(unittest.HomeserverTestCase):
|
class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.user1 = self.register_user("a_user", "pass")
|
self.user1 = self.register_user("a_user", "pass")
|
||||||
|
|
||||||
|
def token_login(self, token: str) -> Optional[str]:
|
||||||
|
body = {
|
||||||
|
"type": "m.login.token",
|
||||||
|
"token": token,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/v3/login",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.code == 200:
|
||||||
|
return channel.json_body["user_id"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def test_macaroon_caveats(self) -> None:
|
def test_macaroon_caveats(self) -> None:
|
||||||
token = self.macaroon_generator.generate_guest_access_token("a_user")
|
token = self.macaroon_generator.generate_guest_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
@ -73,48 +93,61 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
v.satisfy_general(verify_guest)
|
v.satisfy_general(verify_guest)
|
||||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self) -> None:
|
def test_login_token_gives_user_id(self) -> None:
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.get_success(
|
||||||
self.user1, "", duration_in_ms=5000
|
self.auth_handler.create_login_token_for_user_id(
|
||||||
|
self.user1,
|
||||||
|
duration_ms=(5 * 1000),
|
||||||
)
|
)
|
||||||
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
)
|
||||||
|
|
||||||
|
res = self.get_success(self.auth_handler.consume_login_token(token))
|
||||||
self.assertEqual(self.user1, res.user_id)
|
self.assertEqual(self.user1, res.user_id)
|
||||||
self.assertEqual("", res.auth_provider_id)
|
self.assertEqual(None, res.auth_provider_id)
|
||||||
|
|
||||||
|
def test_login_token_reuse_fails(self) -> None:
|
||||||
|
token = self.get_success(
|
||||||
|
self.auth_handler.create_login_token_for_user_id(
|
||||||
|
self.user1,
|
||||||
|
duration_ms=(5 * 1000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(self.auth_handler.consume_login_token(token))
|
||||||
|
|
||||||
|
self.get_failure(
|
||||||
|
self.auth_handler.consume_login_token(token),
|
||||||
|
AuthError,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_login_token_expires(self) -> None:
|
||||||
|
token = self.get_success(
|
||||||
|
self.auth_handler.create_login_token_for_user_id(
|
||||||
|
self.user1,
|
||||||
|
duration_ms=(5 * 1000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.reactor.advance(6)
|
self.reactor.advance(6)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token(token),
|
self.auth_handler.consume_login_token(token),
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_auth_provider(self) -> None:
|
def test_login_token_gives_auth_provider(self) -> None:
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.get_success(
|
||||||
self.user1, auth_provider_id="my_idp"
|
self.auth_handler.create_login_token_for_user_id(
|
||||||
|
self.user1,
|
||||||
|
auth_provider_id="my_idp",
|
||||||
|
auth_provider_session_id="11-22-33-44",
|
||||||
|
duration_ms=(5 * 1000),
|
||||||
)
|
)
|
||||||
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
)
|
||||||
|
res = self.get_success(self.auth_handler.consume_login_token(token))
|
||||||
self.assertEqual(self.user1, res.user_id)
|
self.assertEqual(self.user1, res.user_id)
|
||||||
self.assertEqual("my_idp", res.auth_provider_id)
|
self.assertEqual("my_idp", res.auth_provider_id)
|
||||||
|
self.assertEqual("11-22-33-44", res.auth_provider_session_id)
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
|
||||||
self.user1, "", duration_in_ms=5000
|
|
||||||
)
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
|
||||||
|
|
||||||
res = self.get_success(
|
|
||||||
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
|
|
||||||
)
|
|
||||||
self.assertEqual(self.user1, res.user_id)
|
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
|
||||||
# user_id.
|
|
||||||
macaroon.add_first_party_caveat("user_id = b_user")
|
|
||||||
|
|
||||||
self.get_failure(
|
|
||||||
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
|
|
||||||
AuthError,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_mau_limits_disabled(self) -> None:
|
def test_mau_limits_disabled(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = False
|
self.auth_blocking._limit_usage_by_mau = False
|
||||||
|
@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
token = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token(
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
||||||
self._get_macaroon().serialize()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(self.token_login(token))
|
||||||
|
|
||||||
def test_mau_limits_exceeded_large(self) -> None:
|
def test_mau_limits_exceeded_large(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||||
|
@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||||
return_value=make_awaitable(self.large_number_of_users)
|
return_value=make_awaitable(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_failure(
|
token = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token(
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
||||||
self._get_macaroon().serialize()
|
|
||||||
),
|
|
||||||
ResourceLimitError,
|
|
||||||
)
|
)
|
||||||
|
self.assertIsNone(self.token_login(token))
|
||||||
|
|
||||||
def test_mau_limits_parity(self) -> None:
|
def test_mau_limits_parity(self) -> None:
|
||||||
# Ensure we're not at the unix epoch.
|
# Ensure we're not at the unix epoch.
|
||||||
|
@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
),
|
),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.get_failure(
|
token = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token(
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
||||||
self._get_macaroon().serialize()
|
|
||||||
),
|
|
||||||
ResourceLimitError,
|
|
||||||
)
|
)
|
||||||
|
self.assertIsNone(self.token_login(token))
|
||||||
|
|
||||||
# If in monthly active cohort
|
# If in monthly active cohort
|
||||||
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
|
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
|
||||||
|
@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.user1, device_id=None, valid_until_ms=None
|
self.user1, device_id=None, valid_until_ms=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
token = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token(
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
||||||
self._get_macaroon().serialize()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
self.assertIsNotNone(self.token_login(token))
|
||||||
|
|
||||||
def test_mau_limits_not_exceeded(self) -> None:
|
def test_mau_limits_not_exceeded(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
|
@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||||
return_value=make_awaitable(self.small_number_of_users)
|
return_value=make_awaitable(self.small_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_success(
|
token = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token(
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
||||||
self._get_macaroon().serialize()
|
|
||||||
)
|
)
|
||||||
)
|
self.assertIsNotNone(self.token_login(token))
|
||||||
|
|
||||||
def _get_macaroon(self) -> pymacaroons.Macaroon:
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
|
||||||
self.user1, "", duration_in_ms=5000
|
|
||||||
)
|
|
||||||
return pymacaroons.Macaroon.deserialize(token)
|
|
||||||
|
|
|
@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(user_id, "@user:tesths")
|
self.assertEqual(user_id, "@user:tesths")
|
||||||
|
|
||||||
def test_short_term_login_token(self):
|
|
||||||
"""Test the generation and verification of short-term login tokens"""
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
|
||||||
user_id="@user:tesths",
|
|
||||||
auth_provider_id="oidc",
|
|
||||||
auth_provider_session_id="sid",
|
|
||||||
duration_in_ms=2 * 60 * 1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
info = self.macaroon_generator.verify_short_term_login_token(token)
|
|
||||||
self.assertEqual(info.user_id, "@user:tesths")
|
|
||||||
self.assertEqual(info.auth_provider_id, "oidc")
|
|
||||||
self.assertEqual(info.auth_provider_session_id, "sid")
|
|
||||||
|
|
||||||
# Raises with another secret key
|
|
||||||
with self.assertRaises(MacaroonVerificationFailedException):
|
|
||||||
self.other_macaroon_generator.verify_short_term_login_token(token)
|
|
||||||
|
|
||||||
# Wait a minute
|
|
||||||
self.reactor.pump([60])
|
|
||||||
# Shouldn't raise
|
|
||||||
self.macaroon_generator.verify_short_term_login_token(token)
|
|
||||||
# Wait another minute
|
|
||||||
self.reactor.pump([60])
|
|
||||||
# Should raise since it expired
|
|
||||||
with self.assertRaises(MacaroonVerificationFailedException):
|
|
||||||
self.macaroon_generator.verify_short_term_login_token(token)
|
|
||||||
|
|
||||||
def test_oidc_session_token(self):
|
def test_oidc_session_token(self):
|
||||||
"""Test the generation and verification of OIDC session cookies"""
|
"""Test the generation and verification of OIDC session cookies"""
|
||||||
state = "arandomstate"
|
state = "arandomstate"
|
||||||
|
|
Loading…
Reference in New Issue