Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com> This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
This commit is contained in:
parent
ecd823d766
commit
947dbbdfd1
|
@ -0,0 +1 @@
|
|||
Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.
|
|
@ -53,6 +53,7 @@
|
|||
- [Media](admin_api/media_admin_api.md)
|
||||
- [Purge History](admin_api/purge_history_api.md)
|
||||
- [Register Users](admin_api/register_api.md)
|
||||
- [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
|
||||
- [Manipulate Room Membership](admin_api/room_membership.md)
|
||||
- [Rooms](admin_api/rooms.md)
|
||||
- [Server Notices](admin_api/server_notices.md)
|
||||
|
|
|
@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
|||
# is using
|
||||
# - one for registration that ratelimits registration requests based on the
|
||||
# client's IP address.
|
||||
# - one for checking the validity of registration tokens that ratelimits
|
||||
# requests based on the client's IP address.
|
||||
# - one for login that ratelimits login requests based on the client's IP
|
||||
# address.
|
||||
# - one for login that ratelimits login requests based on the account the
|
||||
|
@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
|||
# per_second: 0.17
|
||||
# burst_count: 3
|
||||
#
|
||||
#rc_registration_token_validity:
|
||||
# per_second: 0.1
|
||||
# burst_count: 5
|
||||
#
|
||||
#rc_login:
|
||||
# address:
|
||||
# per_second: 0.17
|
||||
|
@ -1169,6 +1175,15 @@ url_preview_accept_language:
|
|||
#
|
||||
#enable_3pid_lookup: true
|
||||
|
||||
# Require users to submit a token during registration.
|
||||
# Tokens can be managed using the admin API:
|
||||
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
|
||||
# Note that `enable_registration` must be set to `true`.
|
||||
# Disabling this option will not delete any tokens previously generated.
|
||||
# Defaults to false. Uncomment the following to require tokens:
|
||||
#
|
||||
#registration_requires_token: true
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
|
|
@ -0,0 +1,295 @@
|
|||
# Registration Tokens
|
||||
|
||||
This API allows you to manage tokens which can be used to authenticate
|
||||
registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
|
||||
To use it, you will need to enable the `registration_requires_token` config
|
||||
option, and authenticate by providing an `access_token` for a server admin:
|
||||
see [Admin API](../../usage/administration/admin_api).
|
||||
Note that this API is still experimental; not all clients may support it yet.
|
||||
|
||||
|
||||
## Registration token objects
|
||||
|
||||
Most endpoints make use of JSON objects that contain details about tokens.
|
||||
These objects have the following fields:
|
||||
- `token`: The token which can be used to authenticate registration.
|
||||
- `uses_allowed`: The number of times the token can be used to complete a
|
||||
registration before it becomes invalid.
|
||||
- `pending`: The number of pending uses the token has. When someone uses
|
||||
the token to authenticate themselves, the pending counter is incremented
|
||||
so that the token is not used more than the permitted number of times.
|
||||
When the person completes registration the pending counter is decremented,
|
||||
and the completed counter is incremented.
|
||||
- `completed`: The number of times the token has been used to successfully
|
||||
complete a registration.
|
||||
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||
To convert this into a human-readable form you can remove the milliseconds
|
||||
and use the `date` command. For example, `date -d '@1625394937'`.
|
||||
|
||||
|
||||
## List all tokens
|
||||
|
||||
Lists all tokens and details about them. If the request is successful, the top
|
||||
level JSON object will have a `registration_tokens` key which is an array of
|
||||
registration token objects.
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens
|
||||
```
|
||||
|
||||
Optional query parameters:
|
||||
- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
|
||||
If `false`, only tokens that have expired or have had all uses exhausted are
|
||||
returned. If omitted, all tokens are returned regardless of validity.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"registration_tokens": [
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
},
|
||||
{
|
||||
"token": "pqrs",
|
||||
"uses_allowed": 2,
|
||||
"pending": 1,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
},
|
||||
{
|
||||
"token": "wxyz",
|
||||
"uses_allowed": null,
|
||||
"pending": 0,
|
||||
"completed": 9,
|
||||
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Example using the `valid` query parameter:
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens?valid=false
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"registration_tokens": [
|
||||
{
|
||||
"token": "pqrs",
|
||||
"uses_allowed": 2,
|
||||
"pending": 1,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
},
|
||||
{
|
||||
"token": "wxyz",
|
||||
"uses_allowed": null,
|
||||
"pending": 0,
|
||||
"completed": 9,
|
||||
"expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Get one token
|
||||
|
||||
Get details about a single token. If the request is successful, the response
|
||||
body will be a registration token object.
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens/<token>
|
||||
```
|
||||
|
||||
Path parameters:
|
||||
- `token`: The registration token to return details of.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens/abcd
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Create token
|
||||
|
||||
Create a new registration token. If the request is successful, the newly created
|
||||
token will be returned as a registration token object in the response body.
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/registration_tokens/new
|
||||
```
|
||||
|
||||
The request body must be a JSON object and can contain the following fields:
|
||||
- `token`: The registration token. A string of no more than 64 characters that
|
||||
consists only of characters matched by the regex `[A-Za-z0-9-_]`.
|
||||
Default: randomly generated.
|
||||
- `uses_allowed`: The integer number of times the token can be used to complete
|
||||
a registration before it becomes invalid.
|
||||
Default: `null` (unlimited uses).
|
||||
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||
You could use, for example, `date '+%s000' -d 'tomorrow'`.
|
||||
Default: `null` (token does not expire).
|
||||
- `length`: The length of the token randomly generated if `token` is not
|
||||
specified. Must be between 1 and 64 inclusive. Default: `16`.
|
||||
|
||||
If a field is omitted the default is used.
|
||||
|
||||
Example using defaults:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/registration_tokens/new
|
||||
|
||||
{}
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "0M-9jbkf2t_Tgiw1",
|
||||
"uses_allowed": null,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": null
|
||||
}
|
||||
```
|
||||
|
||||
Example specifying some fields:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/registration_tokens/new
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1
|
||||
}
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": null
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Update token
|
||||
|
||||
Update the number of allowed uses or expiry time of a token. If the request is
|
||||
successful, the updated token will be returned as a registration token object
|
||||
in the response body.
|
||||
|
||||
```
|
||||
PUT /_synapse/admin/v1/registration_tokens/<token>
|
||||
```
|
||||
|
||||
Path parameters:
|
||||
- `token`: The registration token to update.
|
||||
|
||||
The request body must be a JSON object and can contain the following fields:
|
||||
- `uses_allowed`: The integer number of times the token can be used to complete
|
||||
a registration before it becomes invalid. By setting `uses_allowed` to `0`
|
||||
the token can be easily made invalid without deleting it.
|
||||
If `null` the token will have an unlimited number of uses.
|
||||
- `expiry_time`: The latest time the token is valid. Given as the number of
|
||||
milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
|
||||
If `null` the token will not expire.
|
||||
|
||||
If a field is omitted its value is not modified.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
PUT /_synapse/admin/v1/registration_tokens/defg
|
||||
|
||||
{
|
||||
"expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC
|
||||
}
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": 4781243146000
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Delete token
|
||||
|
||||
Delete a registration token. If the request is successful, the response body
|
||||
will be an empty JSON object.
|
||||
|
||||
```
|
||||
DELETE /_synapse/admin/v1/registration_tokens/<token>
|
||||
```
|
||||
|
||||
Path parameters:
|
||||
- `token`: The registration token to delete.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
DELETE /_synapse/admin/v1/registration_tokens/wxyz
|
||||
```
|
||||
```
|
||||
200 OK
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
|
||||
## Errors
|
||||
|
||||
If a request fails a "standard error response" will be returned as defined in
|
||||
the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
|
||||
|
||||
For example, if the token specified in a path parameter does not exist a
|
||||
`404 Not Found` error will be returned.
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/registration_tokens/1234
|
||||
```
|
||||
```
|
||||
404 Not Found
|
||||
|
||||
{
|
||||
"errcode": "M_NOT_FOUND",
|
||||
"error": "No such registration token: 1234"
|
||||
}
|
||||
```
|
|
@ -236,6 +236,7 @@ expressions:
|
|||
# Registration/login requests
|
||||
^/_matrix/client/(api/v1|r0|unstable)/login$
|
||||
^/_matrix/client/(r0|unstable)/register$
|
||||
^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
|
||||
|
||||
# Event sending requests
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
|
||||
|
|
|
@ -79,6 +79,7 @@ class LoginType:
|
|||
TERMS = "m.login.terms"
|
||||
SSO = "m.login.sso"
|
||||
DUMMY = "m.login.dummy"
|
||||
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||
|
||||
|
||||
# This is used in the `type` parameter for /register when called by
|
||||
|
|
|
@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
|
|||
ProfileRestServlet,
|
||||
)
|
||||
from synapse.rest.client.push_rule import PushRuleRestServlet
|
||||
from synapse.rest.client.register import RegisterRestServlet
|
||||
from synapse.rest.client.register import (
|
||||
RegisterRestServlet,
|
||||
RegistrationTokenValidityRestServlet,
|
||||
)
|
||||
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
from synapse.rest.client.voip import VoipRestServlet
|
||||
|
@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
|
|||
resource = JsonResource(self, canonical_json=False)
|
||||
|
||||
RegisterRestServlet(self).register(resource)
|
||||
RegistrationTokenValidityRestServlet(self).register(resource)
|
||||
login.register_servlets(self, resource)
|
||||
ThreepidRestServlet(self).register(resource)
|
||||
DevicesRestServlet(self).register(resource)
|
||||
|
|
|
@ -79,6 +79,11 @@ class RatelimitConfig(Config):
|
|||
|
||||
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
|
||||
|
||||
self.rc_registration_token_validity = RateLimitConfig(
|
||||
config.get("rc_registration_token_validity", {}),
|
||||
defaults={"per_second": 0.1, "burst_count": 5},
|
||||
)
|
||||
|
||||
rc_login_config = config.get("rc_login", {})
|
||||
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
|
||||
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
|
||||
|
@ -143,6 +148,8 @@ class RatelimitConfig(Config):
|
|||
# is using
|
||||
# - one for registration that ratelimits registration requests based on the
|
||||
# client's IP address.
|
||||
# - one for checking the validity of registration tokens that ratelimits
|
||||
# requests based on the client's IP address.
|
||||
# - one for login that ratelimits login requests based on the client's IP
|
||||
# address.
|
||||
# - one for login that ratelimits login requests based on the account the
|
||||
|
@ -171,6 +178,10 @@ class RatelimitConfig(Config):
|
|||
# per_second: 0.17
|
||||
# burst_count: 3
|
||||
#
|
||||
#rc_registration_token_validity:
|
||||
# per_second: 0.1
|
||||
# burst_count: 5
|
||||
#
|
||||
#rc_login:
|
||||
# address:
|
||||
# per_second: 0.17
|
||||
|
|
|
@ -33,6 +33,9 @@ class RegistrationConfig(Config):
|
|||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
||||
self.registration_requires_token = config.get(
|
||||
"registration_requires_token", False
|
||||
)
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
|
@ -140,6 +143,9 @@ class RegistrationConfig(Config):
|
|||
"mechanism by removing the `access_token_lifetime` option."
|
||||
)
|
||||
|
||||
# The fallback template used for authenticating using a registration token
|
||||
self.registration_token_template = self.read_template("registration_token.html")
|
||||
|
||||
# The success template used during fallback auth.
|
||||
self.fallback_success_template = self.read_template("auth_success.html")
|
||||
|
||||
|
@ -199,6 +205,15 @@ class RegistrationConfig(Config):
|
|||
#
|
||||
#enable_3pid_lookup: true
|
||||
|
||||
# Require users to submit a token during registration.
|
||||
# Tokens can be managed using the admin API:
|
||||
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
|
||||
# Note that `enable_registration` must be set to `true`.
|
||||
# Disabling this option will not delete any tokens previously generated.
|
||||
# Defaults to false. Uncomment the following to require tokens:
|
||||
#
|
||||
#registration_requires_token: true
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
|
|
@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
|
|||
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
|
||||
# for.
|
||||
REQUEST_USER_ID = "request_user_id"
|
||||
|
||||
# used during registration to store the registration token used (if required) so that:
|
||||
# - we can prevent a token being used twice by one session
|
||||
# - we can 'use up' the token after registration has successfully completed
|
||||
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||
|
|
|
@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
|||
return await self._check_threepid("msisdn", authdict)
|
||||
|
||||
|
||||
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self._enabled = bool(hs.config.registration_requires_token)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||
if "token" not in authdict:
|
||||
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
|
||||
if not isinstance(authdict["token"], str):
|
||||
raise LoginError(
|
||||
400, "Registration token must be a string", Codes.INVALID_PARAM
|
||||
)
|
||||
if "session" not in authdict:
|
||||
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
|
||||
|
||||
# Get these here to avoid cyclic dependencies
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
|
||||
session = authdict["session"]
|
||||
token = authdict["token"]
|
||||
|
||||
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
|
||||
# return early to avoid incrementing `pending` again.
|
||||
stored_token = await auth_handler.get_session_data(
|
||||
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
|
||||
)
|
||||
if stored_token:
|
||||
if token != stored_token:
|
||||
raise LoginError(
|
||||
400, "Registration token has changed", Codes.INVALID_PARAM
|
||||
)
|
||||
else:
|
||||
return token
|
||||
|
||||
if await self.store.registration_token_is_valid(token):
|
||||
# Increment pending counter, so that if token has limited uses it
|
||||
# can't be used up by someone else in the meantime.
|
||||
await self.store.set_registration_token_pending(token)
|
||||
# Store the token in the UIA session, so that once registration
|
||||
# is complete `completed` can be incremented.
|
||||
await auth_handler.set_session_data(
|
||||
session,
|
||||
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||
token,
|
||||
)
|
||||
# The token will be stored as the result of the authentication stage
|
||||
# in ui_auth_sessions_credentials. This allows the pending counter
|
||||
# for tokens to be decremented when expired sessions are deleted.
|
||||
return token
|
||||
else:
|
||||
raise LoginError(
|
||||
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
INTERACTIVE_AUTH_CHECKERS = [
|
||||
DummyAuthChecker,
|
||||
TermsAuthChecker,
|
||||
RecaptchaAuthChecker,
|
||||
EmailIdentityAuthChecker,
|
||||
MsisdnAuthChecker,
|
||||
RegistrationTokenAuthChecker,
|
||||
]
|
||||
"""A list of UserInteractiveAuthChecker classes"""
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
<html>
|
||||
<head>
|
||||
<title>Authentication</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<form id="registrationForm" method="post" action="{{ myurl }}">
|
||||
<div>
|
||||
{% if error is defined %}
|
||||
<p class="error"><strong>Error: {{ error }}</strong></p>
|
||||
{% endif %}
|
||||
<p>
|
||||
Please enter a registration token.
|
||||
</p>
|
||||
<input type="hidden" name="session" value="{{ session }}" />
|
||||
<input type="text" name="token" />
|
||||
<input type="submit" value="Authenticate" />
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
|
@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
|
|||
)
|
||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.registration_tokens import (
|
||||
ListRegistrationTokensRestServlet,
|
||||
NewRegistrationTokenRestServlet,
|
||||
RegistrationTokenRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.rooms import (
|
||||
DeleteRoomRestServlet,
|
||||
ForwardExtremitiesRestServlet,
|
||||
|
@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
RoomEventContextServlet(hs).register(http_server)
|
||||
RateLimitRestServlet(hs).register(http_server)
|
||||
UsernameAvailableRestServlet(hs).register(http_server)
|
||||
ListRegistrationTokensRestServlet(hs).register(http_server)
|
||||
NewRegistrationTokenRestServlet(hs).register(http_server)
|
||||
RegistrationTokenRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
# Copyright 2021 Callum Brown
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import string
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_boolean,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListRegistrationTokensRestServlet(RestServlet):
|
||||
"""List registration tokens.
|
||||
|
||||
To list all tokens:
|
||||
|
||||
GET /_synapse/admin/v1/registration_tokens
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"registration_tokens": [
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
},
|
||||
{
|
||||
"token": "wxyz",
|
||||
"uses_allowed": null,
|
||||
"pending": 0,
|
||||
"completed": 9,
|
||||
"expiry_time": 1625394937000
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
The optional query parameter `valid` can be used to filter the response.
|
||||
If it is `true`, only valid tokens are returned. If it is `false`, only
|
||||
tokens that have expired or have had all uses exhausted are returned.
|
||||
If it is omitted, all tokens are returned regardless of validity.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
valid = parse_boolean(request, "valid")
|
||||
token_list = await self.store.get_registration_tokens(valid)
|
||||
return 200, {"registration_tokens": token_list}
|
||||
|
||||
|
||||
class NewRegistrationTokenRestServlet(RestServlet):
|
||||
"""Create a new registration token.
|
||||
|
||||
For example, to create a token specifying some fields:
|
||||
|
||||
POST /_synapse/admin/v1/registration_tokens/new
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1
|
||||
}
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": null
|
||||
}
|
||||
|
||||
Defaults are used for any fields not specified.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens/new$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
# A string of all the characters allowed to be in a registration_token
|
||||
self.allowed_chars = string.ascii_letters + string.digits + "-_"
|
||||
self.allowed_chars_set = set(self.allowed_chars)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "token" in body:
|
||||
token = body["token"]
|
||||
if not isinstance(token, str):
|
||||
raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
|
||||
if not (0 < len(token) <= 64):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"token must not be empty and must not be longer than 64 characters",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
if not set(token).issubset(self.allowed_chars_set):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
else:
|
||||
# Get length of token to generate (default is 16)
|
||||
length = body.get("length", 16)
|
||||
if not isinstance(length, int):
|
||||
raise SynapseError(
|
||||
400, "length must be an integer", Codes.INVALID_PARAM
|
||||
)
|
||||
if not (0 < length <= 64):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"length must be greater than zero and not greater than 64",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# Generate token
|
||||
token = await self.store.generate_registration_token(
|
||||
length, self.allowed_chars
|
||||
)
|
||||
|
||||
uses_allowed = body.get("uses_allowed", None)
|
||||
if not (
|
||||
uses_allowed is None
|
||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||
):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"uses_allowed must be a non-negative integer or null",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
expiry_time = body.get("expiry_time", None)
|
||||
if not isinstance(expiry_time, (int, type(None))):
|
||||
raise SynapseError(
|
||||
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||
)
|
||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||
raise SynapseError(
|
||||
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
created = await self.store.create_registration_token(
|
||||
token, uses_allowed, expiry_time
|
||||
)
|
||||
if not created:
|
||||
raise SynapseError(
|
||||
400, f"Token already exists: {token}", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
resp = {
|
||||
"token": token,
|
||||
"uses_allowed": uses_allowed,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": expiry_time,
|
||||
}
|
||||
return 200, resp
|
||||
|
||||
|
||||
class RegistrationTokenRestServlet(RestServlet):
|
||||
"""Retrieve, update, or delete the given token.
|
||||
|
||||
For example,
|
||||
|
||||
to retrieve a token:
|
||||
|
||||
GET /_synapse/admin/v1/registration_tokens/abcd
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
}
|
||||
|
||||
|
||||
to update a token:
|
||||
|
||||
PUT /_synapse/admin/v1/registration_tokens/defg
|
||||
|
||||
{
|
||||
"uses_allowed": 5,
|
||||
"expiry_time": 4781243146000
|
||||
}
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 5,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": 4781243146000
|
||||
}
|
||||
|
||||
|
||||
to delete a token:
|
||||
|
||||
DELETE /_synapse/admin/v1/registration_tokens/wxyz
|
||||
|
||||
200 OK
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||
"""Retrieve a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
token_info = await self.store.get_one_registration_token(token)
|
||||
|
||||
# If no result return a 404
|
||||
if token_info is None:
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
||||
|
||||
return 200, token_info
|
||||
|
||||
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||
"""Update a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
body = parse_json_object_from_request(request)
|
||||
new_attributes = {}
|
||||
|
||||
# Only add uses_allowed to new_attributes if it is present and valid
|
||||
if "uses_allowed" in body:
|
||||
uses_allowed = body["uses_allowed"]
|
||||
if not (
|
||||
uses_allowed is None
|
||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||
):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"uses_allowed must be a non-negative integer or null",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
new_attributes["uses_allowed"] = uses_allowed
|
||||
|
||||
if "expiry_time" in body:
|
||||
expiry_time = body["expiry_time"]
|
||||
if not isinstance(expiry_time, (int, type(None))):
|
||||
raise SynapseError(
|
||||
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||
)
|
||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||
raise SynapseError(
|
||||
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||
)
|
||||
new_attributes["expiry_time"] = expiry_time
|
||||
|
||||
if len(new_attributes) == 0:
|
||||
# Nothing to update, get token info to return
|
||||
token_info = await self.store.get_one_registration_token(token)
|
||||
else:
|
||||
token_info = await self.store.update_registration_token(
|
||||
token, new_attributes
|
||||
)
|
||||
|
||||
# If no result return a 404
|
||||
if token_info is None:
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
||||
|
||||
return 200, token_info
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, token: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Delete a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if await self.store.delete_registration_token(token):
|
||||
return 200, {}
|
||||
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
|
@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
self.recaptcha_template = hs.config.recaptcha_template
|
||||
self.terms_template = hs.config.terms_template
|
||||
self.registration_token_template = hs.config.registration_token_template
|
||||
self.success_template = hs.config.fallback_success_template
|
||||
|
||||
async def on_GET(self, request, stagetype):
|
||||
|
@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
|
|||
# re-authenticate with their SSO provider.
|
||||
html = await self.auth_handler.start_sso_ui_auth(request, session)
|
||||
|
||||
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||
html = self.registration_token_template.render(
|
||||
session=session,
|
||||
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||
)
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
|
|||
# The SSO fallback workflow should not post here,
|
||||
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
|
||||
|
||||
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||
token = parse_string(request, "token", required=True)
|
||||
authdict = {"session": session, "token": token}
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
||||
)
|
||||
except LoginError as e:
|
||||
html = self.registration_token_template.render(
|
||||
session=session,
|
||||
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||
error=e.msg,
|
||||
)
|
||||
else:
|
||||
html = self.success_template.render()
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from synapse.api.errors import (
|
|||
ThreepidValidationError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.captcha import CaptchaConfig
|
||||
from synapse.config.consent import ConsentConfig
|
||||
|
@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
|||
return 200, {"available": True}
|
||||
|
||||
|
||||
class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
"""Check the validity of a registration token.
|
||||
|
||||
Example:
|
||||
|
||||
GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"valid": true
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
|
||||
releases=(),
|
||||
unstable=True,
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
|
||||
)
|
||||
|
||||
async def on_GET(self, request):
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
||||
|
||||
if not self.hs.config.enable_registration:
|
||||
raise SynapseError(
|
||||
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
token = parse_string(request, "token", required=True)
|
||||
valid = await self.store.registration_token_is_valid(token)
|
||||
|
||||
return 200, {"valid": valid}
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/register$")
|
||||
|
||||
|
@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
if registered:
|
||||
# Check if a token was used to authenticate registration
|
||||
registration_token = await self.auth_handler.get_session_data(
|
||||
session_id,
|
||||
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||
)
|
||||
if registration_token:
|
||||
# Increment the `completed` counter for the token
|
||||
await self.store.use_registration_token(registration_token)
|
||||
# Indicate that the token has been successfully used so that
|
||||
# pending is not decremented again when expiring old UIA sessions.
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session_id,
|
||||
LoginType.REGISTRATION_TOKEN,
|
||||
True,
|
||||
)
|
||||
|
||||
await self.registration_handler.post_registration_actions(
|
||||
user_id=registered_user_id,
|
||||
auth_result=auth_result,
|
||||
|
@ -868,6 +934,11 @@ def _calculate_registration_flows(
|
|||
for flow in flows:
|
||||
flow.insert(0, LoginType.RECAPTCHA)
|
||||
|
||||
# Prepend registration token to all flows if we're requiring a token
|
||||
if config.registration_requires_token:
|
||||
for flow in flows:
|
||||
flow.insert(0, LoginType.REGISTRATION_TOKEN)
|
||||
|
||||
return flows
|
||||
|
||||
|
||||
|
@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
|
|||
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
UsernameAvailabilityRestServlet(hs).register(http_server)
|
||||
RegistrationSubmitTokenServlet(hs).register(http_server)
|
||||
RegistrationTokenValidityRestServlet(hs).register(http_server)
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="update_access_token_last_validated",
|
||||
)
|
||||
|
||||
async def registration_token_is_valid(self, token: str) -> bool:
|
||||
"""Checks if a token can be used to authenticate a registration.
|
||||
|
||||
Args:
|
||||
token: The registration token to be checked
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
res = await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# Check if the token exists
|
||||
if res is None:
|
||||
return False
|
||||
|
||||
# Check if the token has expired
|
||||
now = self._clock.time_msec()
|
||||
if res["expiry_time"] and res["expiry_time"] < now:
|
||||
return False
|
||||
|
||||
# Check if the token has been used up
|
||||
if (
|
||||
res["uses_allowed"]
|
||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||
):
|
||||
return False
|
||||
|
||||
# Otherwise, the token is valid
|
||||
return True
|
||||
|
||||
async def set_registration_token_pending(self, token: str) -> None:
|
||||
"""Increment the pending registrations counter for a token.
|
||||
|
||||
Args:
|
||||
token: The registration token pending use
|
||||
"""
|
||||
|
||||
def _set_registration_token_pending_txn(txn):
|
||||
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"pending": pending + 1},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||
)
|
||||
|
||||
async def use_registration_token(self, token: str) -> None:
|
||||
"""Complete a use of the given registration token.
|
||||
|
||||
The `pending` counter will be decremented, and the `completed`
|
||||
counter will be incremented.
|
||||
|
||||
Args:
|
||||
token: The registration token to be 'used'
|
||||
"""
|
||||
|
||||
def _use_registration_token_txn(txn):
|
||||
# Normally, res is Optional[Dict[str, Any]].
|
||||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
# about None not being indexable.
|
||||
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
) # type: ignore
|
||||
|
||||
# Decrement pending and increment completed
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={
|
||||
"completed": res["completed"] + 1,
|
||||
"pending": res["pending"] - 1,
|
||||
},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"use_registration_token", _use_registration_token_txn
|
||||
)
|
||||
|
||||
async def get_registration_tokens(
|
||||
self, valid: Optional[bool] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List all registration tokens. Used by the admin API.
|
||||
|
||||
Args:
|
||||
valid: If True, only valid tokens are returned.
|
||||
If False, only invalid tokens are returned.
|
||||
Default is None: return all tokens regardless of validity.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each containing details of a token.
|
||||
"""
|
||||
|
||||
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
||||
if valid is None:
|
||||
# Return all tokens regardless of validity
|
||||
txn.execute("SELECT * FROM registration_tokens")
|
||||
|
||||
elif valid:
|
||||
# Select valid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
|
||||
"AND (expiry_time > ? OR expiry_time IS NULL)"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
else:
|
||||
# Select invalid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"uses_allowed <= pending + completed OR expiry_time <= ?"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"select_registration_tokens",
|
||||
select_registration_tokens_txn,
|
||||
self._clock.time_msec(),
|
||||
valid,
|
||||
)
|
||||
|
||||
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get info about the given registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to retrieve information about.
|
||||
|
||||
Returns:
|
||||
A dict, or None if token doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
desc="get_one_registration_token",
|
||||
)
|
||||
|
||||
async def generate_registration_token(
|
||||
self, length: int, chars: str
|
||||
) -> Optional[str]:
|
||||
"""Generate a random registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
length: The length of the token to generate.
|
||||
chars: A string of the characters allowed in the generated token.
|
||||
|
||||
Returns:
|
||||
The generated token.
|
||||
|
||||
Raises:
|
||||
SynapseError if a unique registration token could still not be
|
||||
generated after a few tries.
|
||||
"""
|
||||
# Make a few attempts at generating a unique token of the required
|
||||
# length before failing.
|
||||
for _i in range(3):
|
||||
# Generate token
|
||||
token = "".join(random.choices(chars, k=length))
|
||||
|
||||
# Check if the token already exists
|
||||
existing_token = await self.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="token",
|
||||
allow_none=True,
|
||||
desc="check_if_registration_token_exists",
|
||||
)
|
||||
|
||||
if existing_token is None:
|
||||
# The generated token doesn't exist yet, return it
|
||||
return token
|
||||
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Unable to generate a unique registration token. Try again with a greater length",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
async def create_registration_token(
|
||||
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
|
||||
) -> bool:
|
||||
"""Create a new registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to create.
|
||||
uses_allowed: The number of times the token can be used to complete
|
||||
a registration before it becomes invalid. A value of None indicates
|
||||
unlimited uses.
|
||||
expiry_time: The latest time the token is valid. Given as the
|
||||
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
|
||||
None indicates that the token does not expire.
|
||||
|
||||
Returns:
|
||||
Whether the row was inserted or not.
|
||||
"""
|
||||
|
||||
def _create_registration_token_txn(txn):
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# Token already exists
|
||||
return False
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
values={
|
||||
"token": token,
|
||||
"uses_allowed": uses_allowed,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": expiry_time,
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"create_registration_token", _create_registration_token_txn
|
||||
)
|
||||
|
||||
async def update_registration_token(
|
||||
self, token: str, updatevalues: Dict[str, Optional[int]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to update.
|
||||
updatevalues: A dict with the fields to update. E.g.:
|
||||
`{"uses_allowed": 3}` to update just uses_allowed, or
|
||||
`{"uses_allowed": 3, "expiry_time": None}` to update both.
|
||||
This is passed straight to simple_update_one.
|
||||
|
||||
Returns:
|
||||
A dict with all info about the token, or None if token doesn't exist.
|
||||
"""
|
||||
|
||||
def _update_registration_token_txn(txn):
|
||||
try:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
except StoreError:
|
||||
# Update failed because token does not exist
|
||||
return None
|
||||
|
||||
# Get all info about the token so it can be sent in the response
|
||||
return self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=[
|
||||
"token",
|
||||
"uses_allowed",
|
||||
"pending",
|
||||
"completed",
|
||||
"expiry_time",
|
||||
],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_registration_token", _update_registration_token_txn
|
||||
)
|
||||
|
||||
async def delete_registration_token(self, token: str) -> bool:
|
||||
"""Delete a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to delete.
|
||||
|
||||
Returns:
|
||||
Whether the token was successfully deleted or not.
|
||||
"""
|
||||
try:
|
||||
await self.db_pool.simple_delete_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
desc="delete_registration_token",
|
||||
)
|
||||
except StoreError:
|
||||
# Deletion failed because token does not exist
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@cached()
|
||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||
"""
|
||||
|
|
|
@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
keyvalues={},
|
||||
)
|
||||
|
||||
# If a registration token was used, decrement the pending counter
|
||||
# before deleting the session.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_credentials",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
||||
retcols=["result"],
|
||||
)
|
||||
|
||||
# Get the tokens used and how much pending needs to be decremented by.
|
||||
token_counts: Dict[str, int] = {}
|
||||
for r in rows:
|
||||
# If registration was successfully completed, the result of the
|
||||
# registration token stage for that session will be True.
|
||||
# If a token was used to authenticate, but registration was
|
||||
# never completed, the result will be the token used.
|
||||
token = db_to_json(r["result"])
|
||||
if isinstance(token, str):
|
||||
token_counts[token] = token_counts.get(token, 0) + 1
|
||||
|
||||
# Update the `pending` counters.
|
||||
if len(token_counts) > 0:
|
||||
token_rows = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
column="token",
|
||||
iterable=list(token_counts.keys()),
|
||||
keyvalues={},
|
||||
retcols=["token", "pending"],
|
||||
)
|
||||
for token_row in token_rows:
|
||||
token = token_row["token"]
|
||||
new_pending = token_row["pending"] - token_counts[token]
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"pending": new_pending},
|
||||
)
|
||||
|
||||
# Delete the corresponding completed credentials.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2021 Callum Brown
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS registration_tokens(
|
||||
token TEXT NOT NULL, -- The token that can be used for authentication.
|
||||
uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
|
||||
pending INT NOT NULL, -- The number of in progress registrations using this token.
|
||||
completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
|
||||
expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
|
||||
UNIQUE (token)
|
||||
);
|
|
@ -0,0 +1,710 @@
|
|||
# Copyright 2021 Callum Brown
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest.client import login
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
self.other_user = self.register_user("user", "pass")
|
||||
self.other_user_tok = self.login("user", "pass")
|
||||
|
||||
self.url = "/_synapse/admin/v1/registration_tokens"
|
||||
|
||||
def _new_token(self, **kwargs):
|
||||
"""Helper function to create a token."""
|
||||
token = kwargs.get(
|
||||
"token",
|
||||
"".join(random.choices(string.ascii_letters, k=8)),
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": kwargs.get("uses_allowed", None),
|
||||
"pending": kwargs.get("pending", 0),
|
||||
"completed": kwargs.get("completed", 0),
|
||||
"expiry_time": kwargs.get("expiry_time", None),
|
||||
},
|
||||
)
|
||||
)
|
||||
return token
|
||||
|
||||
# CREATION
|
||||
|
||||
def test_create_no_auth(self):
|
||||
"""Try to create a token without authentication."""
|
||||
channel = self.make_request("POST", self.url + "/new", {})
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_create_requester_not_admin(self):
|
||||
"""Try to create a token while not an admin."""
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{},
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_create_using_defaults(self):
|
||||
"""Create a token using all the defaults."""
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(len(channel.json_body["token"]), 16)
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
self.assertEqual(channel.json_body["pending"], 0)
|
||||
self.assertEqual(channel.json_body["completed"], 0)
|
||||
|
||||
def test_create_specifying_fields(self):
|
||||
"""Create a token specifying the value of all fields."""
|
||||
data = {
|
||||
"token": "abcd",
|
||||
"uses_allowed": 1,
|
||||
"expiry_time": self.clock.time_msec() + 1000000,
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["token"], "abcd")
|
||||
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
|
||||
self.assertEqual(channel.json_body["pending"], 0)
|
||||
self.assertEqual(channel.json_body["completed"], 0)
|
||||
|
||||
def test_create_with_null_value(self):
|
||||
"""Create a token specifying unlimited uses and no expiry."""
|
||||
data = {
|
||||
"uses_allowed": None,
|
||||
"expiry_time": None,
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(len(channel.json_body["token"]), 16)
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
self.assertEqual(channel.json_body["pending"], 0)
|
||||
self.assertEqual(channel.json_body["completed"], 0)
|
||||
|
||||
def test_create_token_too_long(self):
|
||||
"""Check token longer than 64 chars is invalid."""
|
||||
data = {"token": "a" * 65}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_create_token_invalid_chars(self):
|
||||
"""Check you can't create token with invalid characters."""
|
||||
data = {
|
||||
"token": "abc/def",
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_create_token_already_exists(self):
|
||||
"""Check you can't create token that already exists."""
|
||||
data = {
|
||||
"token": "abcd",
|
||||
}
|
||||
|
||||
channel1 = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
|
||||
|
||||
channel2 = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
|
||||
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_create_unable_to_generate_token(self):
|
||||
"""Check right error is raised when server can't generate unique token."""
|
||||
# Create all possible single character tokens
|
||||
tokens = []
|
||||
for c in string.ascii_letters + string.digits + "-_":
|
||||
tokens.append(
|
||||
{
|
||||
"token": c,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
}
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert_many(
|
||||
"registration_tokens",
|
||||
tokens,
|
||||
"create_all_registration_tokens",
|
||||
)
|
||||
)
|
||||
|
||||
# Check creating a single character token fails with a 500 status code
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": 1},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
def test_create_uses_allowed(self):
|
||||
"""Check you can only create a token with good values for uses_allowed."""
|
||||
# Should work with 0 (token is invalid from the start)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"uses_allowed": 0},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["uses_allowed"], 0)
|
||||
|
||||
# Should fail with negative integer
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"uses_allowed": -5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with float
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"uses_allowed": 1.5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_create_expiry_time(self):
|
||||
"""Check you can't create a token with an invalid expiry_time."""
|
||||
# Should fail with a time in the past
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"expiry_time": self.clock.time_msec() - 10000},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with float
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"expiry_time": self.clock.time_msec() + 1000000.5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_create_length(self):
|
||||
"""Check you can only generate a token with a valid length."""
|
||||
# Should work with 64
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": 64},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(len(channel.json_body["token"]), 64)
|
||||
|
||||
# Should fail with 0
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": 0},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with a negative integer
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": -5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with a float
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": 8.5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with 65
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url + "/new",
|
||||
{"length": 65},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# UPDATING
|
||||
|
||||
def test_update_no_auth(self):
|
||||
"""Try to update a token without authentication."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
)
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_update_requester_not_admin(self):
|
||||
"""Try to update a token while not an admin."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_update_non_existent(self):
|
||||
"""Try to update a token that doesn't exist."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/1234",
|
||||
{"uses_allowed": 1},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
def test_update_uses_allowed(self):
|
||||
"""Test updating just uses_allowed."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
|
||||
# Should succeed with 1
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"uses_allowed": 1},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
|
||||
# Should succeed with 0 (makes token invalid)
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"uses_allowed": 0},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["uses_allowed"], 0)
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
|
||||
# Should succeed with null
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"uses_allowed": None},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
|
||||
# Should fail with a float
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"uses_allowed": 1.5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail with a negative integer
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"uses_allowed": -5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_update_expiry_time(self):
|
||||
"""Test updating just expiry_time."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
new_expiry_time = self.clock.time_msec() + 1000000
|
||||
|
||||
# Should succeed with a time in the future
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"expiry_time": new_expiry_time},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
|
||||
# Should succeed with null
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"expiry_time": None},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
|
||||
# Should fail with a time in the past
|
||||
past_time = self.clock.time_msec() - 10000
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"expiry_time": past_time},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# Should fail a float
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
{"expiry_time": new_expiry_time + 0.5},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
def test_update_both(self):
|
||||
"""Test updating both uses_allowed and expiry_time."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
new_expiry_time = self.clock.time_msec() + 1000000
|
||||
|
||||
data = {
|
||||
"uses_allowed": 1,
|
||||
"expiry_time": new_expiry_time,
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["uses_allowed"], 1)
|
||||
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
|
||||
|
||||
def test_update_invalid_type(self):
|
||||
"""Test using invalid types doesn't work."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
|
||||
data = {
|
||||
"uses_allowed": False,
|
||||
"expiry_time": "1626430124000",
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url + "/" + token,
|
||||
data,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
|
||||
# DELETING
|
||||
|
||||
def test_delete_no_auth(self):
|
||||
"""Try to delete a token without authentication."""
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
)
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_delete_requester_not_admin(self):
|
||||
"""Try to delete a token while not an admin."""
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_delete_non_existent(self):
|
||||
"""Try to delete a token that doesn't exist."""
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
self.url + "/1234",
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting a token."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
self.url + "/" + token,
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# GETTING ONE
|
||||
|
||||
def test_get_no_auth(self):
|
||||
"""Try to get a token without authentication."""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
)
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_get_requester_not_admin(self):
|
||||
"""Try to get a token while not an admin."""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "/1234", # Token doesn't exist but that doesn't matter
|
||||
{},
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_get_non_existent(self):
|
||||
"""Try to get a token that doesn't exist."""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "/1234",
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
def test_get(self):
|
||||
"""Test getting a token."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "/" + token,
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["token"], token)
|
||||
self.assertIsNone(channel.json_body["uses_allowed"])
|
||||
self.assertIsNone(channel.json_body["expiry_time"])
|
||||
self.assertEqual(channel.json_body["pending"], 0)
|
||||
self.assertEqual(channel.json_body["completed"], 0)
|
||||
|
||||
# LISTING
|
||||
|
||||
def test_list_no_auth(self):
|
||||
"""Try to list tokens without authentication."""
|
||||
channel = self.make_request("GET", self.url, {})
|
||||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||
|
||||
def test_list_requester_not_admin(self):
|
||||
"""Try to list tokens while not an admin."""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url,
|
||||
{},
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_list_all(self):
|
||||
"""Test listing all tokens."""
|
||||
# Create new token using default values
|
||||
token = self._new_token()
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url,
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
|
||||
token_info = channel.json_body["registration_tokens"][0]
|
||||
self.assertEqual(token_info["token"], token)
|
||||
self.assertIsNone(token_info["uses_allowed"])
|
||||
self.assertIsNone(token_info["expiry_time"])
|
||||
self.assertEqual(token_info["pending"], 0)
|
||||
self.assertEqual(token_info["completed"], 0)
|
||||
|
||||
def test_list_invalid_query_parameter(self):
|
||||
"""Test with `valid` query parameter not `true` or `false`."""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?valid=x",
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
def _test_list_query_parameter(self, valid: str):
|
||||
"""Helper used to test both valid=true and valid=false."""
|
||||
# Create 2 valid and 2 invalid tokens.
|
||||
now = self.hs.get_clock().time_msec()
|
||||
# Create always valid token
|
||||
valid1 = self._new_token()
|
||||
# Create token that hasn't been used up
|
||||
valid2 = self._new_token(uses_allowed=1)
|
||||
# Create token that has expired
|
||||
invalid1 = self._new_token(expiry_time=now - 10000)
|
||||
# Create token that has been used up but hasn't expired
|
||||
invalid2 = self._new_token(
|
||||
uses_allowed=2,
|
||||
pending=1,
|
||||
completed=1,
|
||||
expiry_time=now + 1000000,
|
||||
)
|
||||
|
||||
if valid == "true":
|
||||
tokens = [valid1, valid2]
|
||||
else:
|
||||
tokens = [invalid1, invalid2]
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?valid=" + valid,
|
||||
{},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
|
||||
token_info_1 = channel.json_body["registration_tokens"][0]
|
||||
token_info_2 = channel.json_body["registration_tokens"][1]
|
||||
self.assertIn(token_info_1["token"], tokens)
|
||||
self.assertIn(token_info_2["token"], tokens)
|
||||
|
||||
def test_list_valid(self):
|
||||
"""Test listing just valid tokens."""
|
||||
self._test_list_query_parameter(valid="true")
|
||||
|
||||
def test_list_invalid(self):
|
||||
"""Test listing just invalid tokens."""
|
||||
self._test_list_query_parameter(valid="false")
|
|
@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
|||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||
from synapse.storage._base import db_to_json
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_requires_token(self):
|
||||
username = "kermit"
|
||||
device_id = "frogfone"
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
params = {
|
||||
"username": username,
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
# Request without auth to get flows and session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
flows = channel.json_body["flows"]
|
||||
# Synapse adds a dummy stage to differentiate flows where otherwise one
|
||||
# flow would be a subset of another flow.
|
||||
self.assertCountEqual(
|
||||
[[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
|
||||
(f["stages"] for f in flows),
|
||||
)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Do the registration token stage and check it has completed
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
|
||||
# Do the m.login.dummy stage and check registration was successful
|
||||
params["auth"] = {
|
||||
"type": LoginType.DUMMY,
|
||||
"session": session,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
det_data = {
|
||||
"user_id": f"@{username}:{self.hs.hostname}",
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||
|
||||
# Check the `completed` counter has been incremented and pending is 0
|
||||
res = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEquals(res["completed"], 1)
|
||||
self.assertEquals(res["pending"], 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_invalid(self):
|
||||
params = {
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Test with token param missing (invalid)
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Test with non-string (invalid)
|
||||
params["auth"]["token"] = 1234
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Test with unknown token (invalid)
|
||||
params["auth"]["token"] = "1234"
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_limit_uses(self):
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
# Create token that can be used once
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
params1 = {"username": "bert", "password": "monkey"}
|
||||
params2 = {"username": "ernie", "password": "monkey"}
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with session1 and check `pending` is 1
|
||||
params1["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Repeat request to make sure pending isn't increased again
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
pending = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
)
|
||||
self.assertEquals(pending, 1)
|
||||
|
||||
# Check auth fails when using token with session2
|
||||
params2["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Check pending=0 and completed=1
|
||||
res = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
)
|
||||
)
|
||||
self.assertEquals(res["pending"], 0)
|
||||
self.assertEquals(res["completed"], 1)
|
||||
|
||||
# Check auth still fails when using token with session2
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_expiry(self):
|
||||
token = "abcd"
|
||||
now = self.hs.get_clock().time_msec()
|
||||
store = self.hs.get_datastore()
|
||||
# Create token that expired yesterday
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": now - 24 * 60 * 60 * 1000,
|
||||
},
|
||||
)
|
||||
)
|
||||
params = {"username": "kermit", "password": "monkey"}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Check authentication fails with expired token
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEquals(channel.json_body["completed"], [])
|
||||
|
||||
# Update token so it expires tomorrow
|
||||
self.get_success(
|
||||
store.db_pool.simple_update_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
|
||||
)
|
||||
)
|
||||
|
||||
# Check authentication succeeds
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_session_expiry(self):
|
||||
"""Test `pending` is decremented when an uncompleted session expires."""
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
params1 = {"username": "bert", "password": "monkey"}
|
||||
params2 = {"username": "ernie", "password": "monkey"}
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with both sessions
|
||||
params1["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
params2["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
# Check `result` of registration token stage for session1 is `True`
|
||||
result1 = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions_credentials",
|
||||
keyvalues={
|
||||
"session_id": session1,
|
||||
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||
},
|
||||
retcol="result",
|
||||
)
|
||||
)
|
||||
self.assertTrue(db_to_json(result1))
|
||||
|
||||
# Check `result` for session2 is the token used
|
||||
result2 = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"ui_auth_sessions_credentials",
|
||||
keyvalues={
|
||||
"session_id": session2,
|
||||
"stage_type": LoginType.REGISTRATION_TOKEN,
|
||||
},
|
||||
retcol="result",
|
||||
)
|
||||
)
|
||||
self.assertEquals(db_to_json(result2), token)
|
||||
|
||||
# Delete both sessions (mimics expiry)
|
||||
self.get_success(
|
||||
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||
)
|
||||
|
||||
# Check pending is now 0
|
||||
pending = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
)
|
||||
self.assertEquals(pending, 0)
|
||||
|
||||
@override_config({"registration_requires_token": True})
|
||||
def test_POST_registration_token_session_expiry_deleted_token(self):
|
||||
"""Test session expiry doesn't break when the token is deleted.
|
||||
|
||||
1. Start but don't complete UIA with a registration token
|
||||
2. Delete the token from the database
|
||||
3. Expire the session
|
||||
"""
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Do request without auth to get a session ID
|
||||
params = {"username": "kermit", "password": "monkey"}
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Use token
|
||||
params["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
self.make_request(b"POST", self.url, json.dumps(params))
|
||||
|
||||
# Delete token
|
||||
self.get_success(
|
||||
store.db_pool.simple_delete_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
)
|
||||
)
|
||||
|
||||
# Delete session (mimics expiry)
|
||||
self.get_success(
|
||||
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
|
||||
)
|
||||
|
||||
def test_advertised_flows(self):
|
||||
channel = self.make_request(b"POST", self.url, b"{}")
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
||||
self.assertLessEqual(res, now_ms + self.validity_period)
|
||||
|
||||
|
||||
class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register.register_servlets]
|
||||
url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["registration_requires_token"] = True
|
||||
return config
|
||||
|
||||
def test_GET_token_valid(self):
|
||||
token = "abcd"
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.db_pool.simple_insert(
|
||||
"registration_tokens",
|
||||
{
|
||||
"token": token,
|
||||
"uses_allowed": None,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertEquals(channel.json_body["valid"], True)
|
||||
|
||||
def test_GET_token_invalid(self):
|
||||
token = "1234"
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertEquals(channel.json_body["valid"], False)
|
||||
|
||||
@override_config(
|
||||
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
|
||||
)
|
||||
def test_GET_ratelimiting(self):
|
||||
token = "1234"
|
||||
|
||||
for i in range(0, 6):
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
|
||||
if i == 5:
|
||||
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
channel = self.make_request(
|
||||
b"GET",
|
||||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
|
Loading…
Reference in New Issue