Code cleanups and simplifications.
Also: share the saml client between redirect and response handlers.
This commit is contained in:
parent
69a43d9974
commit
426049247b
|
@ -57,7 +57,6 @@ class LoginType(object):
|
|||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
MSISDN = u"m.login.msisdn"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
SSO = u"m.login.sso"
|
||||
TERMS = u"m.login.terms"
|
||||
DUMMY = u"m.login.dummy"
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
@ -25,6 +26,11 @@ class SAML2Config(Config):
|
|||
if not saml2_config or not saml2_config.get("enabled", True):
|
||||
return
|
||||
|
||||
try:
|
||||
check_requirements('saml2')
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message)
|
||||
|
||||
self.saml2_enabled = True
|
||||
|
||||
import saml2.config
|
||||
|
@ -75,7 +81,6 @@ class SAML2Config(Config):
|
|||
# override them.
|
||||
#
|
||||
#saml2_config:
|
||||
# enabled: true
|
||||
# sp_config:
|
||||
# # point this to the IdP's metadata. You can use either a local file or
|
||||
# # (preferably) a URL.
|
||||
|
|
|
@ -767,9 +767,6 @@ class AuthHandler(BaseHandler):
|
|||
if canonical_user_id:
|
||||
defer.returnValue((canonical_user_id, None))
|
||||
|
||||
if login_type == LoginType.SSO:
|
||||
known_login_type = True
|
||||
|
||||
if not known_login_type:
|
||||
raise SynapseError(400, "Unknown login type %s" % login_type)
|
||||
|
||||
|
|
|
@ -34,10 +34,6 @@ from synapse.rest.well_known import WellKnownBuilder
|
|||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
import saml2
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -378,28 +374,49 @@ class LoginRestServlet(RestServlet):
|
|||
defer.returnValue(result)
|
||||
|
||||
|
||||
class CasRedirectServlet(RestServlet):
|
||||
class BaseSsoRedirectServlet(RestServlet):
|
||||
"""Common base class for /login/sso/redirect impls"""
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
client_redirect_url = args[b"redirectUrl"][0]
|
||||
sso_url = self.get_sso_url(client_redirect_url)
|
||||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
"""
|
||||
# to be implemented by subclasses
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CasRedirectServlet(RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||
|
||||
def on_GET(self, request):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return (400, "Redirect URL not specified for CAS auth")
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
client_redirect_url_param = urllib.parse.urlencode({
|
||||
b"redirectUrl": args[b"redirectUrl"][0]
|
||||
b"redirectUrl": client_redirect_url
|
||||
}).encode('ascii')
|
||||
hs_redirect_url = (self.cas_service_url +
|
||||
b"/_matrix/client/r0/login/cas/ticket")
|
||||
service_param = urllib.parse.urlencode({
|
||||
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||
}).encode('ascii')
|
||||
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
|
||||
finish_request(request)
|
||||
return b"%s/login?%s" % (self.cas_server_url, service_param)
|
||||
|
||||
|
||||
class CasTicketServlet(RestServlet):
|
||||
|
@ -482,41 +499,23 @@ class CasTicketServlet(RestServlet):
|
|||
return user, attributes
|
||||
|
||||
|
||||
class SSORedirectServlet(RestServlet):
|
||||
class SAMLRedirectServlet(BaseSsoRedirectServlet):
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SSORedirectServlet, self).__init__()
|
||||
self.saml2_sp_config = hs.config.saml2_sp_config
|
||||
self._saml_client = hs.get_saml_client()
|
||||
|
||||
def on_GET(self, request):
|
||||
args = request.args
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||
relay_state=client_redirect_url,
|
||||
)
|
||||
|
||||
saml_client = Saml2Client(self.saml2_sp_config)
|
||||
reqid, info = saml_client.prepare_for_authenticate()
|
||||
|
||||
redirect_url = None
|
||||
|
||||
# Select the IdP URL to send the AuthN request to
|
||||
for key, value in info['headers']:
|
||||
if key is 'Location':
|
||||
redirect_url = value
|
||||
if key == 'Location':
|
||||
return value
|
||||
|
||||
if redirect_url is None:
|
||||
raise LoginError(401, "Unsuccessful SSO SAML2 redirect url response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
relay_state = "/_matrix/client/r0/login"
|
||||
if b"redirectUrl" in args:
|
||||
relay_state = args[b"redirectUrl"][0]
|
||||
|
||||
url_parts = list(urllib.parse.urlparse(redirect_url))
|
||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||
query.update({"RelayState": relay_state})
|
||||
url_parts[4] = urllib.parse.urlencode(query)
|
||||
|
||||
request.redirect(urllib.parse.urlunparse(url_parts))
|
||||
finish_request(request)
|
||||
# this shouldn't happen!
|
||||
raise Exception("prepare_for_authenticate didn't return a Location header")
|
||||
|
||||
|
||||
class SSOAuthHandler(object):
|
||||
|
@ -594,5 +593,5 @@ def register_servlets(hs, http_server):
|
|||
if hs.config.cas_enabled:
|
||||
CasRedirectServlet(hs).register(http_server)
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
if hs.config.saml2_enabled:
|
||||
SSORedirectServlet(hs).register(http_server)
|
||||
elif hs.config.saml2_enabled:
|
||||
SAMLRedirectServlet(hs).register(http_server)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import logging
|
||||
|
||||
import saml2
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -36,8 +35,7 @@ class SAML2ResponseResource(Resource):
|
|||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._saml_client = hs.get_saml_client()
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
|
||||
def render_POST(self, request):
|
||||
|
|
|
@ -189,6 +189,7 @@ class HomeServer(object):
|
|||
'registration_handler',
|
||||
'account_validity_handler',
|
||||
'event_client_serializer',
|
||||
'saml_client',
|
||||
]
|
||||
|
||||
REQUIRED_ON_MASTER_STARTUP = [
|
||||
|
@ -522,6 +523,10 @@ class HomeServer(object):
|
|||
def build_event_client_serializer(self):
|
||||
return EventClientSerializer(self)
|
||||
|
||||
def build_saml_client(self):
|
||||
from saml2.client import Saml2Client
|
||||
return Saml2Client(self.config.saml2_sp_config)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
|
Loading…
Reference in New Issue