diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 433237c204..6688fa8fa0 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import ( push_rule, register as v1_register, login as v1_login, + logout, ) from synapse.rest.client.v2_alpha import ( @@ -72,6 +73,7 @@ class ClientRestResource(JsonResource): admin.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) + logout.register_servlets(hs, client_resource) # "v2" sync.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py new file mode 100644 index 0000000000..9bff02ee4e --- /dev/null +++ b/synapse/rest/client/v1/logout.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, Codes + +from .base import ClientV1RestServlet, client_path_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class LogoutRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout$") + + def __init__(self, hs): + super(LogoutRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + try: + access_token = request.args["access_token"][0] + except KeyError: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + yield self.store.delete_access_token(access_token) + defer.returnValue((200, {})) + + +class LogoutAllRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout/all$") + + def __init__(self, hs): + super(LogoutAllRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + yield self.store.user_delete_access_tokens(user_id) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + LogoutRestServlet(hs).register(http_server) + LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5e7a4e371d..bd4eb88a92 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -195,24 +195,48 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids): + def user_delete_access_tokens(self, user_id, except_token_ids=[]): def f(txn): - txn.execute( - "SELECT id, token FROM access_tokens " - "WHERE user_id = ? AND id NOT IN ? LIMIT 50", - (user_id, except_token_ids) - ) + sql = "SELECT token FROM access_tokens WHERE user_id = ?" + clauses = [user_id] + + if except_token_ids: + sql += " AND id NOT IN (%s)" % ( + ",".join(["?" for _ in except_token_ids]), + ) + clauses += except_token_ids + + txn.execute(sql, clauses) + rows = txn.fetchall() - for r in rows: - txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) - txn.execute( - "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( - ["?" for _ in rows] - ), [r[0] for r in rows] + + n = 100 + chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] + for chunk in chunks: + for row in chunk: + txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + + txn.execute( + "DELETE FROM access_tokens WHERE token in (%s)" % ( + ",".join(["?" for _ in chunk]), + ), [r[0] for r in chunk] + ) + + yield self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, + table="access_tokens", + keyvalues={ + "token": access_token + }, ) - return len(rows) == 50 - while (yield self.runInteraction("user_delete_access_tokens", f)): - pass + + txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + + return self.runInteraction("delete_access_token", f) @cached() def get_user_by_access_token(self, token):