Implement _simple_delete_many_txn, use it to delete devices

(But this doesn't implement the same for deleting access tokens or e2e keys.

Also respond to code review.
This commit is contained in:
Luke Barnard 2017-03-13 17:53:23 +00:00
parent c077c3277b
commit bbeeb97f75
4 changed files with 101 additions and 11 deletions

View File

@ -169,6 +169,40 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
device_ids (str): The list of device IDs to delete
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_devices(user_id, device_ids)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device

View File

@ -47,13 +47,13 @@ class DevicesRestServlet(servlet.RestServlet):
class DeleteDevicesRestServlet(servlet.RestServlet): class DeleteDevicesRestServlet(servlet.RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeleteDevicesRestServlet, self).__init__() super(DeleteDevicesRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -64,14 +64,13 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
def on_POST(self, request): def on_POST(self, request):
try: try:
body = servlet.parse_json_object_from_request(request) body = servlet.parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict # deal with older clients which didn't pass a J*DELETESON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = {}
else: else:
raise raise e
if 'devices' not in body: if 'devices' not in body:
raise errors.SynapseError( raise errors.SynapseError(
@ -86,11 +85,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((401, result)) defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
for d_id in body['devices']: yield self.device_handler.delete_devices(
yield self.device_handler.delete_device( requester.user.to_string(),
requester.user.to_string(), body['devices'],
d_id, )
)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -840,6 +840,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
txn : Transaction object
table : string giving the table name
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
"""
if not iterable:
return
sql = "DELETE FROM %s" % table
clauses = []
values = []
clauses.append(
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
values.extend(iterable)
for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
if clauses:
sql = "%s WHERE %s" % (
sql,
" AND ".join(clauses),
)
return txn.execute(sql, values)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000): max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.

View File

@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
desc="delete_device", desc="delete_device",
) )
def delete_devices(self, user_id, device_ids):
"""Deletes several devices.
Args:
user_id (str): The ID of the user which owns the devices
device_ids (list): The IDs of the devices to delete
Returns:
defer.Deferred
"""
return self._simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
keyvalues={"user_id": user_id},
desc="delete_devices",
)
def update_device(self, user_id, device_id, new_display_name=None): def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device. """Update a device.