Prevent multiple device list updates from breaking a batch send (#5156)
fixes #5153
This commit is contained in:
parent
a11865016e
commit
2d1d7b7e6f
|
@ -0,0 +1 @@
|
||||||
|
Prevent federation device list updates breaking when processing multiple updates at once.
|
|
@ -349,9 +349,10 @@ class PerDestinationQueue(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_new_device_messages(self, limit):
|
def _get_new_device_messages(self, limit):
|
||||||
last_device_list = self._last_device_list_stream_id
|
last_device_list = self._last_device_list_stream_id
|
||||||
# Will return at most 20 entries
|
|
||||||
|
# Retrieve list of new device updates to send to the destination
|
||||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
now_stream_id, results = yield self._store.get_devices_by_remote(
|
||||||
self._destination, last_device_list
|
self._destination, last_device_list, limit=limit,
|
||||||
)
|
)
|
||||||
edus = [
|
edus = [
|
||||||
Edu(
|
Edu(
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from six import iteritems, itervalues
|
from six import iteritems
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
|
@ -72,11 +72,14 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue({d["device_id"]: d for d in devices})
|
defer.returnValue({d["device_id"]: d for d in devices})
|
||||||
|
|
||||||
def get_devices_by_remote(self, destination, from_stream_id):
|
@defer.inlineCallbacks
|
||||||
|
def get_devices_by_remote(self, destination, from_stream_id, limit):
|
||||||
"""Get stream of updates to send to remote servers
|
"""Get stream of updates to send to remote servers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(int, list[dict]): current stream id and list of updates
|
Deferred[tuple[int, list[dict]]]:
|
||||||
|
current stream id (ie, the stream id of the last update included in the
|
||||||
|
response), and the list of updates
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -84,55 +87,131 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
destination, int(from_stream_id)
|
destination, int(from_stream_id)
|
||||||
)
|
)
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return (now_stream_id, [])
|
defer.returnValue((now_stream_id, []))
|
||||||
|
|
||||||
return self.runInteraction(
|
# We retrieve n+1 devices from the list of outbound pokes where n is
|
||||||
|
# our outbound device update limit. We then check if the very last
|
||||||
|
# device has the same stream_id as the second-to-last device. If so,
|
||||||
|
# then we ignore all devices with that stream_id and only send the
|
||||||
|
# devices with a lower stream_id.
|
||||||
|
#
|
||||||
|
# If when culling the list we end up with no devices afterwards, we
|
||||||
|
# consider the device update to be too large, and simply skip the
|
||||||
|
# stream_id; the rationale being that such a large device list update
|
||||||
|
# is likely an error.
|
||||||
|
updates = yield self.runInteraction(
|
||||||
"get_devices_by_remote",
|
"get_devices_by_remote",
|
||||||
self._get_devices_by_remote_txn,
|
self._get_devices_by_remote_txn,
|
||||||
destination,
|
destination,
|
||||||
from_stream_id,
|
from_stream_id,
|
||||||
now_stream_id,
|
now_stream_id,
|
||||||
|
limit + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_devices_by_remote_txn(
|
# Return an empty list if there are no updates
|
||||||
self, txn, destination, from_stream_id, now_stream_id
|
if not updates:
|
||||||
):
|
defer.returnValue((now_stream_id, []))
|
||||||
sql = """
|
|
||||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
|
||||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
|
||||||
GROUP BY user_id, device_id
|
|
||||||
LIMIT 20
|
|
||||||
"""
|
|
||||||
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
|
|
||||||
|
|
||||||
|
# if we have exceeded the limit, we need to exclude any results with the
|
||||||
|
# same stream_id as the last row.
|
||||||
|
if len(updates) > limit:
|
||||||
|
stream_id_cutoff = updates[-1][2]
|
||||||
|
now_stream_id = stream_id_cutoff - 1
|
||||||
|
else:
|
||||||
|
stream_id_cutoff = None
|
||||||
|
|
||||||
|
# Perform the equivalent of a GROUP BY
|
||||||
|
#
|
||||||
|
# Iterate through the updates list and copy non-duplicate
|
||||||
|
# (user_id, device_id) entries into a map, with the value being
|
||||||
|
# the max stream_id across each set of duplicate entries
|
||||||
|
#
|
||||||
# maps (user_id, device_id) -> stream_id
|
# maps (user_id, device_id) -> stream_id
|
||||||
query_map = {(r[0], r[1]): r[2] for r in txn}
|
# as long as their stream_id does not match that of the last row
|
||||||
|
query_map = {}
|
||||||
|
for update in updates:
|
||||||
|
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
|
||||||
|
# Stop processing updates
|
||||||
|
break
|
||||||
|
|
||||||
|
key = (update[0], update[1])
|
||||||
|
query_map[key] = max(query_map.get(key, 0), update[2])
|
||||||
|
|
||||||
|
# If we didn't find any updates with a stream_id lower than the cutoff, it
|
||||||
|
# means that there are more than limit updates all of which have the same
|
||||||
|
# steam_id.
|
||||||
|
|
||||||
|
# That should only happen if a client is spamming the server with new
|
||||||
|
# devices, in which case E2E isn't going to work well anyway. We'll just
|
||||||
|
# skip that stream_id and return an empty list, and continue with the next
|
||||||
|
# stream_id next time.
|
||||||
if not query_map:
|
if not query_map:
|
||||||
return (now_stream_id, [])
|
defer.returnValue((stream_id_cutoff, []))
|
||||||
|
|
||||||
if len(query_map) >= 20:
|
results = yield self._get_device_update_edus_by_remote(
|
||||||
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
destination,
|
||||||
|
from_stream_id,
|
||||||
|
query_map,
|
||||||
|
)
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(
|
defer.returnValue((now_stream_id, results))
|
||||||
txn,
|
|
||||||
|
def _get_devices_by_remote_txn(
|
||||||
|
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||||
|
):
|
||||||
|
"""Return device update information for a given remote destination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (LoggingTransaction): The transaction to execute
|
||||||
|
destination (str): The host the device updates are intended for
|
||||||
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||||
|
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
|
||||||
|
limit (int): Maximum number of device updates to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: List of device updates
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
|
||||||
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||||
|
ORDER BY stream_id
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
|
||||||
|
|
||||||
|
return list(txn)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_device_update_edus_by_remote(
|
||||||
|
self, destination, from_stream_id, query_map,
|
||||||
|
):
|
||||||
|
"""Returns a list of device update EDUs as well as E2EE keys
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The host the device updates are intended for
|
||||||
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||||
|
query_map (Dict[(str, str): int]): Dictionary mapping
|
||||||
|
user_id/device_id to update stream_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: List of objects representing an device update EDU
|
||||||
|
|
||||||
|
"""
|
||||||
|
devices = yield self.runInteraction(
|
||||||
|
"_get_e2e_device_keys_txn",
|
||||||
|
self._get_e2e_device_keys_txn,
|
||||||
query_map.keys(),
|
query_map.keys(),
|
||||||
include_all_devices=True,
|
include_all_devices=True,
|
||||||
include_deleted_devices=True,
|
include_deleted_devices=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_sent_id_sql = """
|
|
||||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
|
||||||
FROM device_lists_outbound_last_success
|
|
||||||
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for user_id, user_devices in iteritems(devices):
|
for user_id, user_devices in iteritems(devices):
|
||||||
# The prev_id for the first row is always the last row before
|
# The prev_id for the first row is always the last row before
|
||||||
# `from_stream_id`
|
# `from_stream_id`
|
||||||
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
|
prev_id = yield self._get_last_device_update_for_remote_user(
|
||||||
rows = txn.fetchall()
|
destination, user_id, from_stream_id,
|
||||||
prev_id = rows[0][0]
|
)
|
||||||
for device_id, device in iteritems(user_devices):
|
for device_id, device in iteritems(user_devices):
|
||||||
stream_id = query_map[(user_id, device_id)]
|
stream_id = query_map[(user_id, device_id)]
|
||||||
result = {
|
result = {
|
||||||
|
@ -156,7 +235,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return (now_stream_id, results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def _get_last_device_update_for_remote_user(
|
||||||
|
self, destination, user_id, from_stream_id,
|
||||||
|
):
|
||||||
|
def f(txn):
|
||||||
|
prev_sent_id_sql = """
|
||||||
|
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||||
|
FROM device_lists_outbound_last_success
|
||||||
|
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
||||||
|
"""
|
||||||
|
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
return rows[0][0]
|
||||||
|
|
||||||
|
return self.runInteraction("get_last_device_update_for_remote_user", f)
|
||||||
|
|
||||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||||
"""Mark that updates have successfully been sent to the destination.
|
"""Mark that updates have successfully been sent to the destination.
|
||||||
|
|
|
@ -71,6 +71,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
res["device2"],
|
res["device2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_devices_by_remote(self):
|
||||||
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
|
# Add two device updates with a single stream_id
|
||||||
|
yield self.store.add_device_change_to_streams(
|
||||||
|
"user_id", device_ids, ["somehost"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all device updates ever meant for this remote
|
||||||
|
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||||
|
"somehost", -1, limit=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check original device_ids are contained within these updates
|
||||||
|
self._check_devices_in_updates(device_ids, device_updates)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_devices_by_remote_limited(self):
|
||||||
|
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
||||||
|
|
||||||
|
# first add one device
|
||||||
|
device_ids1 = ["device_id0"]
|
||||||
|
yield self.store.add_device_change_to_streams(
|
||||||
|
"user_id", device_ids1, ["someotherhost"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# then add 101
|
||||||
|
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
|
||||||
|
yield self.store.add_device_change_to_streams(
|
||||||
|
"user_id", device_ids2, ["someotherhost"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# then one more
|
||||||
|
device_ids3 = ["newdevice"]
|
||||||
|
yield self.store.add_device_change_to_streams(
|
||||||
|
"user_id", device_ids3, ["someotherhost"],
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# now read them back.
|
||||||
|
#
|
||||||
|
|
||||||
|
# first we should get a single update
|
||||||
|
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||||
|
"someotherhost", -1, limit=100,
|
||||||
|
)
|
||||||
|
self._check_devices_in_updates(device_ids1, device_updates)
|
||||||
|
|
||||||
|
# Then we should get an empty list back as the 101 devices broke the limit
|
||||||
|
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||||
|
"someotherhost", now_stream_id, limit=100,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(device_updates), 0)
|
||||||
|
|
||||||
|
# The 101 devices should've been cleared, so we should now just get one device
|
||||||
|
# update
|
||||||
|
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||||
|
"someotherhost", now_stream_id, limit=100,
|
||||||
|
)
|
||||||
|
self._check_devices_in_updates(device_ids3, device_updates)
|
||||||
|
|
||||||
|
def _check_devices_in_updates(self, expected_device_ids, device_updates):
|
||||||
|
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||||
|
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||||
|
|
||||||
|
received_device_ids = {update["device_id"] for update in device_updates}
|
||||||
|
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_device(self):
|
def test_update_device(self):
|
||||||
yield self.store.store_device("user_id", "device_id", "display_name 1")
|
yield self.store.store_device("user_id", "device_id", "display_name 1")
|
||||||
|
|
Loading…
Reference in New Issue