Merge pull request #2722 from matrix-org/rav/delete_device_on_logout

Delete devices and pushers on logouts etc
This commit is contained in:
Richard van der Hoff 2017-11-29 17:56:46 +00:00 committed by GitHub
commit 7a48a6b63e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 226 additions and 64 deletions

View File

@ -649,41 +649,6 @@ class AuthHandler(BaseHandler):
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
"""Invalidate a single access token """Invalidate a single access token
@ -706,6 +671,12 @@ class AuthHandler(BaseHandler):
access_token=access_token, access_token=access_token,
) )
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"], )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None, def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None): device_id=None):
@ -728,13 +699,18 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): if hasattr(provider, "on_logged_out"):
for token, device_id in tokens_and_devices: for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out( yield provider.on_logged_out(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
access_token=token, access_token=token,
) )
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case. # 'Canonicalise' email addresses down to lower case.

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
# first delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)

View File

@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_ids = device_map.keys()
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids): def delete_devices(self, user_id, device_ids):
""" Delete several devices """ Delete several devices
Args: Args:
user_id (str): user_id (str):
device_ids (str): The list of device IDs to delete device_ids (List[str]): The list of device IDs to delete
Returns: Returns:
defer.Deferred: defer.Deferred:

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
from ._base import BaseHandler
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id,
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
@ -81,6 +82,7 @@ class ModuleApi(object):
reg = self.hs.get_handlers().registration_handler reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart) return reg.register(localpart=localpart)
@defer.inlineCallbacks
def invalidate_access_token(self, access_token): def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user """Invalidate an access token for a user
@ -94,8 +96,16 @@ class ModuleApi(object):
Raises: Raises:
synapse.api.errors.AuthError: the access token is invalid synapse.api.errors.AuthError: the access token is invalid
""" """
# see if the access token corresponds to a device
return self._auth_handler.delete_access_token(access_token) user_info = yield self._auth.get_user_by_access_token(access_token)
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
yield self.hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs): def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection """Run a function with a database connection

View File

@ -103,19 +103,25 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None): def remove_pushers_by_access_token(self, user_id, access_tokens):
all = yield self.store.get_all_pushers() """Remove the pushers for a given user corresponding to a set of
logger.info( access_tokens.
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id Args:
) user_id (str): user to remove pushers for
for p in all: access_tokens (Iterable[int]): access token *ids* to remove pushers
if p['user_name'] == user_id and p['access_token'] != except_access_token_id: for
"""
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p['access_token'] in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(
p['app_id'], p['pushkey'], p['user_name'],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):

View File

@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs) super(DeactivateAccountRestServlet, self).__init__(hs)
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, target_user_id): def on_POST(self, request, target_user_id):
@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self._auth_handler.deactivate_account(target_user_id) yield self._deactivate_account_handler.deactivate_account(target_user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs) super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, target_user_id): def on_POST(self, request, target_user_id):
@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
logger.info("new_password: %r", new_password) logger.info("new_password: %r", new_password)
yield self.auth_handler.set_password( yield self._set_password_handler.set_password(
target_user_id, new_password, requester target_user_id, new_password, requester
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__(hs)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
access_token = get_access_token_from_request(request) try:
yield self._auth_handler.delete_access_token(access_token) requester = yield self.auth.get_user_by_req(request)
except AuthError:
# this implies the access token has already been deleted.
pass
else:
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
requester.user.to_string(), requester.device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -98,6 +98,7 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -147,7 +148,7 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self.auth_handler.set_password( yield self._set_password_handler.set_password(
user_id, new_password, requester user_id, new_password, requester
) )
@ -161,10 +162,11 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -179,7 +181,7 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester and requester.app_service: if requester and requester.app_service:
yield self.auth_handler.deactivate_account( yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string() requester.user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -206,7 +208,7 @@ class DeactivateAccountRestServlet(RestServlet):
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
yield self.auth_handler.deactivate_account(user_id) yield self._deactivate_account_handler.deactivate_account(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -39,11 +39,13 @@ from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
@ -115,6 +117,8 @@ class HomeServer(object):
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler', 'profile_handler',
'deactivate_account_handler',
'set_password_handler',
'notifier', 'notifier',
'event_sources', 'event_sources',
'keyring', 'keyring',
@ -268,6 +272,12 @@ class HomeServer(object):
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) return ProfileHandler(self)
def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self)
def build_set_password_handler(self):
return SetPasswordHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View File

@ -3,11 +3,14 @@ import synapse.federation.transaction_queue
import synapse.federation.transport.client import synapse.federation.transport.client
import synapse.handlers import synapse.handlers
import synapse.handlers.auth import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device import synapse.handlers.device
import synapse.handlers.e2e_keys import synapse.handlers.e2e_keys
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository import synapse.rest.media.v1.media_repository
import synapse.storage
import synapse.state import synapse.state
import synapse.storage
class HomeServer(object): class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth: def get_auth(self) -> synapse.api.auth.Auth:
@ -31,6 +34,12 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
pass pass

View File

@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns: Returns:
defer.Deferred[list[str, str|None]]: a list of the deleted tokens defer.Deferred[list[str, int, str|None, int]]: a list of
and device IDs (token, token id, device id) for each of the deleted tokens
""" """
def f(txn): def f(txn):
keyvalues = { keyvalues = {
@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id) values.append(except_token_id)
txn.execute( txn.execute(
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause, "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values values
) )
tokens_and_devices = [(r[0], r[1]) for r in txn] tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
for token, _ in tokens_and_devices: for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,) txn, self.get_user_by_access_token, (token,)
) )