Persist user interactive authentication sessions (#7302)
By persisting the user interactive authentication sessions to the database, this fixes situations where a user hits different works throughout their auth session and also allows sessions to persist through restarts of Synapse.
This commit is contained in:
parent
9d8ecc9e6c
commit
627b0f5f27
|
@ -0,0 +1 @@
|
||||||
|
Persist user interactive authentication sessions across workers and Synapse restarts.
|
|
@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
|
||||||
MonthlyActiveUsersWorkerStore,
|
MonthlyActiveUsersWorkerStore,
|
||||||
)
|
)
|
||||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||||
|
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||||
from synapse.types import ReadReceipt
|
from synapse.types import ReadReceipt
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
|
||||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||||
# rather than going via the correct worker.
|
# rather than going via the correct worker.
|
||||||
UserDirectoryStore,
|
UserDirectoryStore,
|
||||||
|
UIAuthWorkerStore,
|
||||||
SlavedDeviceInboxStore,
|
SlavedDeviceInboxStore,
|
||||||
SlavedDeviceStore,
|
SlavedDeviceStore,
|
||||||
SlavedReceiptsStore,
|
SlavedReceiptsStore,
|
||||||
|
|
|
@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.http.server import finish_request
|
from synapse.http.server import finish_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.push.mailer import load_jinja2_templates
|
from synapse.push.mailer import load_jinja2_templates
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import Requester, UserID
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||||
|
|
||||||
# This is not a cache per se, but a store of all current sessions that
|
|
||||||
# expire after N hours
|
|
||||||
self.sessions = ExpiringCache(
|
|
||||||
cache_name="register_sessions",
|
|
||||||
clock=hs.get_clock(),
|
|
||||||
expiry_ms=self.SESSION_EXPIRE_MS,
|
|
||||||
reset_expiry_on_get=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
account_handler = ModuleApi(hs, self)
|
account_handler = ModuleApi(hs, self)
|
||||||
self.password_providers = [
|
self.password_providers = [
|
||||||
module(config=config, account_handler=account_handler)
|
module(config=config, account_handler=account_handler)
|
||||||
|
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
self._clock = self.hs.get_clock()
|
self._clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
# Expire old UI auth sessions after a period of time.
|
||||||
|
if hs.config.worker_app is None:
|
||||||
|
self._clock.looping_call(
|
||||||
|
run_as_background_process,
|
||||||
|
5 * 60 * 1000,
|
||||||
|
"expire_old_sessions",
|
||||||
|
self._expire_old_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
# Load the SSO HTML templates.
|
# Load the SSO HTML templates.
|
||||||
|
|
||||||
# The following template is shown to the user during a client login via SSO,
|
# The following template is shown to the user during a client login via SSO,
|
||||||
|
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
|
||||||
if "session" in authdict:
|
if "session" in authdict:
|
||||||
sid = authdict["session"]
|
sid = authdict["session"]
|
||||||
|
|
||||||
|
# Convert the URI and method to strings.
|
||||||
|
uri = request.uri.decode("utf-8")
|
||||||
|
method = request.uri.decode("utf-8")
|
||||||
|
|
||||||
# If there's no session ID, create a new session.
|
# If there's no session ID, create a new session.
|
||||||
if not sid:
|
if not sid:
|
||||||
session = self._create_session(
|
session = await self.store.create_ui_auth_session(
|
||||||
clientdict, (request.uri, request.method, clientdict), description
|
clientdict, uri, method, description
|
||||||
)
|
)
|
||||||
session_id = session["id"]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
session = self._get_session_info(sid)
|
try:
|
||||||
session_id = sid
|
session = await self.store.get_ui_auth_session(sid)
|
||||||
|
except StoreError:
|
||||||
|
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
|
||||||
|
|
||||||
if not clientdict:
|
if not clientdict:
|
||||||
# This was designed to allow the client to omit the parameters
|
# This was designed to allow the client to omit the parameters
|
||||||
|
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
|
||||||
# on a homeserver.
|
# on a homeserver.
|
||||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||||
# isn't arbitrary.
|
# isn't arbitrary.
|
||||||
clientdict = session["clientdict"]
|
clientdict = session.clientdict
|
||||||
|
|
||||||
# Ensure that the queried operation does not vary between stages of
|
# Ensure that the queried operation does not vary between stages of
|
||||||
# the UI authentication session. This is done by generating a stable
|
# the UI authentication session. This is done by generating a stable
|
||||||
# comparator based on the URI, method, and body (minus the auth dict)
|
# comparator based on the URI, method, and body (minus the auth dict)
|
||||||
# and storing it during the initial query. Subsequent queries ensure
|
# and storing it during the initial query. Subsequent queries ensure
|
||||||
# that this comparator has not changed.
|
# that this comparator has not changed.
|
||||||
comparator = (request.uri, request.method, clientdict)
|
comparator = (uri, method, clientdict)
|
||||||
if session["ui_auth"] != comparator:
|
if (session.uri, session.method, session.clientdict) != comparator:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Requested operation has changed during the UI authentication session.",
|
"Requested operation has changed during the UI authentication session.",
|
||||||
|
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
raise InteractiveAuthIncompleteError(
|
raise InteractiveAuthIncompleteError(
|
||||||
self._auth_dict_for_flows(flows, session_id)
|
self._auth_dict_for_flows(flows, session.session_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
creds = session["creds"]
|
|
||||||
|
|
||||||
# check auth type currently being presented
|
# check auth type currently being presented
|
||||||
errordict = {} # type: Dict[str, Any]
|
errordict = {} # type: Dict[str, Any]
|
||||||
if "type" in authdict:
|
if "type" in authdict:
|
||||||
|
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
result = await self._check_auth_dict(authdict, clientip)
|
result = await self._check_auth_dict(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
creds[login_type] = result
|
await self.store.mark_ui_auth_stage_complete(
|
||||||
self._save_session(session)
|
session.session_id, login_type, result
|
||||||
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
if login_type == LoginType.EMAIL_IDENTITY:
|
if login_type == LoginType.EMAIL_IDENTITY:
|
||||||
# riot used to have a bug where it would request a new
|
# riot used to have a bug where it would request a new
|
||||||
|
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
|
||||||
# so that the client can have another go.
|
# so that the client can have another go.
|
||||||
errordict = e.error_dict()
|
errordict = e.error_dict()
|
||||||
|
|
||||||
|
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
|
||||||
for f in flows:
|
for f in flows:
|
||||||
if len(set(f) - set(creds)) == 0:
|
if len(set(f) - set(creds)) == 0:
|
||||||
# it's very useful to know what args are stored, but this can
|
# it's very useful to know what args are stored, but this can
|
||||||
|
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
|
||||||
list(clientdict),
|
list(clientdict),
|
||||||
)
|
)
|
||||||
|
|
||||||
return creds, clientdict, session_id
|
return creds, clientdict, session.session_id
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session_id)
|
ret = self._auth_dict_for_flows(flows, session.session_id)
|
||||||
ret["completed"] = list(creds)
|
ret["completed"] = list(creds)
|
||||||
ret.update(errordict)
|
ret.update(errordict)
|
||||||
raise InteractiveAuthIncompleteError(ret)
|
raise InteractiveAuthIncompleteError(ret)
|
||||||
|
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
|
||||||
if "session" not in authdict:
|
if "session" not in authdict:
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
sess = self._get_session_info(authdict["session"])
|
|
||||||
creds = sess["creds"]
|
|
||||||
|
|
||||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
creds[stagetype] = result
|
await self.store.mark_ui_auth_stage_complete(
|
||||||
self._save_session(sess)
|
authdict["session"], stagetype, result
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
|
||||||
sid = authdict["session"]
|
sid = authdict["session"]
|
||||||
return sid
|
return sid
|
||||||
|
|
||||||
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Store a key-value pair into the sessions data associated with this
|
Store a key-value pair into the sessions data associated with this
|
||||||
request. This data is stored server-side and cannot be modified by
|
request. This data is stored server-side and cannot be modified by
|
||||||
|
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
|
||||||
key: The key to store the data under
|
key: The key to store the data under
|
||||||
value: The data to store
|
value: The data to store
|
||||||
"""
|
"""
|
||||||
sess = self._get_session_info(session_id)
|
try:
|
||||||
sess["serverdict"][key] = value
|
await self.store.set_ui_auth_session_data(session_id, key, value)
|
||||||
self._save_session(sess)
|
except StoreError:
|
||||||
|
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
|
|
||||||
def get_session_data(
|
async def get_session_data(
|
||||||
self, session_id: str, key: str, default: Optional[Any] = None
|
self, session_id: str, key: str, default: Optional[Any] = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
|
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
|
||||||
key: The key to store the data under
|
key: The key to store the data under
|
||||||
default: Value to return if the key has not been set
|
default: Value to return if the key has not been set
|
||||||
"""
|
"""
|
||||||
sess = self._get_session_info(session_id)
|
try:
|
||||||
return sess["serverdict"].get(key, default)
|
return await self.store.get_ui_auth_session_data(session_id, key, default)
|
||||||
|
except StoreError:
|
||||||
|
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
|
|
||||||
|
async def _expire_old_sessions(self):
|
||||||
|
"""
|
||||||
|
Invalidate any user interactive authentication sessions that have expired.
|
||||||
|
"""
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
expiration_time = now - self.SESSION_EXPIRE_MS
|
||||||
|
await self.store.delete_old_ui_auth_sessions(expiration_time)
|
||||||
|
|
||||||
async def _check_auth_dict(
|
async def _check_auth_dict(
|
||||||
self, authdict: Dict[str, Any], clientip: str
|
self, authdict: Dict[str, Any], clientip: str
|
||||||
|
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
|
||||||
"params": params,
|
"params": params,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_session(
|
|
||||||
self,
|
|
||||||
clientdict: Dict[str, Any],
|
|
||||||
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
|
|
||||||
description: str,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Creates a new user interactive authentication session.
|
|
||||||
|
|
||||||
The session can be used to track data across multiple requests, e.g. for
|
|
||||||
interactive authentication.
|
|
||||||
|
|
||||||
Each session has the following keys:
|
|
||||||
|
|
||||||
id:
|
|
||||||
A unique identifier for this session. Passed back to the client
|
|
||||||
and returned for each stage.
|
|
||||||
clientdict:
|
|
||||||
The dictionary from the client root level, not the 'auth' key.
|
|
||||||
ui_auth:
|
|
||||||
A tuple which is checked at each stage of the authentication to
|
|
||||||
ensure that the asked for operation has not changed.
|
|
||||||
creds:
|
|
||||||
A map, which maps each auth-type (str) to the relevant identity
|
|
||||||
authenticated by that auth-type (mostly str, but for captcha, bool).
|
|
||||||
serverdict:
|
|
||||||
A map of data that is stored server-side and cannot be modified
|
|
||||||
by the client.
|
|
||||||
description:
|
|
||||||
A string description of the operation that the current
|
|
||||||
authentication is authorising.
|
|
||||||
Returns:
|
|
||||||
The newly created session.
|
|
||||||
"""
|
|
||||||
session_id = None
|
|
||||||
while session_id is None or session_id in self.sessions:
|
|
||||||
session_id = stringutils.random_string(24)
|
|
||||||
|
|
||||||
self.sessions[session_id] = {
|
|
||||||
"id": session_id,
|
|
||||||
"clientdict": clientdict,
|
|
||||||
"ui_auth": ui_auth,
|
|
||||||
"creds": {},
|
|
||||||
"serverdict": {},
|
|
||||||
"description": description,
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.sessions[session_id]
|
|
||||||
|
|
||||||
def _get_session_info(self, session_id: str) -> dict:
|
|
||||||
"""
|
|
||||||
Gets a session given a session ID.
|
|
||||||
|
|
||||||
The session can be used to track data across multiple requests, e.g. for
|
|
||||||
interactive authentication.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.sessions[session_id]
|
|
||||||
except KeyError:
|
|
||||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
|
||||||
|
|
||||||
async def get_access_token_for_user_id(
|
async def get_access_token_for_user_id(
|
||||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||||
):
|
):
|
||||||
|
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
|
||||||
await self.store.user_delete_threepid(user_id, medium, address)
|
await self.store.user_delete_threepid(user_id, medium, address)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _save_session(self, session: Dict[str, Any]) -> None:
|
|
||||||
"""Update the last used time on the session to now and add it back to the session store."""
|
|
||||||
# TODO: Persistent storage
|
|
||||||
logger.debug("Saving session %s", session)
|
|
||||||
session["last_used"] = self.hs.get_clock().time_msec()
|
|
||||||
self.sessions[session["id"]] = session
|
|
||||||
|
|
||||||
async def hash(self, password: str) -> str:
|
async def hash(self, password: str) -> str:
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
|
||||||
|
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the HTML for the SSO redirect confirmation page.
|
Get the HTML for the SSO redirect confirmation page.
|
||||||
|
|
||||||
|
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
|
||||||
Returns:
|
Returns:
|
||||||
The HTML to render.
|
The HTML to render.
|
||||||
"""
|
"""
|
||||||
session = self._get_session_info(session_id)
|
try:
|
||||||
|
session = await self.store.get_ui_auth_session(session_id)
|
||||||
|
except StoreError:
|
||||||
|
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
return self._sso_auth_confirm_template.render(
|
return self._sso_auth_confirm_template.render(
|
||||||
description=session["description"], redirect_url=redirect_url,
|
description=session.description, redirect_url=redirect_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
def complete_sso_ui_auth(
|
async def complete_sso_ui_auth(
|
||||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||||
):
|
):
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
"""Having figured out a mxid for this user, complete the HTTP request
|
||||||
|
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
|
||||||
process.
|
process.
|
||||||
"""
|
"""
|
||||||
# Mark the stage of the authentication as successful.
|
# Mark the stage of the authentication as successful.
|
||||||
sess = self._get_session_info(session_id)
|
|
||||||
creds = sess["creds"]
|
|
||||||
|
|
||||||
# Save the user who authenticated with SSO, this will be used to ensure
|
# Save the user who authenticated with SSO, this will be used to ensure
|
||||||
# that the account be modified is also the person who logged in.
|
# that the account be modified is also the person who logged in.
|
||||||
creds[LoginType.SSO] = registered_user_id
|
await self.store.mark_ui_auth_stage_complete(
|
||||||
self._save_session(sess)
|
session_id, LoginType.SSO, registered_user_id
|
||||||
|
)
|
||||||
|
|
||||||
# Render the HTML and return.
|
# Render the HTML and return.
|
||||||
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
||||||
|
|
|
@ -206,7 +206,7 @@ class CasHandler:
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
self._auth_handler.complete_sso_ui_auth(
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
registered_user_id, session, request,
|
registered_user_id, session, request,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -149,7 +149,7 @@ class SamlHandler:
|
||||||
|
|
||||||
# Complete the interactive auth session or the login.
|
# Complete the interactive auth session or the login.
|
||||||
if current_session and current_session.ui_auth_session_id:
|
if current_session and current_session.ui_auth_session_id:
|
||||||
self._auth_handler.complete_sso_ui_auth(
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
user_id, current_session.ui_auth_session_id, request
|
user_id, current_session.ui_auth_session_id, request
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
|
||||||
self._cas_server_url = hs.config.cas_server_url
|
self._cas_server_url = hs.config.cas_server_url
|
||||||
self._cas_service_url = hs.config.cas_service_url
|
self._cas_service_url = hs.config.cas_service_url
|
||||||
|
|
||||||
def on_GET(self, request, stagetype):
|
async def on_GET(self, request, stagetype):
|
||||||
session = parse_string(request, "session")
|
session = parse_string(request, "session")
|
||||||
if not session:
|
if not session:
|
||||||
raise SynapseError(400, "No session supplied")
|
raise SynapseError(400, "No session supplied")
|
||||||
|
@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
raise SynapseError(400, "Homeserver not configured for SSO.")
|
||||||
|
|
||||||
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise SynapseError(404, "Unknown auth stage type")
|
raise SynapseError(404, "Unknown auth stage type")
|
||||||
|
|
|
@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# registered a user for this session, so we could just return the
|
# registered a user for this session, so we could just return the
|
||||||
# user here. We carry on and go through the auth checks though,
|
# user here. We carry on and go through the auth checks though,
|
||||||
# for paranoia.
|
# for paranoia.
|
||||||
registered_user_id = self.auth_handler.get_session_data(
|
registered_user_id = await self.auth_handler.get_session_data(
|
||||||
session_id, "registered_user_id", None
|
session_id, "registered_user_id", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
# remember that we've now registered that user account, and with
|
# remember that we've now registered that user account, and with
|
||||||
# what user ID (since the user may not have specified)
|
# what user ID (since the user may not have specified)
|
||||||
self.auth_handler.set_session_data(
|
await self.auth_handler.set_session_data(
|
||||||
session_id, "registered_user_id", registered_user_id
|
session_id, "registered_user_id", registered_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,7 @@ from .stats import StatsStore
|
||||||
from .stream import StreamStore
|
from .stream import StreamStore
|
||||||
from .tags import TagsStore
|
from .tags import TagsStore
|
||||||
from .transactions import TransactionStore
|
from .transactions import TransactionStore
|
||||||
|
from .ui_auth import UIAuthStore
|
||||||
from .user_directory import UserDirectoryStore
|
from .user_directory import UserDirectoryStore
|
||||||
from .user_erasure_store import UserErasureStore
|
from .user_erasure_store import UserErasureStore
|
||||||
|
|
||||||
|
@ -112,6 +113,7 @@ class DataStore(
|
||||||
StatsStore,
|
StatsStore,
|
||||||
RelationsStore,
|
RelationsStore,
|
||||||
CacheInvalidationStore,
|
CacheInvalidationStore,
|
||||||
|
UIAuthStore,
|
||||||
):
|
):
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ui_auth_sessions(
|
||||||
|
session_id TEXT NOT NULL, -- The session ID passed to the client.
|
||||||
|
creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
|
||||||
|
serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
|
||||||
|
clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
|
||||||
|
uri TEXT NOT NULL, -- The URI the UI authentication session is using.
|
||||||
|
method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
|
||||||
|
-- The clientdict, uri, and method make up an tuple that must be immutable
|
||||||
|
-- throughout the lifetime of the UI Auth session.
|
||||||
|
description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
|
||||||
|
UNIQUE (session_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
|
||||||
|
session_id TEXT NOT NULL, -- The corresponding UI Auth session.
|
||||||
|
stage_type TEXT NOT NULL, -- The stage type.
|
||||||
|
result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
|
||||||
|
UNIQUE (session_id, stage_type),
|
||||||
|
FOREIGN KEY (session_id)
|
||||||
|
REFERENCES ui_auth_sessions (session_id)
|
||||||
|
);
|
|
@ -0,0 +1,279 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
import synapse.util.stringutils as stringutils
|
||||||
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class UIAuthSessionData:
|
||||||
|
session_id = attr.ib(type=str)
|
||||||
|
# The dictionary from the client root level, not the 'auth' key.
|
||||||
|
clientdict = attr.ib(type=JsonDict)
|
||||||
|
# The URI and method the session was intiatied with. These are checked at
|
||||||
|
# each stage of the authentication to ensure that the asked for operation
|
||||||
|
# has not changed.
|
||||||
|
uri = attr.ib(type=str)
|
||||||
|
method = attr.ib(type=str)
|
||||||
|
# A string description of the operation that the current authentication is
|
||||||
|
# authorising.
|
||||||
|
description = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
|
class UIAuthWorkerStore(SQLBaseStore):
|
||||||
|
"""
|
||||||
|
Manage user interactive authentication sessions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def create_ui_auth_session(
|
||||||
|
self, clientdict: JsonDict, uri: str, method: str, description: str,
|
||||||
|
) -> UIAuthSessionData:
|
||||||
|
"""
|
||||||
|
Creates a new user interactive authentication session.
|
||||||
|
|
||||||
|
The session can be used to track the stages necessary to authenticate a
|
||||||
|
user across multiple HTTP requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clientdict:
|
||||||
|
The dictionary from the client root level, not the 'auth' key.
|
||||||
|
uri:
|
||||||
|
The URI this session was initiated with, this is checked at each
|
||||||
|
stage of the authentication to ensure that the asked for
|
||||||
|
operation has not changed.
|
||||||
|
method:
|
||||||
|
The method this session was initiated with, this is checked at each
|
||||||
|
stage of the authentication to ensure that the asked for
|
||||||
|
operation has not changed.
|
||||||
|
description:
|
||||||
|
A string description of the operation that the current
|
||||||
|
authentication is authorising.
|
||||||
|
Returns:
|
||||||
|
The newly created session.
|
||||||
|
Raises:
|
||||||
|
StoreError if a unique session ID cannot be generated.
|
||||||
|
"""
|
||||||
|
# The clientdict gets stored as JSON.
|
||||||
|
clientdict_json = json.dumps(clientdict)
|
||||||
|
|
||||||
|
# autogen a session ID and try to create it. We may clash, so just
|
||||||
|
# try a few times till one goes through, giving up eventually.
|
||||||
|
attempts = 0
|
||||||
|
while attempts < 5:
|
||||||
|
session_id = stringutils.random_string(24)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.db.simple_insert(
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
values={
|
||||||
|
"session_id": session_id,
|
||||||
|
"clientdict": clientdict_json,
|
||||||
|
"uri": uri,
|
||||||
|
"method": method,
|
||||||
|
"description": description,
|
||||||
|
"serverdict": "{}",
|
||||||
|
"creation_time": self.hs.get_clock().time_msec(),
|
||||||
|
},
|
||||||
|
desc="create_ui_auth_session",
|
||||||
|
)
|
||||||
|
return UIAuthSessionData(
|
||||||
|
session_id, clientdict, uri, method, description
|
||||||
|
)
|
||||||
|
except self.db.engine.module.IntegrityError:
|
||||||
|
attempts += 1
|
||||||
|
raise StoreError(500, "Couldn't generate a session ID.")
|
||||||
|
|
||||||
|
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
|
||||||
|
"""Retrieve a UI auth session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session.
|
||||||
|
Returns:
|
||||||
|
A dict containing the device information.
|
||||||
|
Raises:
|
||||||
|
StoreError if the session is not found.
|
||||||
|
"""
|
||||||
|
result = await self.db.simple_select_one(
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("clientdict", "uri", "method", "description"),
|
||||||
|
desc="get_ui_auth_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
result["clientdict"] = json.loads(result["clientdict"])
|
||||||
|
|
||||||
|
return UIAuthSessionData(session_id, **result)
|
||||||
|
|
||||||
|
async def mark_ui_auth_stage_complete(
|
||||||
|
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Mark a session stage as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the corresponding session.
|
||||||
|
stage_type: The completed stage type.
|
||||||
|
result: The result of the stage verification.
|
||||||
|
Raises:
|
||||||
|
StoreError if the session cannot be found.
|
||||||
|
"""
|
||||||
|
# Add (or update) the results of the current stage to the database.
|
||||||
|
#
|
||||||
|
# Note that we need to allow for the same stage to complete multiple
|
||||||
|
# times here so that registration is idempotent.
|
||||||
|
try:
|
||||||
|
await self.db.simple_upsert(
|
||||||
|
table="ui_auth_sessions_credentials",
|
||||||
|
keyvalues={"session_id": session_id, "stage_type": stage_type},
|
||||||
|
values={"result": json.dumps(result)},
|
||||||
|
desc="mark_ui_auth_stage_complete",
|
||||||
|
)
|
||||||
|
except self.db.engine.module.IntegrityError:
|
||||||
|
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
|
|
||||||
|
async def get_completed_ui_auth_stages(
|
||||||
|
self, session_id: str
|
||||||
|
) -> Dict[str, Union[str, bool, JsonDict]]:
|
||||||
|
"""
|
||||||
|
Retrieve the completed stages of a UI authentication session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session.
|
||||||
|
Returns:
|
||||||
|
The completed stages mapped to the result of the verification of
|
||||||
|
that auth-type.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for row in await self.db.simple_select_list(
|
||||||
|
table="ui_auth_sessions_credentials",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("stage_type", "result"),
|
||||||
|
desc="get_completed_ui_auth_stages",
|
||||||
|
):
|
||||||
|
results[row["stage_type"]] = json.loads(row["result"])
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
|
||||||
|
"""
|
||||||
|
Store a key-value pair into the sessions data associated with this
|
||||||
|
request. This data is stored server-side and cannot be modified by
|
||||||
|
the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of this session as returned from check_auth
|
||||||
|
key: The key to store the data under
|
||||||
|
value: The data to store
|
||||||
|
Raises:
|
||||||
|
StoreError if the session cannot be found.
|
||||||
|
"""
|
||||||
|
await self.db.runInteraction(
|
||||||
|
"set_ui_auth_session_data",
|
||||||
|
self._set_ui_auth_session_data_txn,
|
||||||
|
session_id,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
|
||||||
|
# Get the current value.
|
||||||
|
result = self.db.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("serverdict",),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update it and add it back to the database.
|
||||||
|
serverdict = json.loads(result["serverdict"])
|
||||||
|
serverdict[key] = value
|
||||||
|
|
||||||
|
self.db.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
updatevalues={"serverdict": json.dumps(serverdict)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_ui_auth_session_data(
|
||||||
|
self, session_id: str, key: str, default: Optional[Any] = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Retrieve data stored with set_session_data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of this session as returned from check_auth
|
||||||
|
key: The key to store the data under
|
||||||
|
default: Value to return if the key has not been set
|
||||||
|
Raises:
|
||||||
|
StoreError if the session cannot be found.
|
||||||
|
"""
|
||||||
|
result = await self.db.simple_select_one(
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("serverdict",),
|
||||||
|
desc="get_ui_auth_session_data",
|
||||||
|
)
|
||||||
|
|
||||||
|
serverdict = json.loads(result["serverdict"])
|
||||||
|
|
||||||
|
return serverdict.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class UIAuthStore(UIAuthWorkerStore):
|
||||||
|
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||||
|
"""
|
||||||
|
Remove sessions which were last used earlier than the expiration time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expiration_time: The latest time that is still considered valid.
|
||||||
|
This is an epoch time in milliseconds.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"delete_old_ui_auth_sessions",
|
||||||
|
self._delete_old_ui_auth_sessions_txn,
|
||||||
|
expiration_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
|
||||||
|
# Get the expired sessions.
|
||||||
|
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||||
|
txn.execute(sql, [expiration_time])
|
||||||
|
session_ids = [r[0] for r in txn.fetchall()]
|
||||||
|
|
||||||
|
# Delete the corresponding completed credentials.
|
||||||
|
self.db.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions_credentials",
|
||||||
|
column="session_id",
|
||||||
|
iterable=session_ids,
|
||||||
|
keyvalues={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally, delete the sessions.
|
||||||
|
self.db.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions",
|
||||||
|
column="session_id",
|
||||||
|
iterable=session_ids,
|
||||||
|
keyvalues={},
|
||||||
|
)
|
|
@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
||||||
prepare_database(db_conn, self, config=None)
|
prepare_database(db_conn, self, config=None)
|
||||||
|
|
||||||
db_conn.create_function("rank", 1, _rank)
|
db_conn.create_function("rank", 1, _rank)
|
||||||
|
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
def is_deadlock(self, error):
|
def is_deadlock(self, error):
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
|
|
||||||
|
def test_complete_operation_unknown_session(self):
|
||||||
|
"""
|
||||||
|
Attempting to mark an invalid session as complete should error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Make the initial request to register. (Later on a different password
|
||||||
|
# will be used.)
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
# Returns a 401 as per the spec
|
||||||
|
self.assertEqual(request.code, 401)
|
||||||
|
# Grab the session
|
||||||
|
session = channel.json_body["session"]
|
||||||
|
# Assert our configured public key is being given
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
|
||||||
|
)
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(request.code, 200)
|
||||||
|
|
||||||
|
# Attempt to complete an unknown session, which should return an error.
|
||||||
|
unknown_session = session + "unknown"
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"auth/m.login.recaptcha/fallback/web?session="
|
||||||
|
+ unknown_session
|
||||||
|
+ "&g-recaptcha-response=a",
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(request.code, 400)
|
||||||
|
|
|
@ -512,8 +512,8 @@ class MockClock(object):
|
||||||
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def looping_call(self, function, interval):
|
def looping_call(self, function, interval, *args, **kwargs):
|
||||||
self.loopers.append([function, interval / 1000.0, self.now])
|
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
|
||||||
|
|
||||||
def cancel_call_later(self, timer, ignore_errs=False):
|
def cancel_call_later(self, timer, ignore_errs=False):
|
||||||
if timer[2]:
|
if timer[2]:
|
||||||
|
@ -543,9 +543,9 @@ class MockClock(object):
|
||||||
self.timers.append(t)
|
self.timers.append(t)
|
||||||
|
|
||||||
for looped in self.loopers:
|
for looped in self.loopers:
|
||||||
func, interval, last = looped
|
func, interval, last, args, kwargs = looped
|
||||||
if last + interval < self.now:
|
if last + interval < self.now:
|
||||||
func()
|
func(*args, **kwargs)
|
||||||
looped[2] = self.now
|
looped[2] = self.now
|
||||||
|
|
||||||
def advance_time_msec(self, ms):
|
def advance_time_msec(self, ms):
|
||||||
|
|
3
tox.ini
3
tox.ini
|
@ -200,8 +200,9 @@ commands = mypy \
|
||||||
synapse/replication \
|
synapse/replication \
|
||||||
synapse/rest \
|
synapse/rest \
|
||||||
synapse/spam_checker_api \
|
synapse/spam_checker_api \
|
||||||
synapse/storage/engines \
|
synapse/storage/data_stores/main/ui_auth.py \
|
||||||
synapse/storage/database.py \
|
synapse/storage/database.py \
|
||||||
|
synapse/storage/engines \
|
||||||
synapse/streams \
|
synapse/streams \
|
||||||
synapse/util/caches/stream_change_cache.py \
|
synapse/util/caches/stream_change_cache.py \
|
||||||
tests/replication/tcp/streams \
|
tests/replication/tcp/streams \
|
||||||
|
|
Loading…
Reference in New Issue