Update the auth providers to be async. (#7935)
This commit is contained in:
parent
7078866969
commit
83434df381
|
@ -0,0 +1 @@
|
||||||
|
Convert the auth providers to be async/await.
|
|
@ -19,102 +19,103 @@ password auth provider module implementations:
|
||||||
|
|
||||||
Password auth provider classes must provide the following methods:
|
Password auth provider classes must provide the following methods:
|
||||||
|
|
||||||
*class* `SomeProvider.parse_config`(*config*)
|
* `parse_config(config)`
|
||||||
|
This method is passed the `config` object for this module from the
|
||||||
|
homeserver configuration file.
|
||||||
|
|
||||||
> This method is passed the `config` object for this module from the
|
It should perform any appropriate sanity checks on the provided
|
||||||
> homeserver configuration file.
|
configuration, and return an object which is then passed into
|
||||||
>
|
|
||||||
> It should perform any appropriate sanity checks on the provided
|
|
||||||
> configuration, and return an object which is then passed into
|
|
||||||
> `__init__`.
|
|
||||||
|
|
||||||
*class* `SomeProvider`(*config*, *account_handler*)
|
This method should have the `@staticmethod` decoration.
|
||||||
|
|
||||||
> The constructor is passed the config object returned by
|
* `__init__(self, config, account_handler)`
|
||||||
> `parse_config`, and a `synapse.module_api.ModuleApi` object which
|
|
||||||
> allows the password provider to check if accounts exist and/or create
|
The constructor is passed the config object returned by
|
||||||
> new ones.
|
`parse_config`, and a `synapse.module_api.ModuleApi` object which
|
||||||
|
allows the password provider to check if accounts exist and/or create
|
||||||
|
new ones.
|
||||||
|
|
||||||
## Optional methods
|
## Optional methods
|
||||||
|
|
||||||
Password auth provider classes may optionally provide the following
|
Password auth provider classes may optionally provide the following methods:
|
||||||
methods.
|
|
||||||
|
|
||||||
*class* `SomeProvider.get_db_schema_files`()
|
* `get_db_schema_files(self)`
|
||||||
|
|
||||||
> This method, if implemented, should return an Iterable of
|
This method, if implemented, should return an Iterable of
|
||||||
> `(name, stream)` pairs of database schema files. Each file is applied
|
`(name, stream)` pairs of database schema files. Each file is applied
|
||||||
> in turn at initialisation, and a record is then made in the database
|
in turn at initialisation, and a record is then made in the database
|
||||||
> so that it is not re-applied on the next start.
|
so that it is not re-applied on the next start.
|
||||||
|
|
||||||
`someprovider.get_supported_login_types`()
|
* `get_supported_login_types(self)`
|
||||||
|
|
||||||
> This method, if implemented, should return a `dict` mapping from a
|
This method, if implemented, should return a `dict` mapping from a
|
||||||
> login type identifier (such as `m.login.password`) to an iterable
|
login type identifier (such as `m.login.password`) to an iterable
|
||||||
> giving the fields which must be provided by the user in the submission
|
giving the fields which must be provided by the user in the submission
|
||||||
> to the `/login` api. These fields are passed in the `login_dict`
|
to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
|
||||||
> dictionary to `check_auth`.
|
These fields are passed in the `login_dict` dictionary to `check_auth`.
|
||||||
>
|
|
||||||
> For example, if a password auth provider wants to implement a custom
|
|
||||||
> login type of `com.example.custom_login`, where the client is expected
|
|
||||||
> to pass the fields `secret1` and `secret2`, the provider should
|
|
||||||
> implement this method and return the following dict:
|
|
||||||
>
|
|
||||||
> {"com.example.custom_login": ("secret1", "secret2")}
|
|
||||||
|
|
||||||
`someprovider.check_auth`(*username*, *login_type*, *login_dict*)
|
For example, if a password auth provider wants to implement a custom
|
||||||
|
login type of `com.example.custom_login`, where the client is expected
|
||||||
|
to pass the fields `secret1` and `secret2`, the provider should
|
||||||
|
implement this method and return the following dict:
|
||||||
|
|
||||||
> This method is the one that does the real work. If implemented, it
|
```python
|
||||||
> will be called for each login attempt where the login type matches one
|
{"com.example.custom_login": ("secret1", "secret2")}
|
||||||
> of the keys returned by `get_supported_login_types`.
|
```
|
||||||
>
|
|
||||||
> It is passed the (possibly UNqualified) `user` provided by the client,
|
|
||||||
> the login type, and a dictionary of login secrets passed by the
|
|
||||||
> client.
|
|
||||||
>
|
|
||||||
> The method should return a Twisted `Deferred` object, which resolves
|
|
||||||
> to the canonical `@localpart:domain` user id if authentication is
|
|
||||||
> successful, and `None` if not.
|
|
||||||
>
|
|
||||||
> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
|
|
||||||
> which case the second field is a callback which will be called with
|
|
||||||
> the result from the `/login` call (including `access_token`,
|
|
||||||
> `device_id`, etc.)
|
|
||||||
|
|
||||||
`someprovider.check_3pid_auth`(*medium*, *address*, *password*)
|
* `check_auth(self, username, login_type, login_dict)`
|
||||||
|
|
||||||
> This method, if implemented, is called when a user attempts to
|
This method does the real work. If implemented, it
|
||||||
> register or log in with a third party identifier, such as email. It is
|
will be called for each login attempt where the login type matches one
|
||||||
> passed the medium (ex. "email"), an address (ex.
|
of the keys returned by `get_supported_login_types`.
|
||||||
> "<jdoe@example.com>") and the user's password.
|
|
||||||
>
|
|
||||||
> The method should return a Twisted `Deferred` object, which resolves
|
|
||||||
> to a `str` containing the user's (canonical) User ID if
|
|
||||||
> authentication was successful, and `None` if not.
|
|
||||||
>
|
|
||||||
> As with `check_auth`, the `Deferred` may alternatively resolve to a
|
|
||||||
> `(user_id, callback)` tuple.
|
|
||||||
|
|
||||||
`someprovider.check_password`(*user_id*, *password*)
|
It is passed the (possibly unqualified) `user` field provided by the client,
|
||||||
|
the login type, and a dictionary of login secrets passed by the
|
||||||
|
client.
|
||||||
|
|
||||||
> This method provides a simpler interface than
|
The method should return an `Awaitable` object, which resolves
|
||||||
> `get_supported_login_types` and `check_auth` for password auth
|
to the canonical `@localpart:domain` user ID if authentication is
|
||||||
> providers that just want to provide a mechanism for validating
|
successful, and `None` if not.
|
||||||
> `m.login.password` logins.
|
|
||||||
>
|
|
||||||
> Iif implemented, it will be called to check logins with an
|
|
||||||
> `m.login.password` login type. It is passed a qualified
|
|
||||||
> `@localpart:domain` user id, and the password provided by the user.
|
|
||||||
>
|
|
||||||
> The method should return a Twisted `Deferred` object, which resolves
|
|
||||||
> to `True` if authentication is successful, and `False` if not.
|
|
||||||
|
|
||||||
`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*)
|
Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
|
||||||
|
which case the second field is a callback which will be called with
|
||||||
|
the result from the `/login` call (including `access_token`,
|
||||||
|
`device_id`, etc.)
|
||||||
|
|
||||||
> This method, if implemented, is called when a user logs out. It is
|
* `check_3pid_auth(self, medium, address, password)`
|
||||||
> passed the qualified user ID, the ID of the deactivated device (if
|
|
||||||
> any: access tokens are occasionally created without an associated
|
This method, if implemented, is called when a user attempts to
|
||||||
> device ID), and the (now deactivated) access token.
|
register or log in with a third party identifier, such as email. It is
|
||||||
>
|
passed the medium (ex. "email"), an address (ex.
|
||||||
> It may return a Twisted `Deferred` object; the logout request will
|
"<jdoe@example.com>") and the user's password.
|
||||||
> wait for the deferred to complete but the result is ignored.
|
|
||||||
|
The method should return an `Awaitable` object, which resolves
|
||||||
|
to a `str` containing the user's (canonical) User id if
|
||||||
|
authentication was successful, and `None` if not.
|
||||||
|
|
||||||
|
As with `check_auth`, the `Awaitable` may alternatively resolve to a
|
||||||
|
`(user_id, callback)` tuple.
|
||||||
|
|
||||||
|
* `check_password(self, user_id, password)`
|
||||||
|
|
||||||
|
This method provides a simpler interface than
|
||||||
|
`get_supported_login_types` and `check_auth` for password auth
|
||||||
|
providers that just want to provide a mechanism for validating
|
||||||
|
`m.login.password` logins.
|
||||||
|
|
||||||
|
If implemented, it will be called to check logins with an
|
||||||
|
`m.login.password` login type. It is passed a qualified
|
||||||
|
`@localpart:domain` user id, and the password provided by the user.
|
||||||
|
|
||||||
|
The method should return an `Awaitable` object, which resolves
|
||||||
|
to `True` if authentication is successful, and `False` if not.
|
||||||
|
|
||||||
|
* `on_logged_out(self, user_id, device_id, access_token)`
|
||||||
|
|
||||||
|
This method, if implemented, is called when a user logs out. It is
|
||||||
|
passed the qualified user ID, the ID of the deactivated device (if
|
||||||
|
any: access tokens are occasionally created without an associated
|
||||||
|
device ID), and the (now deactivated) access token.
|
||||||
|
|
||||||
|
It may return an `Awaitable` object; the logout request will
|
||||||
|
wait for the `Awaitable` to complete, but the result is ignored.
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
@ -863,11 +864,15 @@ class AuthHandler(BaseHandler):
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
if hasattr(provider, "on_logged_out"):
|
||||||
await provider.on_logged_out(
|
# This might return an awaitable, if it does block the log out
|
||||||
|
# until it completes.
|
||||||
|
result = provider.on_logged_out(
|
||||||
user_id=str(user_info["user"]),
|
user_id=str(user_info["user"]),
|
||||||
device_id=user_info["device_id"],
|
device_id=user_info["device_id"],
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info["token_id"] is not None:
|
if user_info["token_id"] is not None:
|
||||||
|
|
|
@ -14,10 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
"""Check if the configuration of the homeserver allows this checker to work
|
"""Check if the configuration of the homeserver allows this checker to work
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this login type is enabled.
|
True if this login type is enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
"""Given the authentication dict from the client, attempt to check this step
|
"""Given the authentication dict from the client, attempt to check this step
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
authdict (dict): authentication dictionary from the client
|
authdict: authentication dictionary from the client
|
||||||
clientip (str): The IP address of the client.
|
clientip: The IP address of the client.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if authentication failed
|
SynapseError if authentication failed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: the result of authentication (to pass back to the client?)
|
The result of authentication (to pass back to the client?)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
|
||||||
def is_enabled(self):
|
def is_enabled(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict, clientip):
|
||||||
return defer.succeed(True)
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TermsAuthChecker(UserInteractiveAuthChecker):
|
class TermsAuthChecker(UserInteractiveAuthChecker):
|
||||||
|
@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
|
||||||
def is_enabled(self):
|
def is_enabled(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict, clientip):
|
||||||
return defer.succeed(True)
|
return True
|
||||||
|
|
||||||
|
|
||||||
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
|
@ -89,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
def is_enabled(self):
|
def is_enabled(self):
|
||||||
return self._enabled
|
return self._enabled
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_auth(self, authdict, clientip):
|
||||||
def check_auth(self, authdict, clientip):
|
|
||||||
try:
|
try:
|
||||||
user_response = authdict["response"]
|
user_response = authdict["response"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -107,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
# TODO: get this from the homeserver rather than creating a new one for
|
# TODO: get this from the homeserver rather than creating a new one for
|
||||||
# each request
|
# each request
|
||||||
try:
|
try:
|
||||||
resp_body = yield self._http_client.post_urlencoded_get_json(
|
resp_body = await self._http_client.post_urlencoded_get_json(
|
||||||
self._url,
|
self._url,
|
||||||
args={
|
args={
|
||||||
"secret": self._secret,
|
"secret": self._secret,
|
||||||
|
@ -219,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
|
||||||
ThreepidBehaviour.LOCAL,
|
ThreepidBehaviour.LOCAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict, clientip):
|
||||||
return defer.ensureDeferred(self._check_threepid("email", authdict))
|
return await self._check_threepid("email", authdict)
|
||||||
|
|
||||||
|
|
||||||
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||||
|
@ -233,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||||
def is_enabled(self):
|
def is_enabled(self):
|
||||||
return bool(self.hs.config.account_threepid_delegate_msisdn)
|
return bool(self.hs.config.account_threepid_delegate_msisdn)
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict, clientip):
|
||||||
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
|
return await self._check_threepid("msisdn", authdict)
|
||||||
|
|
||||||
|
|
||||||
INTERACTIVE_AUTH_CHECKERS = [
|
INTERACTIVE_AUTH_CHECKERS = [
|
||||||
|
|
Loading…
Reference in New Issue