Split out registration to worker

This allows registration to be handled by a worker, though the actual
write to the database still happens on master.

Note: due to the in-memory session map all registration requests must be
handled by the same worker.
This commit is contained in:
Erik Johnston 2019-02-18 12:12:57 +00:00
parent 4151111d95
commit eb2b8523ae
7 changed files with 401 additions and 147 deletions

View File

@ -47,6 +47,7 @@ from synapse.rest.client.v1.room import (
RoomMemberListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@ -92,6 +93,7 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,

View File

@ -27,6 +27,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.http.client import CaptchaServerHttpClient
from synapse.replication.http.register import ReplicationRegisterServlet
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed
@ -61,6 +62,9 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
@ -185,7 +189,7 @@ class RegistrationHandler(BaseHandler):
token = None
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -217,7 +221,7 @@ class RegistrationHandler(BaseHandler):
if default_display_name is None:
default_display_name = localpart
try:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -316,7 +320,7 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
@ -494,7 +498,7 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -573,3 +577,54 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
def _register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None):
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
Returns:
Deferred
"""
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
else:
return self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
from synapse.http.server import JsonResource
from synapse.replication.http import federation, membership, send_event
from synapse.replication.http import federation, login, membership, register, send_event
REPLICATION_PREFIX = "/_synapse/replication"
@ -28,3 +28,5 @@ class ReplicationRestResource(JsonResource):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"""Ensure a device is registered, generating a new access token for the
device.
Used during registration and login.
"""
NAME = "device_check_registered"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(RegisterDeviceReplicationServlet, self).__init__(hs)
self.auth_handler = hs.get_auth_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
generated.
initial_display_name (str|None)
is_guest (bool)
"""
return {
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name,
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((200, {
"device_id": device_id,
"access_token": access_token,
}))
def register_servlets(hs, http_server):
RegisterDeviceReplicationServlet(hs).register(http_server)

View File

@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationRegisterServlet(ReplicationEndpoint):
"""Register a new user
"""
NAME = "register_user"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(ReplicationRegisterServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(
user_id, token, password_hash, was_guest, make_guest, appservice_id,
create_profile_with_displayname, admin, user_type,
):
"""
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
"""
return {
"token": token,
"password_hash": password_hash,
"was_guest": was_guest,
"make_guest": make_guest,
"appservice_id": appservice_id,
"create_profile_with_displayname": create_profile_with_displayname,
"admin": admin,
"user_type": user_type,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
yield self.store.register(
user_id=user_id,
token=content["token"],
password_hash=content["password_hash"],
was_guest=content["was_guest"],
make_guest=content["make_guest"],
appservice_id=content["appservice_id"],
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationRegisterServlet(hs).register(http_server)

View File

@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed
@ -190,9 +191,15 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
if self.hs.config.worker_app:
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
@ -633,12 +640,10 @@ class RegisterRestServlet(RestServlet):
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
device_id = yield self._register_device(user_id, params)
access_token = (
yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self._register_device(
user_id, device_id, initial_display_name, is_guest=False,
)
result.update({
@ -647,25 +652,42 @@ class RegisterRestServlet(RestServlet):
})
defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
@defer.inlineCallbacks
def _register_device(self, user_id, device_id, initial_display_name,
is_guest):
"""Register a device for a user and generate an access token.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred: (str) device_id
defer.Deferred[(str, str)]: Tuple of device ID and access token
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
def _do_guest_registration(self, params):
@ -680,13 +702,10 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
device_id, access_token = yield self._register_device(
user_id, device_id, initial_display_name, is_guest=True,
)
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
defer.returnValue((200, {
"user_id": user_id,
"device_id": device_id,

View File

@ -139,6 +139,121 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@ -326,20 +441,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@ -564,107 +665,6 @@ class RegistrationStore(RegistrationWorkerStore,
desc="user_delete_threepids",
)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id