Merge branch 'develop' into daniel/3pidinvites

This commit is contained in:
Daniel Wagner-Hall 2015-10-13 13:32:43 +01:00
commit 32a453d7ba
11 changed files with 296 additions and 69 deletions

View File

@ -44,4 +44,7 @@ Eric Myhre <hash at exultant.us>
repository API. repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com> Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins. * Add SAML2 support for registration and login.
Steven Hammerton <steven.hammerton at openmarket.com>
* Add CAS support for registration and login.

43
synapse/config/cas.py Normal file
View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class CasConfig(Config):
"""Cas Configuration
cas_server_url: URL of CAS server
"""
def read_config(self, config):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = True
self.cas_server_url = cas_config["server_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# server_url: "https://cas-server.com"
# #required_attributes:
# # name: value
"""

View File

@ -26,12 +26,13 @@ from .metrics import MetricsConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .cas import CasConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ): AppServiceConfig, KeyConfig, SAML2Config, CasConfig):
pass pass

View File

@ -295,6 +295,38 @@ class AuthHandler(BaseHandler):
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def login_with_cas_user_id(self, user_id):
"""
Authenticates the user with the given user ID,
intended to have been captured from a CAS response
Args:
user_id (str): User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case

View File

@ -46,6 +46,56 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def started_stream(self, user):
"""Tells the presence handler that we have started an eventstream for
the user:
Args:
user (User): The user who started a stream.
Returns:
A deferred that completes once their presence has been updated.
"""
if user not in self._streams_per_user:
self._streams_per_user[user] = 0
if user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire("started_user_eventstream", user)
self._streams_per_user[user] += 1
def stopped_stream(self, user):
"""If there are no streams for a user this starts a timer that will
notify the presence handler that we haven't got an event stream for
the user unless the user starts a new stream in 30 seconds.
Args:
user (User): The user who stopped a stream.
"""
self._streams_per_user[user] -= 1
if not self._streams_per_user[user]:
del self._streams_per_user[user]
# 30 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug("_later stopped_user_eventstream %s", user)
self._stop_timer_per_user.pop(user, None)
return self.distributor.fire("stopped_user_eventstream", user)
logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = (
self.clock.call_later(30, _later)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
@ -59,20 +109,7 @@ class EventStreamHandler(BaseHandler):
try: try:
if affect_presence: if affect_presence:
if auth_user not in self._streams_per_user: yield self.started_stream(auth_user)
self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
@ -114,27 +151,7 @@ class EventStreamHandler(BaseHandler):
finally: finally:
if affect_presence: if affect_presence:
self._streams_per_user[auth_user] -= 1 self.stopped_stream(auth_user)
if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
# 10 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
return self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View File

@ -324,7 +324,8 @@ class MessageHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined. """Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is This snapshot may include messages for all rooms where the user is
@ -335,17 +336,19 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return. config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns: Returns:
A list of dicts with "room_id" and "membership" keys for all rooms A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig. on the specified PaginationConfig.
""" """
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, user_id=user_id, membership_list=memberships
membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
) )
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@ -378,7 +378,7 @@ class PresenceHandler(BaseHandler):
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
self.changed_presencelike_data(target_user, state) yield self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None): def bump_presence_active_time(self, user, now=None):
if now is None: if now is None:
@ -422,12 +422,12 @@ class PresenceHandler(BaseHandler):
@log_function @log_function
def started_user_eventstream(self, user): def started_user_eventstream(self, user):
# TODO(paul): Use "last online" state # TODO(paul): Use "last online" state
self.set_state(user, user, {"presence": PresenceState.ONLINE}) return self.set_state(user, user, {"presence": PresenceState.ONLINE})
@log_function @log_function
def stopped_user_eventstream(self, user): def stopped_user_eventstream(self, user):
# TODO(paul): Save current state as "last online" state # TODO(paul): Save current state as "last online" state
self.set_state(user, user, {"presence": PresenceState.OFFLINE}) return self.set_state(user, user, {"presence": PresenceState.OFFLINE})
@defer.inlineCallbacks @defer.inlineCallbacks
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
@ -1263,6 +1263,11 @@ class UserPresenceCache(object):
self.state = {"presence": PresenceState.OFFLINE} self.state = {"presence": PresenceState.OFFLINE}
self.serial = None self.serial = None
def __repr__(self):
return "UserPresenceCache(state=%r, serial=%r)" % (
self.state, self.serial
)
def update(self, state, serial): def update(self, state, serial):
assert("mtime_age" not in state) assert("mtime_age" not in state)

View File

@ -160,27 +160,8 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the On a non-2xx HTTP response. The response body will be used as the
error message. error message.
""" """
if len(args): body = yield self.get_raw(uri, args)
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}): def put_json(self, uri, json_body, args={}):
@ -225,6 +206,42 @@ class SimpleHttpClient(object):
# load it into JSON if they want. # load it into JSON if they want.
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def get_raw(self, uri, args={}):
""" Gets raw text from the given URI.
Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
Raises:
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
""" """

View File

@ -186,7 +186,7 @@ class Pusher(object):
if not display_name: if not display_name:
return False return False
return re.search( return re.search(
"\b%s\b" % re.escape(display_name), ev['content']['body'], r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE flags=re.IGNORECASE
) is not None ) is not None

View File

@ -29,10 +29,12 @@ class InitialSyncRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms( content = yield handler.snapshot_all_rooms(
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event,
include_archived=include_archived,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -15,7 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
@ -27,6 +28,8 @@ from saml2 import BINDING_HTTP_POST
from saml2 import config from saml2 import config
from saml2.client import Saml2Client from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
def on_GET(self, request): def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}] flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled: if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -67,6 +77,19 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -100,6 +123,74 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
class LoginFallbackRestServlet(ClientV1RestServlet): class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$") PATTERN = client_path_pattern("/login/fallback$")
@ -174,6 +265,17 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -188,4 +290,6 @@ def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server) SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)