Merge pull request #949 from matrix-org/rav/update_devices
Implement updates and deletes for devices
This commit is contained in:
commit
d34e9f93b7
|
@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
|
||||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -374,7 +375,8 @@ class AuthHandler(BaseHandler):
|
||||||
return self._check_password(user_id, password)
|
return self._check_password(user_id, password)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_login_tuple_for_user_id(self, user_id, device_id=None):
|
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||||
|
initial_display_name=None):
|
||||||
"""
|
"""
|
||||||
Gets login tuple for the user with the given user ID.
|
Gets login tuple for the user with the given user ID.
|
||||||
|
|
||||||
|
@ -383,9 +385,15 @@ class AuthHandler(BaseHandler):
|
||||||
The user is assumed to have been authenticated by some other
|
The user is assumed to have been authenticated by some other
|
||||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||||
|
|
||||||
|
The device will be recorded in the table if it is not there already.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): canonical User ID
|
user_id (str): canonical User ID
|
||||||
device_id (str): the device ID to associate with the access token
|
device_id (str|None): the device ID to associate with the tokens.
|
||||||
|
None to leave the tokens unassociated with a device (deprecated:
|
||||||
|
we should always have a device ID)
|
||||||
|
initial_display_name (str): display name to associate with the
|
||||||
|
device if it needs re-registering
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
|
@ -397,6 +405,16 @@ class AuthHandler(BaseHandler):
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||||
|
|
||||||
|
# the device *should* have been registered before we got here; however,
|
||||||
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
|
# really don't want is active access_tokens without a record of the
|
||||||
|
# device, so we double-check it here.
|
||||||
|
if device_id is not None:
|
||||||
|
yield self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((access_token, refresh_token))
|
defer.returnValue((access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str):
|
user_id (str):
|
||||||
device_id (str)
|
device_id (str):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: dict[str, X]: info on the device
|
defer.Deferred: dict[str, X]: info on the device
|
||||||
|
@ -117,6 +117,55 @@ class DeviceHandler(BaseHandler):
|
||||||
_update_device_from_client_ips(device, ips)
|
_update_device_from_client_ips(device, ips)
|
||||||
defer.returnValue(device)
|
defer.returnValue(device)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
""" Delete the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.delete_device(user_id, device_id)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
# no match
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield self.store.user_delete_access_tokens(user_id,
|
||||||
|
device_id=device_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_device(self, user_id, device_id, content):
|
||||||
|
""" Update the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
content (dict): body of update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.update_device(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
new_display_name=content.get("display_name")
|
||||||
|
)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
raise errors.NotFoundError()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _update_device_from_client_ips(device, client_ips):
|
def _update_device_from_client_ips(device, client_ips):
|
||||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
self._PathEntry(path_pattern, callback)
|
self._PathEntry(path_pattern, callback)
|
||||||
)
|
)
|
||||||
|
|
|
@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
registered_user_id, device_id
|
registered_user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
|
|
|
@ -13,19 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http import servlet
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DevicesRestServlet(RestServlet):
|
class DevicesRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
||||||
defer.returnValue((200, {"devices": devices}))
|
defer.returnValue((200, {"devices": devices}))
|
||||||
|
|
||||||
|
|
||||||
class DeviceRestServlet(RestServlet):
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||||
releases=[], v2_alpha=False)
|
releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, device))
|
defer.returnValue((200, device))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, device_id):
|
||||||
|
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||||
|
# It allows the client to delete access tokens, which feels like a
|
||||||
|
# thing which merits extra auth. But if we want to do the interactive-
|
||||||
|
# auth dance, we should really make it possible to delete more than one
|
||||||
|
# device at a time.
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
yield self.device_handler.delete_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
yield self.device_handler.update_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
body
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
DevicesRestServlet(hs).register(http_server)
|
DevicesRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
device_id = yield self._register_device(user_id, params)
|
device_id = yield self._register_device(user_id, params)
|
||||||
|
|
||||||
access_token = yield self.auth_handler.issue_access_token(
|
access_token, refresh_token = (
|
||||||
user_id, device_id=device_id
|
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
initial_display_name=params.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
|
||||||
user_id, device_id=device_id
|
|
||||||
)
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
|
|
@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
|
||||||
desc="get_device",
|
desc="get_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
"""Delete a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to delete
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
return self._simple_delete_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_device(self, user_id, device_id, new_display_name=None):
|
||||||
|
"""Update a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to update
|
||||||
|
new_display_name (str|None): new displayname for device; None
|
||||||
|
to leave unchanged
|
||||||
|
Raises:
|
||||||
|
StoreError: if the device is not found
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
updates = {}
|
||||||
|
if new_display_name is not None:
|
||||||
|
updates["display_name"] = new_display_name
|
||||||
|
if not updates:
|
||||||
|
return defer.succeed(None)
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
updatevalues=updates,
|
||||||
|
desc="update_device",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_devices_by_user(self, user_id):
|
def get_devices_by_user(self, user_id):
|
||||||
"""Retrieve all of a user's registered devices.
|
"""Retrieve all of a user's registered devices.
|
||||||
|
|
|
@ -18,18 +18,31 @@ import re
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
from synapse.storage import background_updates
|
||||||
from ._base import SQLBaseStore
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RegistrationStore, self).__init__(hs)
|
super(RegistrationStore, self).__init__(hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"access_tokens_device_index",
|
||||||
|
index_name="access_tokens_device_id",
|
||||||
|
table="access_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"refresh_tokens_device_index",
|
||||||
|
index_name="refresh_tokens_device_id",
|
||||||
|
table="refresh_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore):
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
self.get_user_by_id.invalidate((user_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
|
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||||
|
device_id=None):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
||||||
clauses = [user_id]
|
clauses = [user_id]
|
||||||
|
|
||||||
|
if device_id is not None:
|
||||||
|
sql += " AND device_id = ?"
|
||||||
|
clauses.append(device_id)
|
||||||
|
|
||||||
if except_token_ids:
|
if except_token_ids:
|
||||||
sql += " AND id NOT IN (%s)" % (
|
sql += " AND id NOT IN (%s)" % (
|
||||||
",".join(["?" for _ in except_token_ids]),
|
",".join(["?" for _ in except_token_ids]),
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('access_tokens_device_index', '{}');
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('refresh_tokens_device_index', '{}');
|
|
@ -12,11 +12,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse import types
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
import synapse.handlers.device
|
import synapse.handlers.device
|
||||||
|
|
||||||
import synapse.storage
|
import synapse.storage
|
||||||
|
from synapse import types
|
||||||
from tests import unittest, utils
|
from tests import unittest, utils
|
||||||
|
|
||||||
user1 = "@boris:aaa"
|
user1 = "@boris:aaa"
|
||||||
|
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||||
self.store = None # type: synapse.storage.DataStore
|
self.store = None # type: synapse.storage.DataStore
|
||||||
self.handler = None # type: device.DeviceHandler
|
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||||
self.clock = None # type: utils.MockClock
|
self.clock = None # type: utils.MockClock
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -123,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
"last_seen_ts": 3000000,
|
"last_seen_ts": 3000000,
|
||||||
}, res)
|
}, res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_delete_device(self):
|
||||||
|
yield self._record_users()
|
||||||
|
|
||||||
|
# delete the device
|
||||||
|
yield self.handler.delete_device(user1, "abc")
|
||||||
|
|
||||||
|
# check the device was deleted
|
||||||
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
|
yield self.handler.get_device(user1, "abc")
|
||||||
|
|
||||||
|
# we'd like to check the access token was invalidated, but that's a
|
||||||
|
# bit of a PITA.
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self._record_users()
|
||||||
|
|
||||||
|
update = {"display_name": "new display"}
|
||||||
|
yield self.handler.update_device(user1, "abc", update)
|
||||||
|
|
||||||
|
res = yield self.handler.get_device(user1, "abc")
|
||||||
|
self.assertEqual(res["display_name"], "new display")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
update = {"display_name": "new_display"}
|
||||||
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
|
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||||
|
update)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _record_users(self):
|
def _record_users(self):
|
||||||
# check this works for both devices which have a recorded client_ip,
|
# check this works for both devices which have a recorded client_ip,
|
||||||
|
|
|
@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.registration_handler.appservice_register = Mock(
|
self.registration_handler.appservice_register = Mock(
|
||||||
return_value=user_id
|
return_value=user_id
|
||||||
)
|
)
|
||||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||||
|
return_value=(token, "kermits_refresh_token")
|
||||||
|
)
|
||||||
|
|
||||||
(code, result) = yield self.servlet.on_POST(self.request)
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
|
"refresh_token": "kermits_refresh_token",
|
||||||
"home_server": self.hs.hostname
|
"home_server": self.hs.hostname
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
|
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||||
|
return_value=(token, "kermits_refresh_token")
|
||||||
|
)
|
||||||
self.device_handler.check_device_registered = \
|
self.device_handler.check_device_registered = \
|
||||||
Mock(return_value=device_id)
|
Mock(return_value=device_id)
|
||||||
|
|
||||||
|
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
|
"refresh_token": "kermits_refresh_token",
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
self.assertIn("refresh_token", result)
|
self.assertIn("refresh_token", result)
|
||||||
self.auth_handler.issue_access_token.assert_called_once_with(
|
self.auth_handler.get_login_tuple_for_user_id(
|
||||||
user_id, device_id=device_id)
|
user_id, device_id=device_id, initial_device_display_name=None)
|
||||||
|
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self):
|
||||||
self.hs.config.enable_registration = False
|
self.hs.config.enable_registration = False
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
import tests.utils
|
import tests.utils
|
||||||
|
|
||||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
"device_id": "device2",
|
"device_id": "device2",
|
||||||
"display_name": "display_name 2",
|
"display_name": "display_name 2",
|
||||||
}, res["device2"])
|
}, res["device2"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self.store.store_device(
|
||||||
|
"user_id", "device_id", "display_name 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do a no-op first
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
)
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do the update
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check it worked
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 2", res["display_name"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "unknown_device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
self.assertEqual(404, cm.exception.code)
|
||||||
|
|
Loading…
Reference in New Issue