Support pagination tokens from /sync and /messages in the relations API. (#11952)

This commit is contained in:
Patrick Cloke 2022-02-10 10:52:48 -05:00 committed by GitHub
parent 337f38cac3
commit df36945ff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 217 additions and 53 deletions

1
changelog.d/11952.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.

View File

@ -32,14 +32,45 @@ from synapse.storage.relations import (
PaginationChunk, PaginationChunk,
RelationPaginationToken, RelationPaginationToken,
) )
from synapse.types import JsonDict from synapse.types import JsonDict, RoomStreamToken, StreamToken
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _parse_token(
store: "DataStore", token: Optional[str]
) -> Optional[StreamToken]:
"""
For backwards compatibility support RelationPaginationToken, but new pagination
tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
"""
if not token:
return None
# Luckily the format for StreamToken and RelationPaginationToken differ enough
# that they can easily be separated. An "_" appears in the serialization of
# RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
# "-" only for separators.
if "_" in token:
return await StreamToken.from_string(store, token)
else:
relation_token = RelationPaginationToken.from_string(token)
return StreamToken(
room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)
class RelationPaginationServlet(RestServlet): class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally """API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type. filtered by relation type and event type.
@ -88,13 +119,8 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = PaginationChunk(chunk=[]) pagination_chunk = PaginationChunk(chunk=[])
else: else:
# Return the relations # Return the relations
from_token = None from_token = await _parse_token(self.store, from_token_str)
if from_token_str: to_token = await _parse_token(self.store, to_token_str)
from_token = RelationPaginationToken.from_string(from_token_str)
to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
pagination_chunk = await self.store.get_relations_for_event( pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
@ -125,7 +151,7 @@ class RelationPaginationServlet(RestServlet):
events, now, bundle_aggregations=aggregations events, now, bundle_aggregations=aggregations
) )
return_value = pagination_chunk.to_dict() return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events return_value["chunk"] = serialized_events
return_value["original_event"] = original_event return_value["original_event"] = original_event
@ -216,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
return 200, pagination_chunk.to_dict() return 200, await pagination_chunk.to_dict(self.store)
class RelationAggregationGroupPaginationServlet(RestServlet): class RelationAggregationGroupPaginationServlet(RestServlet):
@ -287,13 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
from_token_str = parse_string(request, "from") from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to") to_token_str = parse_string(request, "to")
from_token = None from_token = await _parse_token(self.store, from_token_str)
if from_token_str: to_token = await _parse_token(self.store, to_token_str)
from_token = RelationPaginationToken.from_string(from_token_str)
to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
result = await self.store.get_relations_for_event( result = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
@ -313,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
now = self.clock.time_msec() now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now) serialized_events = self._event_serializer.serialize_events(events, now)
return_value = result.to_dict() return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events return_value["chunk"] = serialized_events
return 200, return_value return 200, return_value

View File

@ -39,16 +39,13 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import ( from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
AggregationPaginationToken, from synapse.types import JsonDict, RoomStreamToken, StreamToken
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore):
aggregation_key: Optional[str] = None, aggregation_key: Optional[str] = None,
limit: int = 5, limit: int = 5,
direction: str = "b", direction: str = "b",
from_token: Optional[RelationPaginationToken] = None, from_token: Optional[StreamToken] = None,
to_token: Optional[RelationPaginationToken] = None, to_token: Optional[StreamToken] = None,
) -> PaginationChunk: ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering. """Get a list of relations for an event, ordered by topological ordering.
@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause( pagination_clause = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] from_token=from_token.room_key.as_historical_tuple()
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine, engine=self.database_engine,
) )
@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore):
last_topo_id = row[1] last_topo_id = row[1]
last_stream_id = row[2] last_stream_id = row[2]
next_batch = None # If there are more events, generate the next pagination key.
next_token = None
if len(events) > limit and last_topo_id and last_stream_id: if len(events) > limit and last_topo_id and last_stream_id:
next_batch = RelationPaginationToken(last_topo_id, last_stream_id) next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace("room_key", next_key)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)
return PaginationChunk( return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore):
annotations = await self.get_aggregation_groups_for_event(event_id, room_id) annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk: if annotations.chunk:
aggregations.annotations = annotations.to_dict() aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self.get_relations_for_event( references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f" event_id, room_id, RelationTypes.REFERENCE, direction="f"
) )
if references.chunk: if references.chunk:
aggregations.references = references.to_dict() aggregations.references = await references.to_dict(cast("DataStore", self))
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled: if self._msc3440_enabled:

View File

@ -13,13 +13,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr import attr
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,14 +42,14 @@ class PaginationChunk:
next_batch: Optional[Any] = None next_batch: Optional[Any] = None
prev_batch: Optional[Any] = None prev_batch: Optional[Any] = None
def to_dict(self) -> Dict[str, Any]: async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
d = {"chunk": self.chunk} d = {"chunk": self.chunk}
if self.next_batch: if self.next_batch:
d["next_batch"] = self.next_batch.to_string() d["next_batch"] = await self.next_batch.to_string(store)
if self.prev_batch: if self.prev_batch:
d["prev_batch"] = self.prev_batch.to_string() d["prev_batch"] = await self.prev_batch.to_string(store)
return d return d
@ -75,7 +78,7 @@ class RelationPaginationToken:
except ValueError: except ValueError:
raise SynapseError(400, "Invalid relation pagination token") raise SynapseError(400, "Invalid relation pagination token")
def to_string(self) -> str: async def to_string(self, store: "DataStore") -> str:
return "%d-%d" % (self.topological, self.stream) return "%d-%d" % (self.topological, self.stream)
def as_tuple(self) -> Tuple[Any, ...]: def as_tuple(self) -> Tuple[Any, ...]:
@ -105,7 +108,7 @@ class AggregationPaginationToken:
except ValueError: except ValueError:
raise SynapseError(400, "Invalid aggregation pagination token") raise SynapseError(400, "Invalid aggregation pagination token")
def to_string(self) -> str: async def to_string(self, store: "DataStore") -> str:
return "%d-%d" % (self.count, self.stream) return "%d-%d" % (self.count, self.stream)
def as_tuple(self) -> Tuple[Any, ...]: def as_tuple(self) -> Tuple[Any, ...]:

View File

@ -21,7 +21,8 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync from synapse.rest.client import login, register, relations, room, sync
from synapse.types import JsonDict from synapse.storage.relations import RelationPaginationToken
from synapse.types import JsonDict, StreamToken
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body.get("next_batch"), str, channel.json_body channel.json_body.get("next_batch"), str, channel.json_body
) )
def _stream_token_to_relation_token(self, token: str) -> str:
"""Convert a StreamToken into a legacy token (RelationPaginationToken)."""
room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
return self.get_success(
RelationPaginationToken(
topological=room_key.topological, stream=room_key.stream
).to_string(self.store)
)
def test_repeated_paginate_relations(self): def test_repeated_paginate_relations(self):
"""Test that if we paginate using a limit and tokens then we get the """Test that if we paginate using a limit and tokens then we get the
expected events. expected events.
@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"]) expected_event_ids.append(channel.json_body["event_id"])
prev_token: Optional[str] = None prev_token = ""
found_event_ids: List[str] = [] found_event_ids: List[str] = []
for _ in range(20): for _ in range(20):
from_token = "" from_token = ""
@ -222,8 +232,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
% (self.room, self.parent_id, from_token),
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -241,6 +250,93 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse() found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids) self.assertEquals(found_event_ids, expected_event_ids)
# Reset and try again, but convert the tokens to the legacy format.
prev_token = ""
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
def test_pagination_from_sync_and_messages(self):
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
# Send an event after the relation events.
self.helper.send(self.room, body="Latest event", tok=self.user_token)
# Request /sync, limiting it such that only the latest event is returned
# (and not the relation).
filter = urllib.parse.quote_plus(
'{"room": {"timeline": {"limit": 1}}}'.encode()
)
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
sync_prev_batch = room_timeline["prev_batch"]
self.assertIsNotNone(sync_prev_batch)
# Ensure the relation event is not in the batch returned from /sync.
self.assertNotIn(
annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
)
# Request /messages, limiting it such that only the latest event is
# returned (and not the relation).
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b&limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
messages_end = channel.json_body["end"]
self.assertIsNotNone(messages_end)
# Ensure the relation event is not in the chunk returned from /messages.
self.assertNotIn(
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
# Request /relations with the pagination tokens received from both the
# /sync and /messages responses above, in turn.
#
# This is a tiny bit silly since the client wouldn't know the parent ID
# from the requests above; consider the parent ID to be known from a
# previous /sync.
for from_token in (sync_prev_batch, messages_end):
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# The relation should be in the returned chunk.
self.assertIn(
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
def test_aggregation_pagination_groups(self): def test_aggregation_pagination_groups(self):
"""Test that we can paginate annotation groups correctly.""" """Test that we can paginate annotation groups correctly."""
@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
prev_token: Optional[str] = None prev_token = ""
found_event_ids: List[str] = [] found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode()) encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20): for _ in range(20):
@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/rooms/%s" f"/_matrix/client/unstable/rooms/{self.room}"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s" f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
% ( f"/m.reaction/{encoded_key}?limit=1{from_token}",
self.room, access_token=self.user_token,
self.parent_id, )
RelationTypes.ANNOTATION, self.assertEquals(200, channel.code, channel.json_body)
encoded_key,
from_token, self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
),
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
# Reset and try again, but convert the tokens to the legacy format.
prev_token = ""
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}"
f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
f"/m.reaction/{encoded_key}?limit=1{from_token}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)