Fix limit logic for AccountDataStream (#7384)
Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream.
This commit is contained in:
parent
34a43f0084
commit
6c1f7c722f
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
|
@ -14,14 +14,27 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import synapse.server
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# the number of rows to request from an update_function.
|
# the number of rows to request from an update_function.
|
||||||
|
@ -37,7 +50,7 @@ Token = int
|
||||||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||||
# just a row from a database query, though this is dependent on the stream in question.
|
# just a row from a database query, though this is dependent on the stream in question.
|
||||||
#
|
#
|
||||||
StreamRow = Tuple
|
StreamRow = TypeVar("StreamRow", bound=Tuple)
|
||||||
|
|
||||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||||
# get_updates_since, etc.
|
# get_updates_since, etc.
|
||||||
|
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
AccountDataStreamRow = namedtuple(
|
AccountDataStreamRow = namedtuple(
|
||||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
"AccountDataStream",
|
||||||
|
("user_id", "room_id", "data_type"), # str # Optional[str] # str
|
||||||
)
|
)
|
||||||
|
|
||||||
NAME = "account_data"
|
NAME = "account_data"
|
||||||
ROW_TYPE = AccountDataStreamRow
|
ROW_TYPE = AccountDataStreamRow
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||||
db_query_to_update_function(self._update_function),
|
self._update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_function(self, from_token, to_token, limit):
|
async def _update_function(
|
||||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||||
from_token, from_token, to_token, limit
|
) -> StreamUpdateResult:
|
||||||
|
limited = False
|
||||||
|
global_results = await self.store.get_updated_global_account_data(
|
||||||
|
from_token, to_token, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
results = list(room_results)
|
# if the global results hit the limit, we'll need to limit the room results to
|
||||||
results.extend(
|
# the same stream token.
|
||||||
(stream_id, user_id, None, account_data_type)
|
if len(global_results) >= limit:
|
||||||
|
to_token = global_results[-1][0]
|
||||||
|
limited = True
|
||||||
|
|
||||||
|
room_results = await self.store.get_updated_room_account_data(
|
||||||
|
from_token, to_token, limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# likewise, if the room results hit the limit, limit the global results to
|
||||||
|
# the same stream token.
|
||||||
|
if len(room_results) >= limit:
|
||||||
|
to_token = room_results[-1][0]
|
||||||
|
limited = True
|
||||||
|
|
||||||
|
# convert the global results to the right format, and limit them to the to_token
|
||||||
|
# at the same time
|
||||||
|
global_rows = (
|
||||||
|
(stream_id, (user_id, None, account_data_type))
|
||||||
for stream_id, user_id, account_data_type in global_results
|
for stream_id, user_id, account_data_type in global_results
|
||||||
|
if stream_id <= to_token
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
# we know that the room_results are already limited to `to_token` so no need
|
||||||
|
# for a check on `stream_id` here.
|
||||||
|
room_rows = (
|
||||||
|
(stream_id, (user_id, room_id, account_data_type))
|
||||||
|
for stream_id, user_id, room_id, account_data_type in room_results
|
||||||
|
)
|
||||||
|
|
||||||
|
# we need to return a sorted list, so merge them together.
|
||||||
|
updates = list(heapq.merge(room_rows, global_rows))
|
||||||
|
return updates, to_token, limited
|
||||||
|
|
||||||
|
|
||||||
class GroupServerStream(Stream):
|
class GroupServerStream(Stream):
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
|
@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_updated_account_data(
|
async def get_updated_global_account_data(
|
||||||
self, last_global_id, last_room_id, current_id, limit
|
self, last_id: int, current_id: int, limit: int
|
||||||
):
|
) -> List[Tuple[int, str, str]]:
|
||||||
"""Get all the client account_data that has changed on the server
|
"""Get the global account_data that has changed, for the account_data stream
|
||||||
Args:
|
|
||||||
last_global_id(int): The position to fetch from for top level data
|
|
||||||
last_room_id(int): The position to fetch from for per room data
|
|
||||||
current_id(int): The position to fetch up to.
|
|
||||||
Returns:
|
|
||||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
|
||||||
room_id string, and type string.
|
|
||||||
"""
|
|
||||||
if last_room_id == current_id and last_global_id == current_id:
|
|
||||||
return defer.succeed(([], []))
|
|
||||||
|
|
||||||
def get_updated_account_data_txn(txn):
|
Args:
|
||||||
|
last_id: the last stream_id from the previous batch.
|
||||||
|
current_id: the maximum stream_id to return up to
|
||||||
|
limit: the maximum number of rows to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of stream_id int, user_id string,
|
||||||
|
and type string.
|
||||||
|
"""
|
||||||
|
if last_id == current_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_updated_global_account_data_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id, user_id, account_data_type"
|
"SELECT stream_id, user_id, account_data_type"
|
||||||
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
|
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||||
" ORDER BY stream_id ASC LIMIT ?"
|
" ORDER BY stream_id ASC LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_global_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
global_results = txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
|
return await self.db.runInteraction(
|
||||||
|
"get_updated_global_account_data", get_updated_global_account_data_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_updated_room_account_data(
|
||||||
|
self, last_id: int, current_id: int, limit: int
|
||||||
|
) -> List[Tuple[int, str, str, str]]:
|
||||||
|
"""Get the global account_data that has changed, for the account_data stream
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_id: the last stream_id from the previous batch.
|
||||||
|
current_id: the maximum stream_id to return up to
|
||||||
|
limit: the maximum number of rows to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of stream_id int, user_id string,
|
||||||
|
room_id string and type string.
|
||||||
|
"""
|
||||||
|
if last_id == current_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_updated_room_account_data_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id, user_id, room_id, account_data_type"
|
"SELECT stream_id, user_id, room_id, account_data_type"
|
||||||
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
|
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||||
" ORDER BY stream_id ASC LIMIT ?"
|
" ORDER BY stream_id ASC LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_room_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
room_results = txn.fetchall()
|
return txn.fetchall()
|
||||||
return global_results, room_results
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
return await self.db.runInteraction(
|
||||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.replication.tcp.streams._base import (
|
||||||
|
_STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||||
|
AccountDataStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.replication._base import BaseStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||||
|
def test_update_function_room_account_data_limit(self):
|
||||||
|
"""Test replication with many room account data updates
|
||||||
|
"""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
|
||||||
|
# generate lots of account data updates
|
||||||
|
updates = []
|
||||||
|
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||||
|
update = "m.test_type.%i" % (i,)
|
||||||
|
self.get_success(
|
||||||
|
store.add_account_data_to_room("test_user", "test_room", update, {})
|
||||||
|
)
|
||||||
|
updates.append(update)
|
||||||
|
|
||||||
|
# also one global update
|
||||||
|
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
|
||||||
|
|
||||||
|
# tell the notifier to catch up to avoid duplicate rows.
|
||||||
|
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||||
|
# FIXME remove this when the above is fixed
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# check we're testing what we think we are: no rows should yet have been
|
||||||
|
# received
|
||||||
|
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||||
|
|
||||||
|
# now reconnect to pull the updates
|
||||||
|
self.reconnect()
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# we should have received all the expected rows in the right order
|
||||||
|
received_rows = self.test_handler.received_rdata_rows
|
||||||
|
|
||||||
|
for t in updates:
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||||
|
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||||
|
self.assertEqual(row.data_type, t)
|
||||||
|
self.assertEqual(row.room_id, "test_room")
|
||||||
|
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||||
|
self.assertEqual(row.data_type, "m.global")
|
||||||
|
self.assertIsNone(row.room_id)
|
||||||
|
|
||||||
|
self.assertEqual([], received_rows)
|
||||||
|
|
||||||
|
def test_update_function_global_account_data_limit(self):
|
||||||
|
"""Test replication with many global account data updates
|
||||||
|
"""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
|
||||||
|
# generate lots of account data updates
|
||||||
|
updates = []
|
||||||
|
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||||
|
update = "m.test_type.%i" % (i,)
|
||||||
|
self.get_success(store.add_account_data_for_user("test_user", update, {}))
|
||||||
|
updates.append(update)
|
||||||
|
|
||||||
|
# also one per-room update
|
||||||
|
self.get_success(
|
||||||
|
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
# tell the notifier to catch up to avoid duplicate rows.
|
||||||
|
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||||
|
# FIXME remove this when the above is fixed
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# check we're testing what we think we are: no rows should yet have been
|
||||||
|
# received
|
||||||
|
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||||
|
|
||||||
|
# now reconnect to pull the updates
|
||||||
|
self.reconnect()
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# we should have received all the expected rows in the right order
|
||||||
|
received_rows = self.test_handler.received_rdata_rows
|
||||||
|
|
||||||
|
for t in updates:
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||||
|
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||||
|
self.assertEqual(row.data_type, t)
|
||||||
|
self.assertIsNone(row.room_id)
|
||||||
|
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||||
|
self.assertEqual(row.data_type, "m.per_room")
|
||||||
|
self.assertEqual(row.room_id, "test_room")
|
||||||
|
|
||||||
|
self.assertEqual([], received_rows)
|
Loading…
Reference in New Issue