From 91482cd6a0285faa837d7bd503855aa002cd3034 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 8 Oct 2015 11:22:15 +0100 Subject: [PATCH 01/20] Use raw string for regex here, otherwise \b is the backspace character. Fixes displayname matching. --- synapse/push/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index f1952b5a0f..0e0c61dec8 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -186,7 +186,7 @@ class Pusher(object): if not display_name: return False return re.search( - "\b%s\b" % re.escape(display_name), ev['content']['body'], + r"\b%s\b" % re.escape(display_name), ev['content']['body'], flags=re.IGNORECASE ) is not None From dc720217489e2a8cf528255502fe448a85e1ff52 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 8 Oct 2015 17:19:42 +0100 Subject: [PATCH 02/20] Add a flag to initial sync to indicate we want rooms that the user has left --- synapse/handlers/message.py | 13 ++++++++----- synapse/rest/client/v1/initial_sync.py | 4 +++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 30949ff7a6..b70258697b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -324,7 +324,8 @@ class MessageHandler(BaseHandler): ) @defer.inlineCallbacks - def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -335,17 +336,19 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left Returns: A list of dicts with "room_id" and "membership" keys for all rooms the user is currently invited or joined in on. Rooms where the user is joined on, may return a "messages" key with messages, depending on the specified PaginationConfig. """ + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=[ - Membership.INVITE, Membership.JOIN, Membership.LEAVE - ] + user_id=user_id, membership_list=memberships ) user = UserID.from_string(user_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index bac68cc29f..52b7951b8f 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -29,10 +29,12 @@ class InitialSyncRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler + include_archived = request.args.get("archived", None) == "1" content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + include_archived=include_archived, ) defer.returnValue((200, content)) From 51ef7256472106aaca99de836ae73564ee78349c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 8 Oct 2015 18:13:02 +0100 Subject: [PATCH 03/20] Use 'true' rather than '1' for archived flag --- synapse/rest/client/v1/initial_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 52b7951b8f..52c7943400 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -29,7 +29,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler - include_archived = request.args.get("archived", None) == "1" + include_archived = request.args.get("archived", None) == ["true"] content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, From 1b9802a0d99aafddd41088c94dc46bf88399e879 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 9 Oct 2015 19:13:08 +0100 Subject: [PATCH 04/20] Split the sections of EventStreamHandler.get_stream that handle presence into separate functions. This makes the code a bit easier to read, and means that we can reuse the logic when implementing the v2 sync API. --- synapse/handlers/events.py | 87 +++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 891502c04f..92afa35d57 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -46,6 +46,56 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() + @defer.inlineCallbacks + def started_stream(self, user): + """Tells the presence handler that we have started an eventstream for + the user: + + Args: + user (User): The user who started a stream. + Returns: + A deferred that completes once their presence has been updated. + """ + if user not in self._streams_per_user: + self._streams_per_user[user] = 0 + if user in self._stop_timer_per_user: + try: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(user) + ) + except: + logger.exception("Failed to cancel event timer") + else: + yield self.distributor.fire("started_user_eventstream", user) + + self._streams_per_user[user] += 1 + + def stopped_stream(self, user): + """If there are no streams for a user this starts a timer that will + notify the presence handler that we haven't got an event stream for + the user unless the user starts a new stream in 30 seconds. + + Args: + user (User): The user who stopped a stream. + """ + self._streams_per_user[user] -= 1 + if not self._streams_per_user[user]: + del self._streams_per_user[user] + + # 30 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + logger.debug("_later stopped_user_eventstream %s", user) + + self._stop_timer_per_user.pop(user, None) + + return self.distributor.fire("stopped_user_eventstream", user) + + logger.debug("Scheduling _later: for %s", user) + self._stop_timer_per_user[user] = ( + self.clock.call_later(30, _later) + ) + @defer.inlineCallbacks @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, @@ -59,20 +109,7 @@ class EventStreamHandler(BaseHandler): try: if affect_presence: - if auth_user not in self._streams_per_user: - self._streams_per_user[auth_user] = 0 - if auth_user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(auth_user) - ) - except: - logger.exception("Failed to cancel event timer") - else: - yield self.distributor.fire( - "started_user_eventstream", auth_user - ) - self._streams_per_user[auth_user] += 1 + yield self.started_stream(auth_user) rm_handler = self.hs.get_handlers().room_member_handler @@ -114,27 +151,7 @@ class EventStreamHandler(BaseHandler): finally: if affect_presence: - self._streams_per_user[auth_user] -= 1 - if not self._streams_per_user[auth_user]: - del self._streams_per_user[auth_user] - - # 10 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug( - "_later stopped_user_eventstream %s", auth_user - ) - - self._stop_timer_per_user.pop(auth_user, None) - - return self.distributor.fire( - "stopped_user_eventstream", auth_user - ) - - logger.debug("Scheduling _later: for %s", auth_user) - self._stop_timer_per_user[auth_user] = ( - self.clock.call_later(30, _later) - ) + self.stopped_stream(auth_user) class EventHandler(BaseHandler): From 987803781e3870e1e1a3652612d833779d6cd290 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 9 Oct 2015 19:59:50 +0100 Subject: [PATCH 05/20] Fix some races in the synapse presence handler caused by not yielding on deferreds --- synapse/handlers/presence.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index e91e81831e..ce60642127 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -378,7 +378,7 @@ class PresenceHandler(BaseHandler): # TODO(paul): perform a presence push as part of start/stop poll so # we don't have to do this all the time - self.changed_presencelike_data(target_user, state) + yield self.changed_presencelike_data(target_user, state) def bump_presence_active_time(self, user, now=None): if now is None: @@ -422,12 +422,12 @@ class PresenceHandler(BaseHandler): @log_function def started_user_eventstream(self, user): # TODO(paul): Use "last online" state - self.set_state(user, user, {"presence": PresenceState.ONLINE}) + return self.set_state(user, user, {"presence": PresenceState.ONLINE}) @log_function def stopped_user_eventstream(self, user): # TODO(paul): Save current state as "last online" state - self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) @defer.inlineCallbacks def user_joined_room(self, user, room_id): @@ -1263,6 +1263,11 @@ class UserPresenceCache(object): self.state = {"presence": PresenceState.OFFLINE} self.serial = None + def __repr__(self): + return "UserPresenceCache(state=%r, serial=%r)" % ( + self.state, self.serial + ) + def update(self, state, serial): assert("mtime_age" not in state) From c33f5c1a2414632f21183f41ecd4aef00e46a437 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Wed, 7 Oct 2015 14:45:57 +0100 Subject: [PATCH 06/20] Provide ability to login using CAS --- synapse/config/cas.py | 39 ++++++++++++++++++++ synapse/config/homeserver.py | 3 +- synapse/handlers/auth.py | 31 ++++++++++++++++ synapse/rest/client/v1/login.py | 64 ++++++++++++++++++++++++++++++++- 4 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 synapse/config/cas.py diff --git a/synapse/config/cas.py b/synapse/config/cas.py new file mode 100644 index 0000000000..81d034e8f0 --- /dev/null +++ b/synapse/config/cas.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class CasConfig(Config): + """Cas Configuration + + cas_server_url: URL of CAS server + """ + + def read_config(self, config): + cas_config = config.get("cas_config", None) + if cas_config: + self.cas_enabled = True + self.cas_server_url = cas_config["server_url"] + else: + self.cas_enabled = False + self.cas_server_url = None + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Enable CAS for registration and login. + #cas_config: + # server_url: "https://cas-server.com" + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index d77f045406..3039f3c0bf 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -26,12 +26,13 @@ from .metrics import MetricsConfig from .appservice import AppServiceConfig from .key import KeyConfig from .saml2 import SAML2Config +from .cas import CasConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, - AppServiceConfig, KeyConfig, SAML2Config, ): + AppServiceConfig, KeyConfig, SAML2Config, CasConfig): pass diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 793b3fcd8b..0ad28c4948 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -295,6 +295,37 @@ class AuthHandler(BaseHandler): refresh_token = yield self.issue_refresh_token(user_id) defer.returnValue((user_id, access_token, refresh_token)) + @defer.inlineCallbacks + def login_with_cas_user_id(self, user_id): + """ + Authenticates the user with the given user ID, intended to have been captured from a CAS response + + Args: + user_id (str): User ID + Returns: + A tuple of: + The user's ID. + The access token for the user's session. + The refresh token for the user's session. + Raises: + StoreError if there was a problem storing the token. + LoginError if there was an authentication problem. + """ + user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) + + logger.info("Logging in user %s", user_id) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((user_id, access_token, refresh_token)) + + @defer.inlineCallbacks + def does_user_exist(self, user_id): + try: + yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(True) + except LoginError: + defer.returnValue(False) + @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e580f71964..56e5cf79fe 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from base import ClientV1RestServlet, client_path_pattern @@ -27,6 +27,9 @@ from saml2 import BINDING_HTTP_POST from saml2 import config from saml2.client import Saml2Client +import xml.etree.ElementTree as ET +import requests + logger = logging.getLogger(__name__) @@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login$") PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" + CAS_TYPE = "m.login.cas" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.saml2_enabled = hs.config.saml2_enabled + self.cas_enabled = hs.config.cas_enabled + + self.cas_server_url = hs.config.cas_server_url + self.servername = hs.config.server_name def on_GET(self, request): flows = [{"type": LoginRestServlet.PASS_TYPE}] if self.saml2_enabled: flows.append({"type": LoginRestServlet.SAML2_TYPE}) + if self.cas_enabled: + flows.append({"type": LoginRestServlet.CAS_TYPE}) return (200, {"flows": flows}) def on_OPTIONS(self, request): @@ -67,6 +77,12 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): + url = "%s/proxyValidate" % (self.cas_server_url) + parameters = {"ticket": login_submission["ticket"], "service": login_submission["service"]} + response = requests.get(url, verify=False, params=parameters) + result = yield self.do_cas_login(response.text) + defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -100,6 +116,41 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + @defer.inlineCallbacks + def do_cas_login(self, cas_response_body): + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not root[0].tag.endswith("authenticationSuccess"): + raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + user_id = "@%s:%s" % (user, self.servername) + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = yield auth_handler.login_with_cas_user_id(user_id) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = yield self.handlers.registration_handler.register(localpart=user) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + class LoginFallbackRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/fallback$") @@ -173,6 +224,15 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue(None) defer.returnValue((200, {"status": "not_authenticated"})) +class CasRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas") + + def __init__(self, hs): + super(CasRestServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + + def on_GET(self, request): + return (200, {"serverUrl": self.cas_server_url}) def _parse_json(request): try: @@ -188,4 +248,6 @@ def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) + if hs.config.cas_enabled: + CasRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) From 22112f8d14d1fcdb567c75484b3717e931d705db Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Thu, 8 Oct 2015 23:34:04 +0100 Subject: [PATCH 07/20] Formatting changes --- synapse/handlers/auth.py | 3 ++- synapse/rest/client/v1/login.py | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0ad28c4948..484f719253 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -298,7 +298,8 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def login_with_cas_user_id(self, user_id): """ - Authenticates the user with the given user ID, intended to have been captured from a CAS response + Authenticates the user with the given user ID, + intended to have been captured from a CAS response Args: user_id (str): User ID diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 56e5cf79fe..4de5f19591 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -77,9 +77,13 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) - elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): + elif self.cas_enabled and (login_submission["type"] == + LoginRestServlet.CAS_TYPE): url = "%s/proxyValidate" % (self.cas_server_url) - parameters = {"ticket": login_submission["ticket"], "service": login_submission["service"]} + parameters = { + "ticket": login_submission["ticket"], + "service": login_submission["service"] + } response = requests.get(url, verify=False, params=parameters) result = yield self.do_cas_login(response.text) defer.returnValue(result) @@ -130,7 +134,8 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.handlers.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: - user_id, access_token, refresh_token = yield auth_handler.login_with_cas_user_id(user_id) + user_id, access_token, refresh_token = yield + auth_handler.login_with_cas_user_id(user_id) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -139,7 +144,8 @@ class LoginRestServlet(ClientV1RestServlet): } else: - user_id, access_token = yield self.handlers.registration_handler.register(localpart=user) + user_id, access_token = yield + self.handlers.registration_handler.register(localpart=user) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -148,7 +154,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) @@ -224,6 +229,7 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue(None) defer.returnValue((200, {"status": "not_authenticated"})) + class CasRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/cas") @@ -234,6 +240,7 @@ class CasRestServlet(ClientV1RestServlet): def on_GET(self, request): return (200, {"serverUrl": self.cas_server_url}) + def _parse_json(request): try: content = json.loads(request.content.read()) From 625e13bfde35a3c6fdd2b3e8263838ec4d4fbcc3 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Fri, 9 Oct 2015 11:02:56 +0100 Subject: [PATCH 08/20] Add get_raw method to SimpleHttpClient, use this in CAS auth rather than requests --- synapse/http/client.py | 61 +++++++++++++++++++++------------ synapse/rest/client/v1/login.py | 13 ++++--- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 79c529291f..ca642a7a06 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -160,27 +160,8 @@ class SimpleHttpClient(object): On a non-2xx HTTP response. The response body will be used as the error message. """ - if len(args): - query_bytes = urllib.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) - - response = yield self.request( - "GET", - uri.encode("ascii"), - headers=Headers({ - b"User-Agent": [self.user_agent], - }) - ) - - body = yield preserve_context_over_fn(readBody, response) - - if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) - else: - # NB: This is explicitly not json.loads(body)'d because the contract - # of CodeMessageException is a *string* message. Callers can always - # load it into JSON if they want. - raise CodeMessageException(response.code, body) + body = yield self.get_raw(uri, args) + defer.returnValue(json.loads(body)) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}): @@ -209,7 +190,7 @@ class SimpleHttpClient(object): "PUT", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [self.user_agent], + b"User-Agent": [self.version_string], "Content-Type": ["application/json"] }), bodyProducer=FileBodyProducer(StringIO(json_str)) @@ -225,6 +206,42 @@ class SimpleHttpClient(object): # load it into JSON if they want. raise CodeMessageException(response.code, body) + @defer.inlineCallbacks + def get_raw(self, uri, args={}): + """ Gets raw text from the given URI. + + Args: + uri (str): The URI to request, not including query parameters + args (dict): A dictionary used to create query strings, defaults to + None. + **Note**: The value of each key is assumed to be an iterable + and *not* a string. + Returns: + Deferred: Succeeds when we get *any* 2xx HTTP response, with the + HTTP body at text. + Raises: + On a non-2xx HTTP response. The response body will be used as the + error message. + """ + if len(args): + query_bytes = urllib.urlencode(args, True) + uri = "%s?%s" % (uri, query_bytes) + + response = yield self.request( + "GET", + uri.encode("ascii"), + headers=Headers({ + b"User-Agent": [self.version_string], + }) + ) + + body = yield preserve_context_over_fn(readBody, response) + + if 200 <= response.code < 300: + defer.returnValue(body) + else: + raise CodeMessageException(response.code, body) + class CaptchaServerHttpClient(SimpleHttpClient): """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4de5f19591..f5cd6a1960 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes +from synapse.http.client import SimpleHttpClient from synapse.types import UserID from base import ClientV1RestServlet, client_path_pattern @@ -28,7 +29,6 @@ from saml2 import config from saml2.client import Saml2Client import xml.etree.ElementTree as ET -import requests logger = logging.getLogger(__name__) @@ -79,13 +79,16 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): - url = "%s/proxyValidate" % (self.cas_server_url) - parameters = { + # TODO: get this from the homeserver rather than creating a new one for + # each request + http_client = SimpleHttpClient(self.hs) + uri = "%s/proxyValidate" % (self.cas_server_url,) + args = { "ticket": login_submission["ticket"], "service": login_submission["service"] } - response = requests.get(url, verify=False, params=parameters) - result = yield self.do_cas_login(response.text) + body = yield http_client.get_raw(uri, args) + result = yield self.do_cas_login(body) defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") From e52f4dc5995fccd96a2a4084dc68a05da5a16838 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Fri, 9 Oct 2015 11:04:07 +0100 Subject: [PATCH 09/20] Use UserId to create FQ user id --- synapse/rest/client/v1/login.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f5cd6a1960..05095e7d6e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -133,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet): for child in root[0]: if child.tag.endswith("user"): user = child.text - user_id = "@%s:%s" % (user, self.servername) + user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.handlers.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: From a9c299c0befc5cfc10ed1a5282b6002a43b9b462 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Fri, 9 Oct 2015 11:04:30 +0100 Subject: [PATCH 10/20] Fix my broken line splitting --- synapse/rest/client/v1/login.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 05095e7d6e..1bd93526ad 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -137,8 +137,9 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.handlers.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: - user_id, access_token, refresh_token = yield - auth_handler.login_with_cas_user_id(user_id) + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_cas_user_id(user_id) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -147,8 +148,9 @@ class LoginRestServlet(ClientV1RestServlet): } else: - user_id, access_token = yield - self.handlers.registration_handler.register(localpart=user) + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, From 95f7661170c842966e14b0a274347e73b90f1134 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Fri, 9 Oct 2015 11:05:02 +0100 Subject: [PATCH 11/20] Raise LoginError if CasResponse doensn't contain user --- synapse/rest/client/v1/login.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1bd93526ad..a99dcaab6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -157,7 +157,7 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } - defer.returnValue((200, result)) + defer.returnValue((200, result)) raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) From a80ef851f7e624b9eee91f134b233f3c0742bb3e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 10 Oct 2015 12:35:39 +0100 Subject: [PATCH 12/20] Fix previous merge to s/version_string/user_agent/ --- synapse/http/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index ca642a7a06..9a5869abee 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -190,7 +190,7 @@ class SimpleHttpClient(object): "PUT", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [self.version_string], + b"User-Agent": [self.user_agent], "Content-Type": ["application/json"] }), bodyProducer=FileBodyProducer(StringIO(json_str)) @@ -231,7 +231,7 @@ class SimpleHttpClient(object): "GET", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [self.version_string], + b"User-Agent": [self.user_agent], }) ) From 782f7fb4899ef078b422ce779e931e156fde15be Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 10 Oct 2015 18:24:44 +0100 Subject: [PATCH 13/20] add steve to authors --- AUTHORS.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 54ced67000..58a67c6b12 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -44,4 +44,7 @@ Eric Myhre repository API. Muthu Subramanian - * Add SAML2 support for registration and logins. + * Add SAML2 support for registration and login. + +Steven Hammerton + * Add CAS support for registration and login. From 7845f62c2207e9fa51f7a0aa7b60b49cf6436696 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 10:52:43 +0100 Subject: [PATCH 14/20] Parse both user and attributes from CAS response --- synapse/rest/client/v1/login.py | 64 +++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a99dcaab6f..0e12880ab5 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_cas_login(self, cas_response_body): + (user, attributes) = self.parse_cas_response(cas_response_body) + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_cas_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) @@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet): for child in root[0]: if child.tag.endswith("user"): user = child.text - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + return (user, attributes) class LoginFallbackRestServlet(ClientV1RestServlet): From 76421c496d1ee4ba5ea97fb24466156d0ddc0723 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 11:11:49 +0100 Subject: [PATCH 15/20] Allow optional config params for a required attribute and it's value, if specified any CAS user must have the given attribute and the value must equal --- synapse/config/cas.py | 15 +++++++++++++++ synapse/rest/client/v1/login.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 81d034e8f0..4d1dd8cc7b 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -27,13 +27,28 @@ class CasConfig(Config): if cas_config: self.cas_enabled = True self.cas_server_url = cas_config["server_url"] + + if "required_attribute" in cas_config: + self.cas_required_attribute = cas_config["required_attribute"] + else: + self.cas_required_attribute = None + + if "required_attribute_value" in cas_config: + self.cas_required_attribute_value = cas_config["required_attribute_value"] + else: + self.cas_required_attribute_value = None + else: self.cas_enabled = False self.cas_server_url = None + self.cas_required_attribute = None + self.cas_required_attribute_value = None def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable CAS for registration and login. #cas_config: # server_url: "https://cas-server.com" + # #required_attribute: something + # #required_attribute_value: true """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0e12880ab5..1e62beaff8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet): self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url + self.cas_required_attribute = hs.config.cas_required_attribute + self.cas_required_attribute_value = hs.config.cas_required_attribute_value self.servername = hs.config.server_name def on_GET(self, request): @@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_cas_login(self, cas_response_body): (user, attributes) = self.parse_cas_response(cas_response_body) + + if self.cas_required_attribute is not None: + # If required attribute was not in CAS Response - Forbidden + if self.cas_required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if self.cas_required_attribute_value is not None: + actualValue = attributes[self.cas_required_attribute] + # If required attribute value does not match expected - Forbidden + if self.cas_required_attribute_value != actualValue: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.handlers.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) From 01a5f1991c8e54d0762cf1647c941d00c938f994 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 14:43:17 +0100 Subject: [PATCH 16/20] Support multiple required attributes in CAS response, and in a nicer config format too --- synapse/config/cas.py | 19 ++++--------------- synapse/rest/client/v1/login.py | 13 ++++++------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 4d1dd8cc7b..e884d03fe6 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -27,28 +27,17 @@ class CasConfig(Config): if cas_config: self.cas_enabled = True self.cas_server_url = cas_config["server_url"] - - if "required_attribute" in cas_config: - self.cas_required_attribute = cas_config["required_attribute"] - else: - self.cas_required_attribute = None - - if "required_attribute_value" in cas_config: - self.cas_required_attribute_value = cas_config["required_attribute_value"] - else: - self.cas_required_attribute_value = None - + self.cas_required_attributes = cas_config.get("required_attributes", None) else: self.cas_enabled = False self.cas_server_url = None - self.cas_required_attribute = None - self.cas_required_attribute_value = None + self.cas_required_attributes = {} def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable CAS for registration and login. #cas_config: # server_url: "https://cas-server.com" - # #required_attribute: something - # #required_attribute_value: true + # #required_attributes: + # # name: value """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1e62beaff8..84774e61aa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url - self.cas_required_attribute = hs.config.cas_required_attribute - self.cas_required_attribute_value = hs.config.cas_required_attribute_value + self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name def on_GET(self, request): @@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet): def do_cas_login(self, cas_response_body): (user, attributes) = self.parse_cas_response(cas_response_body) - if self.cas_required_attribute is not None: + for required_attribute in self.cas_required_attributes: # If required attribute was not in CAS Response - Forbidden - if self.cas_required_attribute not in attributes: + if required_attribute not in attributes: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) # Also need to check value - if self.cas_required_attribute_value is not None: - actualValue = attributes[self.cas_required_attribute] + if self.cas_required_attributes[required_attribute] is not None: + actualValue = attributes[required_attribute] # If required attribute value does not match expected - Forbidden - if self.cas_required_attribute_value != actualValue: + if self.cas_required_attributes[required_attribute] != actualValue: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() From 7f8fdc9814571723bfc120e43c6d21cde1c660a4 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 14:45:24 +0100 Subject: [PATCH 17/20] Remove not required parenthesis --- synapse/rest/client/v1/login.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 84774e61aa..8facb00126 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -125,7 +125,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_cas_login(self, cas_response_body): - (user, attributes) = self.parse_cas_response(cas_response_body) + user, attributes = self.parse_cas_response(cas_response_body) for required_attribute in self.cas_required_attributes: # If required attribute was not in CAS Response - Forbidden From ab7f9bb861791b9415d80f0e71d7b4b867b0a445 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 14:58:59 +0100 Subject: [PATCH 18/20] Default cas_required_attributes to empty dictionary --- synapse/config/cas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/config/cas.py b/synapse/config/cas.py index e884d03fe6..d268680729 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -27,7 +27,7 @@ class CasConfig(Config): if cas_config: self.cas_enabled = True self.cas_server_url = cas_config["server_url"] - self.cas_required_attributes = cas_config.get("required_attributes", None) + self.cas_required_attributes = cas_config.get("required_attributes", {}) else: self.cas_enabled = False self.cas_server_url = None From 83b464e4f70fbfcc338b0c3533359a8c53890cdc Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 15:05:34 +0100 Subject: [PATCH 19/20] Unpack dictionary in for loop for nicer syntax --- synapse/rest/client/v1/login.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8facb00126..c92dedcc0f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -127,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet): def do_cas_login(self, cas_response_body): user, attributes = self.parse_cas_response(cas_response_body) - for required_attribute in self.cas_required_attributes: + for required_attribute, required_value in self.cas_required_attributes.items(): # If required attribute was not in CAS Response - Forbidden if required_attribute not in attributes: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) # Also need to check value - if self.cas_required_attributes[required_attribute] is not None: - actualValue = attributes[required_attribute] + if required_value is not None: + actual_value = attributes[required_attribute] # If required attribute value does not match expected - Forbidden - if self.cas_required_attributes[required_attribute] != actualValue: + if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() From 739464fbc5dc328001fcc71e327938229c836204 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 16:02:17 +0100 Subject: [PATCH 20/20] Add a comment to clarify why we split on closing curly brace when reading CAS attribute tags --- synapse/rest/client/v1/login.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index c92dedcc0f..2e3e4f39f3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -177,6 +177,11 @@ class LoginRestServlet(ClientV1RestServlet): if child.tag.endswith("attributes"): attributes = {} for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. if "}" in attribute.tag: attributes[attribute.tag.split("}")[1]] = attribute.text else: