diff --git a/docs/password_auth_providers.rst b/docs/password_auth_providers.rst index ca05a76617..2dbebcd72c 100644 --- a/docs/password_auth_providers.rst +++ b/docs/password_auth_providers.rst @@ -30,22 +30,55 @@ Password auth provider classes must provide the following methods: and a ``synapse.handlers.auth._AccountHandler`` object which allows the password provider to check if accounts exist and/or create new ones. -``someprovider.check_password``\(*user_id*, *password*) - - This is the method that actually does the work. It is passed a qualified - ``@localpart:domain`` user id, and the password provided by the user. - - The method should return a Twisted ``Deferred`` object, which resolves to - ``True`` if authentication is successful, and ``False`` if not. - Optional methods ---------------- -Password provider classes may optionally provide the following methods. +Password auth provider classes may optionally provide the following methods. -*class* ``SomeProvider.get_db_schema_files()`` +*class* ``SomeProvider.get_db_schema_files``\() This method, if implemented, should return an Iterable of ``(name, stream)`` pairs of database schema files. Each file is applied in turn at initialisation, and a record is then made in the database so that it is not re-applied on the next start. + +``someprovider.get_supported_login_types``\() + + This method, if implemented, should return a ``dict`` mapping from a login + type identifier (such as ``m.login.password``) to an iterable giving the + fields which must be provided by the user in the submission to the + ``/login`` api. These fields are passed in the ``login_dict`` dictionary + to ``check_auth``. + + For example, if a password auth provider wants to implement a custom login + type of ``com.example.custom_login``, where the client is expected to pass + the fields ``secret1`` and ``secret2``, the provider should implement this + method and return the following dict:: + + {"com.example.custom_login": ("secret1", "secret2")} + +``someprovider.check_auth``\(*username*, *login_type*, *login_dict*) + + This method is the one that does the real work. If implemented, it will be + called for each login attempt where the login type matches one of the keys + returned by ``get_supported_login_types``. + + It is passed the (possibly UNqualified) ``user`` provided by the client, + the login type, and a dictionary of login secrets passed by the client. + + The method should return a Twisted ``Deferred`` object, which resolves to + the canonical ``@localpart:domain`` user id if authentication is successful, + and ``None`` if not. + +``someprovider.check_password``\(*user_id*, *password*) + + This method provides a simpler interface than ``get_supported_login_types`` + and ``check_auth`` for password auth providers that just want to provide a + mechanism for validating ``m.login.password`` logins. + + Iif implemented, it will be called to check logins with an + ``m.login.password`` login type. It is passed a qualified + ``@localpart:domain`` user id, and the password provided by the user. + + The method should return a Twisted ``Deferred`` object, which resolves to + ``True`` if authentication is successful, and ``False`` if not. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4ba75eb742..9799461d26 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -81,6 +81,11 @@ class AuthHandler(BaseHandler): login_types = set() if self._password_enabled: login_types.add(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + login_types.update( + provider.get_supported_login_types().keys() + ) self._supported_login_types = frozenset(login_types) @defer.inlineCallbacks @@ -501,14 +506,14 @@ class AuthHandler(BaseHandler): return self._supported_login_types @defer.inlineCallbacks - def validate_login(self, user_id, login_submission): + def validate_login(self, username, login_submission): """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate m.login.password auth types. Args: - user_id (str): user_id supplied by the user + username (str): username supplied by the user login_submission (dict): the whole of the login submission (including 'type' and other relevant fields) Returns: @@ -519,32 +524,81 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ - if not user_id.startswith('@'): - user_id = UserID( - user_id, self.hs.hostname + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname ).to_string() login_type = login_submission.get("type") + known_login_type = False - if login_type != LoginType.PASSWORD: - raise SynapseError(400, "Bad login type.") - if not self._password_enabled: - raise SynapseError(400, "Password login has been disabled.") - if "password" not in login_submission: - raise SynapseError(400, "Missing parameter: password") + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") - password = login_submission["password"] for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue(qualified_user_id) - canonical_user_id = yield self._check_local_password( - user_id, password, - ) + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue - if canonical_user_id: - defer.returnValue(canonical_user_id) + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) + + returned_user_id = yield provider.check_auth( + username, login_type, login_dict, + ) + if returned_user_id: + defer.returnValue(returned_user_id) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) + + if canonical_user_id: + defer.returnValue(canonical_user_id) + + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -773,11 +827,31 @@ class _AccountHandler(object): self._check_user_exists = check_user_exists - def check_user_exists(self, user_id): - """Check if user exissts. + def get_qualified_user_id(self, username): + """Qualify a user id, if necessary + + Takes a user id provided by the user and adds the @ and :domain to + qualify it, if necessary + + Args: + username (str): provided user id Returns: - Deferred(bool) + str: qualified @user:id + """ + if username.startswith('@'): + return username + return UserID(username, self.hs.hostname).to_string() + + def check_user_exists(self, user_id): + """Check if user exists. + + Args: + user_id (str): Complete @user:id + + Returns: + Deferred[str|None]: Canonical (case-corrected) user_id, or None + if the user is not registered. """ return self._check_user_exists(user_id)