Improvements to bundling aggregations. (#11815)
This is some odds and ends found during the review of #11791 and while continuing to work in this code: * Return attrs classes instead of dictionaries from some methods to improve type safety. * Call `get_bundled_aggregations` fewer times. * Adds a missing assertion in the tests. * Do not return empty bundled aggregations for an event (preferring to not include the bundle at all, as the docstring states).
This commit is contained in:
parent
d8df8e6c14
commit
2897fb6b4f
|
@ -0,0 +1 @@
|
||||||
|
Improve type safety of bundled aggregations code.
|
|
@ -14,7 +14,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
|
||||||
|
|
||||||
from . import EventBase
|
from . import EventBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.storage.databases.main.relations import BundledAggregations
|
||||||
|
|
||||||
|
|
||||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||||
# (?<!stuff) matches if the current position in the string is not preceded
|
# (?<!stuff) matches if the current position in the string is not preceded
|
||||||
# by a match for 'stuff'.
|
# by a match for 'stuff'.
|
||||||
|
@ -376,7 +390,7 @@ class EventClientSerializer:
|
||||||
event: Union[JsonDict, EventBase],
|
event: Union[JsonDict, EventBase],
|
||||||
time_now: int,
|
time_now: int,
|
||||||
*,
|
*,
|
||||||
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
|
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Serializes a single event.
|
"""Serializes a single event.
|
||||||
|
@ -415,7 +429,7 @@ class EventClientSerializer:
|
||||||
self,
|
self,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
time_now: int,
|
time_now: int,
|
||||||
aggregations: JsonDict,
|
aggregations: "BundledAggregations",
|
||||||
serialized_event: JsonDict,
|
serialized_event: JsonDict,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
|
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
|
||||||
|
@ -427,13 +441,18 @@ class EventClientSerializer:
|
||||||
serialized_event: The serialized event which may be modified.
|
serialized_event: The serialized event which may be modified.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Make a copy in-case the object is cached.
|
serialized_aggregations = {}
|
||||||
aggregations = aggregations.copy()
|
|
||||||
|
|
||||||
if RelationTypes.REPLACE in aggregations:
|
if aggregations.annotations:
|
||||||
|
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
|
||||||
|
|
||||||
|
if aggregations.references:
|
||||||
|
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
|
||||||
|
|
||||||
|
if aggregations.replace:
|
||||||
# If there is an edit replace the content, preserving existing
|
# If there is an edit replace the content, preserving existing
|
||||||
# relations.
|
# relations.
|
||||||
edit = aggregations[RelationTypes.REPLACE]
|
edit = aggregations.replace
|
||||||
|
|
||||||
# Ensure we take copies of the edit content, otherwise we risk modifying
|
# Ensure we take copies of the edit content, otherwise we risk modifying
|
||||||
# the original event.
|
# the original event.
|
||||||
|
@ -451,24 +470,28 @@ class EventClientSerializer:
|
||||||
else:
|
else:
|
||||||
serialized_event["content"].pop("m.relates_to", None)
|
serialized_event["content"].pop("m.relates_to", None)
|
||||||
|
|
||||||
aggregations[RelationTypes.REPLACE] = {
|
serialized_aggregations[RelationTypes.REPLACE] = {
|
||||||
"event_id": edit.event_id,
|
"event_id": edit.event_id,
|
||||||
"origin_server_ts": edit.origin_server_ts,
|
"origin_server_ts": edit.origin_server_ts,
|
||||||
"sender": edit.sender,
|
"sender": edit.sender,
|
||||||
}
|
}
|
||||||
|
|
||||||
# If this event is the start of a thread, include a summary of the replies.
|
# If this event is the start of a thread, include a summary of the replies.
|
||||||
if RelationTypes.THREAD in aggregations:
|
if aggregations.thread:
|
||||||
# Serialize the latest thread event.
|
serialized_aggregations[RelationTypes.THREAD] = {
|
||||||
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
|
# Don't bundle aggregations as this could recurse forever.
|
||||||
|
"latest_event": self.serialize_event(
|
||||||
# Don't bundle aggregations as this could recurse forever.
|
aggregations.thread.latest_event, time_now, bundle_aggregations=None
|
||||||
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
|
),
|
||||||
latest_thread_event, time_now, bundle_aggregations=None
|
"count": aggregations.thread.count,
|
||||||
)
|
"current_user_participated": aggregations.thread.current_user_participated,
|
||||||
|
}
|
||||||
|
|
||||||
# Include the bundled aggregations in the event.
|
# Include the bundled aggregations in the event.
|
||||||
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
|
if serialized_aggregations:
|
||||||
|
serialized_event["unsigned"].setdefault("m.relations", {}).update(
|
||||||
|
serialized_aggregations
|
||||||
|
)
|
||||||
|
|
||||||
def serialize_events(
|
def serialize_events(
|
||||||
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
|
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
|
||||||
|
|
|
@ -30,6 +30,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import attr
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
|
@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.federation.federation_client import InvalidResponseError
|
from synapse.federation.federation_client import InvalidResponseError
|
||||||
from synapse.handlers.federation import get_domains_from_state
|
from synapse.handlers.federation import get_domains_from_state
|
||||||
from synapse.rest.admin._base import assert_user_is_admin
|
from synapse.rest.admin._base import assert_user_is_admin
|
||||||
|
from synapse.storage.databases.main.relations import BundledAggregations
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -90,6 +92,17 @@ id_server_scheme = "https://"
|
||||||
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
|
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class EventContext:
|
||||||
|
events_before: List[EventBase]
|
||||||
|
event: EventBase
|
||||||
|
events_after: List[EventBase]
|
||||||
|
state: List[EventBase]
|
||||||
|
aggregations: Dict[str, BundledAggregations]
|
||||||
|
start: str
|
||||||
|
end: str
|
||||||
|
|
||||||
|
|
||||||
class RoomCreationHandler:
|
class RoomCreationHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -1119,7 +1132,7 @@ class RoomContextHandler:
|
||||||
limit: int,
|
limit: int,
|
||||||
event_filter: Optional[Filter],
|
event_filter: Optional[Filter],
|
||||||
use_admin_priviledge: bool = False,
|
use_admin_priviledge: bool = False,
|
||||||
) -> Optional[JsonDict]:
|
) -> Optional[EventContext]:
|
||||||
"""Retrieves events, pagination tokens and state around a given event
|
"""Retrieves events, pagination tokens and state around a given event
|
||||||
in a room.
|
in a room.
|
||||||
|
|
||||||
|
@ -1167,38 +1180,28 @@ class RoomContextHandler:
|
||||||
results = await self.store.get_events_around(
|
results = await self.store.get_events_around(
|
||||||
room_id, event_id, before_limit, after_limit, event_filter
|
room_id, event_id, before_limit, after_limit, event_filter
|
||||||
)
|
)
|
||||||
|
events_before = results.events_before
|
||||||
|
events_after = results.events_after
|
||||||
|
|
||||||
if event_filter:
|
if event_filter:
|
||||||
results["events_before"] = await event_filter.filter(
|
events_before = await event_filter.filter(events_before)
|
||||||
results["events_before"]
|
events_after = await event_filter.filter(events_after)
|
||||||
)
|
|
||||||
results["events_after"] = await event_filter.filter(results["events_after"])
|
|
||||||
|
|
||||||
results["events_before"] = await filter_evts(results["events_before"])
|
events_before = await filter_evts(events_before)
|
||||||
results["events_after"] = await filter_evts(results["events_after"])
|
events_after = await filter_evts(events_after)
|
||||||
# filter_evts can return a pruned event in case the user is allowed to see that
|
# filter_evts can return a pruned event in case the user is allowed to see that
|
||||||
# there's something there but not see the content, so use the event that's in
|
# there's something there but not see the content, so use the event that's in
|
||||||
# `filtered` rather than the event we retrieved from the datastore.
|
# `filtered` rather than the event we retrieved from the datastore.
|
||||||
results["event"] = filtered[0]
|
event = filtered[0]
|
||||||
|
|
||||||
# Fetch the aggregations.
|
# Fetch the aggregations.
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
aggregations = await self.store.get_bundled_aggregations(
|
||||||
[results["event"]], user.to_string()
|
itertools.chain(events_before, (event,), events_after),
|
||||||
|
user.to_string(),
|
||||||
)
|
)
|
||||||
aggregations.update(
|
|
||||||
await self.store.get_bundled_aggregations(
|
|
||||||
results["events_before"], user.to_string()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
aggregations.update(
|
|
||||||
await self.store.get_bundled_aggregations(
|
|
||||||
results["events_after"], user.to_string()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
results["aggregations"] = aggregations
|
|
||||||
|
|
||||||
if results["events_after"]:
|
if events_after:
|
||||||
last_event_id = results["events_after"][-1].event_id
|
last_event_id = events_after[-1].event_id
|
||||||
else:
|
else:
|
||||||
last_event_id = event_id
|
last_event_id = event_id
|
||||||
|
|
||||||
|
@ -1206,9 +1209,9 @@ class RoomContextHandler:
|
||||||
state_filter = StateFilter.from_lazy_load_member_list(
|
state_filter = StateFilter.from_lazy_load_member_list(
|
||||||
ev.sender
|
ev.sender
|
||||||
for ev in itertools.chain(
|
for ev in itertools.chain(
|
||||||
results["events_before"],
|
events_before,
|
||||||
(results["event"],),
|
(event,),
|
||||||
results["events_after"],
|
events_after,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1226,21 +1229,23 @@ class RoomContextHandler:
|
||||||
if event_filter:
|
if event_filter:
|
||||||
state_events = await event_filter.filter(state_events)
|
state_events = await event_filter.filter(state_events)
|
||||||
|
|
||||||
results["state"] = await filter_evts(state_events)
|
|
||||||
|
|
||||||
# We use a dummy token here as we only care about the room portion of
|
# We use a dummy token here as we only care about the room portion of
|
||||||
# the token, which we replace.
|
# the token, which we replace.
|
||||||
token = StreamToken.START
|
token = StreamToken.START
|
||||||
|
|
||||||
results["start"] = await token.copy_and_replace(
|
return EventContext(
|
||||||
"room_key", results["start"]
|
events_before=events_before,
|
||||||
).to_string(self.store)
|
event=event,
|
||||||
|
events_after=events_after,
|
||||||
results["end"] = await token.copy_and_replace(
|
state=await filter_evts(state_events),
|
||||||
"room_key", results["end"]
|
aggregations=aggregations,
|
||||||
).to_string(self.store)
|
start=await token.copy_and_replace("room_key", results.start).to_string(
|
||||||
|
self.store
|
||||||
return results
|
),
|
||||||
|
end=await token.copy_and_replace("room_key", results.end).to_string(
|
||||||
|
self.store
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TimestampLookupHandler:
|
class TimestampLookupHandler:
|
||||||
|
|
|
@ -361,36 +361,37 @@ class SearchHandler:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Context for search returned %d and %d events",
|
"Context for search returned %d and %d events",
|
||||||
len(res["events_before"]),
|
len(res.events_before),
|
||||||
len(res["events_after"]),
|
len(res.events_after),
|
||||||
)
|
)
|
||||||
|
|
||||||
res["events_before"] = await filter_events_for_client(
|
events_before = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), res["events_before"]
|
self.storage, user.to_string(), res.events_before
|
||||||
)
|
)
|
||||||
|
|
||||||
res["events_after"] = await filter_events_for_client(
|
events_after = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), res["events_after"]
|
self.storage, user.to_string(), res.events_after
|
||||||
)
|
)
|
||||||
|
|
||||||
res["start"] = await now_token.copy_and_replace(
|
context = {
|
||||||
"room_key", res["start"]
|
"events_before": events_before,
|
||||||
).to_string(self.store)
|
"events_after": events_after,
|
||||||
|
"start": await now_token.copy_and_replace(
|
||||||
res["end"] = await now_token.copy_and_replace(
|
"room_key", res.start
|
||||||
"room_key", res["end"]
|
).to_string(self.store),
|
||||||
).to_string(self.store)
|
"end": await now_token.copy_and_replace(
|
||||||
|
"room_key", res.end
|
||||||
|
).to_string(self.store),
|
||||||
|
}
|
||||||
|
|
||||||
if include_profile:
|
if include_profile:
|
||||||
senders = {
|
senders = {
|
||||||
ev.sender
|
ev.sender
|
||||||
for ev in itertools.chain(
|
for ev in itertools.chain(events_before, [event], events_after)
|
||||||
res["events_before"], [event], res["events_after"]
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if res["events_after"]:
|
if events_after:
|
||||||
last_event_id = res["events_after"][-1].event_id
|
last_event_id = events_after[-1].event_id
|
||||||
else:
|
else:
|
||||||
last_event_id = event.event_id
|
last_event_id = event.event_id
|
||||||
|
|
||||||
|
@ -402,7 +403,7 @@ class SearchHandler:
|
||||||
last_event_id, state_filter
|
last_event_id, state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
res["profile_info"] = {
|
context["profile_info"] = {
|
||||||
s.state_key: {
|
s.state_key: {
|
||||||
"displayname": s.content.get("displayname", None),
|
"displayname": s.content.get("displayname", None),
|
||||||
"avatar_url": s.content.get("avatar_url", None),
|
"avatar_url": s.content.get("avatar_url", None),
|
||||||
|
@ -411,7 +412,7 @@ class SearchHandler:
|
||||||
if s.type == EventTypes.Member and s.state_key in senders
|
if s.type == EventTypes.Member and s.state_key in senders
|
||||||
}
|
}
|
||||||
|
|
||||||
contexts[event.event_id] = res
|
contexts[event.event_id] = context
|
||||||
else:
|
else:
|
||||||
contexts = {}
|
contexts = {}
|
||||||
|
|
||||||
|
@ -421,10 +422,10 @@ class SearchHandler:
|
||||||
|
|
||||||
for context in contexts.values():
|
for context in contexts.values():
|
||||||
context["events_before"] = self._event_serializer.serialize_events(
|
context["events_before"] = self._event_serializer.serialize_events(
|
||||||
context["events_before"], time_now
|
context["events_before"], time_now # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
context["events_after"] = self._event_serializer.serialize_events(
|
context["events_after"] = self._event_serializer.serialize_events(
|
||||||
context["events_after"], time_now
|
context["events_after"], time_now # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
state_results = {}
|
state_results = {}
|
||||||
|
|
|
@ -37,6 +37,7 @@ from synapse.logging.context import current_context
|
||||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
||||||
|
from synapse.storage.databases.main.relations import BundledAggregations
|
||||||
from synapse.storage.roommember import MemberSummary
|
from synapse.storage.roommember import MemberSummary
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -100,7 +101,7 @@ class TimelineBatch:
|
||||||
limited: bool
|
limited: bool
|
||||||
# A mapping of event ID to the bundled aggregations for the above events.
|
# A mapping of event ID to the bundled aggregations for the above events.
|
||||||
# This is only calculated if limited is true.
|
# This is only calculated if limited is true.
|
||||||
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
|
bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
"""Make the result appear empty if there are no updates. This is used
|
"""Make the result appear empty if there are no updates. This is used
|
||||||
|
|
|
@ -455,7 +455,7 @@ class Mailer:
|
||||||
}
|
}
|
||||||
|
|
||||||
the_events = await filter_events_for_client(
|
the_events = await filter_events_for_client(
|
||||||
self.storage, user_id, results["events_before"]
|
self.storage, user_id, results.events_before
|
||||||
)
|
)
|
||||||
the_events.append(notif_event)
|
the_events.append(notif_event)
|
||||||
|
|
||||||
|
|
|
@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
results = await self.room_context_handler.get_event_context(
|
event_context = await self.room_context_handler.get_event_context(
|
||||||
requester,
|
requester,
|
||||||
room_id,
|
room_id,
|
||||||
event_id,
|
event_id,
|
||||||
|
@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet):
|
||||||
use_admin_priviledge=True,
|
use_admin_priviledge=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not event_context:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
|
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
aggregations = results.pop("aggregations", None)
|
results = {
|
||||||
results["events_before"] = self._event_serializer.serialize_events(
|
"events_before": self._event_serializer.serialize_events(
|
||||||
results["events_before"], time_now, bundle_aggregations=aggregations
|
event_context.events_before,
|
||||||
)
|
time_now,
|
||||||
results["event"] = self._event_serializer.serialize_event(
|
bundle_aggregations=event_context.aggregations,
|
||||||
results["event"], time_now, bundle_aggregations=aggregations
|
),
|
||||||
)
|
"event": self._event_serializer.serialize_event(
|
||||||
results["events_after"] = self._event_serializer.serialize_events(
|
event_context.event,
|
||||||
results["events_after"], time_now, bundle_aggregations=aggregations
|
time_now,
|
||||||
)
|
bundle_aggregations=event_context.aggregations,
|
||||||
results["state"] = self._event_serializer.serialize_events(
|
),
|
||||||
results["state"], time_now
|
"events_after": self._event_serializer.serialize_events(
|
||||||
)
|
event_context.events_after,
|
||||||
|
time_now,
|
||||||
|
bundle_aggregations=event_context.aggregations,
|
||||||
|
),
|
||||||
|
"state": self._event_serializer.serialize_events(
|
||||||
|
event_context.state, time_now
|
||||||
|
),
|
||||||
|
"start": event_context.start,
|
||||||
|
"end": event_context.end,
|
||||||
|
}
|
||||||
|
|
||||||
return HTTPStatus.OK, results
|
return HTTPStatus.OK, results
|
||||||
|
|
||||||
|
|
|
@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
results = await self.room_context_handler.get_event_context(
|
event_context = await self.room_context_handler.get_event_context(
|
||||||
requester, room_id, event_id, limit, event_filter
|
requester, room_id, event_id, limit, event_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not event_context:
|
||||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
aggregations = results.pop("aggregations", None)
|
results = {
|
||||||
results["events_before"] = self._event_serializer.serialize_events(
|
"events_before": self._event_serializer.serialize_events(
|
||||||
results["events_before"], time_now, bundle_aggregations=aggregations
|
event_context.events_before,
|
||||||
)
|
time_now,
|
||||||
results["event"] = self._event_serializer.serialize_event(
|
bundle_aggregations=event_context.aggregations,
|
||||||
results["event"], time_now, bundle_aggregations=aggregations
|
),
|
||||||
)
|
"event": self._event_serializer.serialize_event(
|
||||||
results["events_after"] = self._event_serializer.serialize_events(
|
event_context.event,
|
||||||
results["events_after"], time_now, bundle_aggregations=aggregations
|
time_now,
|
||||||
)
|
bundle_aggregations=event_context.aggregations,
|
||||||
results["state"] = self._event_serializer.serialize_events(
|
),
|
||||||
results["state"], time_now
|
"events_after": self._event_serializer.serialize_events(
|
||||||
)
|
event_context.events_after,
|
||||||
|
time_now,
|
||||||
|
bundle_aggregations=event_context.aggregations,
|
||||||
|
),
|
||||||
|
"state": self._event_serializer.serialize_events(
|
||||||
|
event_context.state, time_now
|
||||||
|
),
|
||||||
|
"start": event_context.start,
|
||||||
|
"end": event_context.end,
|
||||||
|
}
|
||||||
|
|
||||||
return 200, results
|
return 200, results
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
|
from synapse.storage.databases.main.relations import BundledAggregations
|
||||||
from synapse.types import JsonDict, StreamToken
|
from synapse.types import JsonDict, StreamToken
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
|
@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
def serialize(
|
def serialize(
|
||||||
events: Iterable[EventBase],
|
events: Iterable[EventBase],
|
||||||
aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
|
aggregations: Optional[Dict[str, BundledAggregations]] = None,
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
return self._event_serializer.serialize_events(
|
return self._event_serializer.serialize_events(
|
||||||
events,
|
events,
|
||||||
|
|
|
@ -13,17 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
@ -43,6 +33,7 @@ from synapse.storage.relations import (
|
||||||
PaginationChunk,
|
PaginationChunk,
|
||||||
RelationPaginationToken,
|
RelationPaginationToken,
|
||||||
)
|
)
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -51,6 +42,30 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class _ThreadAggregation:
|
||||||
|
latest_event: EventBase
|
||||||
|
count: int
|
||||||
|
current_user_participated: bool
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class BundledAggregations:
|
||||||
|
"""
|
||||||
|
The bundled aggregations for an event.
|
||||||
|
|
||||||
|
Some values require additional processing during serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotations: Optional[JsonDict] = None
|
||||||
|
references: Optional[JsonDict] = None
|
||||||
|
replace: Optional[EventBase] = None
|
||||||
|
thread: Optional[_ThreadAggregation] = None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.annotations or self.references or self.replace or self.thread)
|
||||||
|
|
||||||
|
|
||||||
class RelationsWorkerStore(SQLBaseStore):
|
class RelationsWorkerStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
async def _get_bundled_aggregation_for_event(
|
async def _get_bundled_aggregation_for_event(
|
||||||
self, event: EventBase, user_id: str
|
self, event: EventBase, user_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[BundledAggregations]:
|
||||||
"""Generate bundled aggregations for an event.
|
"""Generate bundled aggregations for an event.
|
||||||
|
|
||||||
Note that this does not use a cache, but depends on cached methods.
|
Note that this does not use a cache, but depends on cached methods.
|
||||||
|
@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
# The bundled aggregations to include, a mapping of relation type to a
|
# The bundled aggregations to include, a mapping of relation type to a
|
||||||
# type-specific value. Some types include the direct return type here
|
# type-specific value. Some types include the direct return type here
|
||||||
# while others need more processing during serialization.
|
# while others need more processing during serialization.
|
||||||
aggregations: Dict[str, Any] = {}
|
aggregations = BundledAggregations()
|
||||||
|
|
||||||
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
||||||
if annotations.chunk:
|
if annotations.chunk:
|
||||||
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
|
aggregations.annotations = annotations.to_dict()
|
||||||
|
|
||||||
references = await self.get_relations_for_event(
|
references = await self.get_relations_for_event(
|
||||||
event_id, room_id, RelationTypes.REFERENCE, direction="f"
|
event_id, room_id, RelationTypes.REFERENCE, direction="f"
|
||||||
)
|
)
|
||||||
if references.chunk:
|
if references.chunk:
|
||||||
aggregations[RelationTypes.REFERENCE] = references.to_dict()
|
aggregations.references = references.to_dict()
|
||||||
|
|
||||||
edit = None
|
edit = None
|
||||||
if event.type == EventTypes.Message:
|
if event.type == EventTypes.Message:
|
||||||
edit = await self.get_applicable_edit(event_id, room_id)
|
edit = await self.get_applicable_edit(event_id, room_id)
|
||||||
|
|
||||||
if edit:
|
if edit:
|
||||||
aggregations[RelationTypes.REPLACE] = edit
|
aggregations.replace = edit
|
||||||
|
|
||||||
# If this event is the start of a thread, include a summary of the replies.
|
# If this event is the start of a thread, include a summary of the replies.
|
||||||
if self._msc3440_enabled:
|
if self._msc3440_enabled:
|
||||||
|
@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
event_id, room_id, user_id
|
event_id, room_id, user_id
|
||||||
)
|
)
|
||||||
if latest_thread_event:
|
if latest_thread_event:
|
||||||
aggregations[RelationTypes.THREAD] = {
|
aggregations.thread = _ThreadAggregation(
|
||||||
"latest_event": latest_thread_event,
|
latest_event=latest_thread_event,
|
||||||
"count": thread_count,
|
count=thread_count,
|
||||||
"current_user_participated": participated,
|
current_user_participated=participated,
|
||||||
}
|
)
|
||||||
|
|
||||||
# Store the bundled aggregations in the event metadata for later use.
|
# Store the bundled aggregations in the event metadata for later use.
|
||||||
return aggregations
|
return aggregations
|
||||||
|
@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
self,
|
self,
|
||||||
events: Iterable[EventBase],
|
events: Iterable[EventBase],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
) -> Dict[str, BundledAggregations]:
|
||||||
"""Generate bundled aggregations for events.
|
"""Generate bundled aggregations for events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
results = {}
|
results = {}
|
||||||
for event in events:
|
for event in events:
|
||||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||||
if event_result is not None:
|
if event_result:
|
||||||
results[event.event_id] = event_result
|
results[event.event_id] = event_result
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -81,6 +81,14 @@ class _EventDictReturn:
|
||||||
stream_ordering: int
|
stream_ordering: int
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class _EventsAround:
|
||||||
|
events_before: List[EventBase]
|
||||||
|
events_after: List[EventBase]
|
||||||
|
start: RoomStreamToken
|
||||||
|
end: RoomStreamToken
|
||||||
|
|
||||||
|
|
||||||
def generate_pagination_where_clause(
|
def generate_pagination_where_clause(
|
||||||
direction: str,
|
direction: str,
|
||||||
column_names: Tuple[str, str],
|
column_names: Tuple[str, str],
|
||||||
|
@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
before_limit: int,
|
before_limit: int,
|
||||||
after_limit: int,
|
after_limit: int,
|
||||||
event_filter: Optional[Filter] = None,
|
event_filter: Optional[Filter] = None,
|
||||||
) -> dict:
|
) -> _EventsAround:
|
||||||
"""Retrieve events and pagination tokens around a given event in a
|
"""Retrieve events and pagination tokens around a given event in a
|
||||||
room.
|
room.
|
||||||
"""
|
"""
|
||||||
|
@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
list(results["after"]["event_ids"]), get_prev_content=True
|
list(results["after"]["event_ids"]), get_prev_content=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return _EventsAround(
|
||||||
"events_before": events_before,
|
events_before=events_before,
|
||||||
"events_after": events_after,
|
events_after=events_after,
|
||||||
"start": results["before"]["token"],
|
start=results["before"]["token"],
|
||||||
"end": results["after"]["token"],
|
end=results["after"]["token"],
|
||||||
}
|
)
|
||||||
|
|
||||||
def _get_events_around_txn(
|
def _get_events_around_txn(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
||||||
self.assertTrue(room_timeline["limited"])
|
self.assertTrue(room_timeline["limited"])
|
||||||
self._find_event_in_chunk(room_timeline["events"])
|
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
|
||||||
|
|
||||||
def test_aggregation_get_event_for_annotation(self):
|
def test_aggregation_get_event_for_annotation(self):
|
||||||
"""Test that annotations do not get bundled aggregations included
|
"""Test that annotations do not get bundled aggregations included
|
||||||
|
|
Loading…
Reference in New Issue