Move get_bundled_aggregations to relations handler. (#12237)
The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit.
This commit is contained in:
parent
80e0e1f35e
commit
8fe930c215
|
@ -0,0 +1 @@
|
||||||
|
Refactor the relations endpoints to add a `RelationsHandler`.
|
|
@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
|
||||||
from . import EventBase
|
from . import EventBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
|
|
||||||
|
|
||||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||||
|
|
|
@ -134,6 +134,7 @@ class PaginationHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
# IDs of rooms in which there currently an active purge *or delete* operation.
|
# IDs of rooms in which there currently an active purge *or delete* operation.
|
||||||
|
@ -539,7 +540,9 @@ class PaginationHandler:
|
||||||
state_dict = await self.store.get_events(list(state_ids.values()))
|
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||||
state = state_dict.values()
|
state = state_dict.values()
|
||||||
|
|
||||||
aggregations = await self.store.get_bundled_aggregations(events, user_id)
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
|
events, user_id
|
||||||
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -12,18 +12,53 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from synapse.api.constants import RelationTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.types import JsonDict, Requester, StreamToken
|
from synapse.types import JsonDict, Requester, StreamToken
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class _ThreadAggregation:
|
||||||
|
# The latest event in the thread.
|
||||||
|
latest_event: EventBase
|
||||||
|
# The latest edit to the latest event in the thread.
|
||||||
|
latest_edit: Optional[EventBase]
|
||||||
|
# The total number of events in the thread.
|
||||||
|
count: int
|
||||||
|
# True if the current user has sent an event to the thread.
|
||||||
|
current_user_participated: bool
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class BundledAggregations:
|
||||||
|
"""
|
||||||
|
The bundled aggregations for an event.
|
||||||
|
|
||||||
|
Some values require additional processing during serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotations: Optional[JsonDict] = None
|
||||||
|
references: Optional[JsonDict] = None
|
||||||
|
replace: Optional[EventBase] = None
|
||||||
|
thread: Optional[_ThreadAggregation] = None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.annotations or self.references or self.replace or self.thread)
|
||||||
|
|
||||||
|
|
||||||
class RelationsHandler:
|
class RelationsHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._main_store = hs.get_datastores().main
|
self._main_store = hs.get_datastores().main
|
||||||
|
@ -103,7 +138,7 @@ class RelationsHandler:
|
||||||
)
|
)
|
||||||
# The relations returned for the requested event do include their
|
# The relations returned for the requested event do include their
|
||||||
# bundled aggregations.
|
# bundled aggregations.
|
||||||
aggregations = await self._main_store.get_bundled_aggregations(
|
aggregations = await self.get_bundled_aggregations(
|
||||||
events, requester.user.to_string()
|
events, requester.user.to_string()
|
||||||
)
|
)
|
||||||
serialized_events = self._event_serializer.serialize_events(
|
serialized_events = self._event_serializer.serialize_events(
|
||||||
|
@ -115,3 +150,115 @@ class RelationsHandler:
|
||||||
return_value["original_event"] = original_event
|
return_value["original_event"] = original_event
|
||||||
|
|
||||||
return return_value
|
return return_value
|
||||||
|
|
||||||
|
async def _get_bundled_aggregation_for_event(
|
||||||
|
self, event: EventBase, user_id: str
|
||||||
|
) -> Optional[BundledAggregations]:
|
||||||
|
"""Generate bundled aggregations for an event.
|
||||||
|
|
||||||
|
Note that this does not use a cache, but depends on cached methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bundled aggregations for an event, if bundled aggregations are
|
||||||
|
enabled and the event can have bundled aggregations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Do not bundle aggregations for an event which represents an edit or an
|
||||||
|
# annotation. It does not make sense for them to have related events.
|
||||||
|
relates_to = event.content.get("m.relates_to")
|
||||||
|
if isinstance(relates_to, (dict, frozendict)):
|
||||||
|
relation_type = relates_to.get("rel_type")
|
||||||
|
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
||||||
|
return None
|
||||||
|
|
||||||
|
event_id = event.event_id
|
||||||
|
room_id = event.room_id
|
||||||
|
|
||||||
|
# The bundled aggregations to include, a mapping of relation type to a
|
||||||
|
# type-specific value. Some types include the direct return type here
|
||||||
|
# while others need more processing during serialization.
|
||||||
|
aggregations = BundledAggregations()
|
||||||
|
|
||||||
|
annotations = await self._main_store.get_aggregation_groups_for_event(
|
||||||
|
event_id, room_id
|
||||||
|
)
|
||||||
|
if annotations.chunk:
|
||||||
|
aggregations.annotations = await annotations.to_dict(
|
||||||
|
cast("DataStore", self)
|
||||||
|
)
|
||||||
|
|
||||||
|
references = await self._main_store.get_relations_for_event(
|
||||||
|
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
||||||
|
)
|
||||||
|
if references.chunk:
|
||||||
|
aggregations.references = await references.to_dict(cast("DataStore", self))
|
||||||
|
|
||||||
|
# Store the bundled aggregations in the event metadata for later use.
|
||||||
|
return aggregations
|
||||||
|
|
||||||
|
async def get_bundled_aggregations(
|
||||||
|
self, events: Iterable[EventBase], user_id: str
|
||||||
|
) -> Dict[str, BundledAggregations]:
|
||||||
|
"""Generate bundled aggregations for events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events: The iterable of events to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A map of event ID to the bundled aggregation for the event. Not all
|
||||||
|
events may have bundled aggregations in the results.
|
||||||
|
"""
|
||||||
|
# De-duplicate events by ID to handle the same event requested multiple times.
|
||||||
|
#
|
||||||
|
# State events do not get bundled aggregations.
|
||||||
|
events_by_id = {
|
||||||
|
event.event_id: event for event in events if not event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
# event ID -> bundled aggregation in non-serialized form.
|
||||||
|
results: Dict[str, BundledAggregations] = {}
|
||||||
|
|
||||||
|
# Fetch other relations per event.
|
||||||
|
for event in events_by_id.values():
|
||||||
|
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||||
|
if event_result:
|
||||||
|
results[event.event_id] = event_result
|
||||||
|
|
||||||
|
# Fetch any edits (but not for redacted events).
|
||||||
|
edits = await self._main_store.get_applicable_edits(
|
||||||
|
[
|
||||||
|
event_id
|
||||||
|
for event_id, event in events_by_id.items()
|
||||||
|
if not event.internal_metadata.is_redacted()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for event_id, edit in edits.items():
|
||||||
|
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||||
|
|
||||||
|
# Fetch thread summaries.
|
||||||
|
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
|
||||||
|
# Only fetch participated for a limited selection based on what had
|
||||||
|
# summaries.
|
||||||
|
participated = await self._main_store.get_threads_participated(
|
||||||
|
[event_id for event_id, summary in summaries.items() if summary], user_id
|
||||||
|
)
|
||||||
|
for event_id, summary in summaries.items():
|
||||||
|
if summary:
|
||||||
|
thread_count, latest_thread_event, edit = summary
|
||||||
|
results.setdefault(
|
||||||
|
event_id, BundledAggregations()
|
||||||
|
).thread = _ThreadAggregation(
|
||||||
|
latest_event=latest_thread_event,
|
||||||
|
latest_edit=edit,
|
||||||
|
count=thread_count,
|
||||||
|
# If there's a thread summary it must also exist in the
|
||||||
|
# participated dictionary.
|
||||||
|
current_user_participated=participated[event_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
@ -60,8 +60,8 @@ from synapse.events import EventBase
|
||||||
from synapse.events.utils import copy_power_levels_contents
|
from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.federation.federation_client import InvalidResponseError
|
from synapse.federation.federation_client import InvalidResponseError
|
||||||
from synapse.handlers.federation import get_domains_from_state
|
from synapse.handlers.federation import get_domains_from_state
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.rest.admin._base import assert_user_is_admin
|
from synapse.rest.admin._base import assert_user_is_admin
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -1118,6 +1118,7 @@ class RoomContextHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
async def get_event_context(
|
async def get_event_context(
|
||||||
self,
|
self,
|
||||||
|
@ -1190,7 +1191,7 @@ class RoomContextHandler:
|
||||||
event = filtered[0]
|
event = filtered[0]
|
||||||
|
|
||||||
# Fetch the aggregations.
|
# Fetch the aggregations.
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
itertools.chain(events_before, (event,), events_after),
|
itertools.chain(events_before, (event,), events_after),
|
||||||
user.to_string(),
|
user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -54,6 +54,7 @@ class SearchHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -354,7 +355,7 @@ class SearchHandler:
|
||||||
|
|
||||||
aggregations = None
|
aggregations = None
|
||||||
if self._msc3666_enabled:
|
if self._msc3666_enabled:
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
# Generate an iterable of EventBase for all the events that will be
|
# Generate an iterable of EventBase for all the events that will be
|
||||||
# returned, including contextual events.
|
# returned, including contextual events.
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
|
|
|
@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.logging.context import current_context
|
from synapse.logging.context import current_context
|
||||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
from synapse.storage.roommember import MemberSummary
|
from synapse.storage.roommember import MemberSummary
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -269,6 +269,7 @@ class SyncHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -638,8 +639,10 @@ class SyncHandler:
|
||||||
# as clients will have all the necessary information.
|
# as clients will have all the necessary information.
|
||||||
bundled_aggregations = None
|
bundled_aggregations = None
|
||||||
if limited or newly_joined_room:
|
if limited or newly_joined_room:
|
||||||
bundled_aggregations = await self.store.get_bundled_aggregations(
|
bundled_aggregations = (
|
||||||
recents, sync_config.user.to_string()
|
await self._relations_handler.get_bundled_aggregations(
|
||||||
|
recents, sync_config.user.to_string()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return TimelineBatch(
|
return TimelineBatch(
|
||||||
|
|
|
@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self.event_handler = hs.get_event_handler()
|
self.event_handler = hs.get_event_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
|
@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
# Ensure there are bundled aggregations available.
|
# Ensure there are bundled aggregations available.
|
||||||
aggregations = await self._store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
[event], requester.user.to_string()
|
[event], requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
|
||||||
|
|
||||||
from synapse.api.constants import RelationTypes
|
from synapse.api.constants import RelationTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -41,45 +40,15 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
||||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
from synapse.types import RoomStreamToken, StreamToken
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main import DataStore
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
|
||||||
class _ThreadAggregation:
|
|
||||||
# The latest event in the thread.
|
|
||||||
latest_event: EventBase
|
|
||||||
# The latest edit to the latest event in the thread.
|
|
||||||
latest_edit: Optional[EventBase]
|
|
||||||
# The total number of events in the thread.
|
|
||||||
count: int
|
|
||||||
# True if the current user has sent an event to the thread.
|
|
||||||
current_user_participated: bool
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
|
||||||
class BundledAggregations:
|
|
||||||
"""
|
|
||||||
The bundled aggregations for an event.
|
|
||||||
|
|
||||||
Some values require additional processing during serialization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
annotations: Optional[JsonDict] = None
|
|
||||||
references: Optional[JsonDict] = None
|
|
||||||
replace: Optional[EventBase] = None
|
|
||||||
thread: Optional[_ThreadAggregation] = None
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return bool(self.annotations or self.references or self.replace or self.thread)
|
|
||||||
|
|
||||||
|
|
||||||
class RelationsWorkerStore(SQLBaseStore):
|
class RelationsWorkerStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
|
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
|
||||||
async def _get_applicable_edits(
|
async def get_applicable_edits(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, Optional[EventBase]]:
|
) -> Dict[str, Optional[EventBase]]:
|
||||||
"""Get the most recent edit (if any) that has happened for the given
|
"""Get the most recent edit (if any) that has happened for the given
|
||||||
|
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
|
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
|
||||||
async def _get_thread_summaries(
|
async def get_thread_summaries(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
|
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
|
||||||
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
|
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
|
||||||
|
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# Check to see if any of those events are edited.
|
# Check to see if any of those events are edited.
|
||||||
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
|
latest_edits = await self.get_applicable_edits(latest_event_ids.values())
|
||||||
|
|
||||||
# Map to the event IDs to the thread summary.
|
# Map to the event IDs to the thread summary.
|
||||||
#
|
#
|
||||||
|
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
|
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
|
||||||
async def _get_threads_participated(
|
async def get_threads_participated(
|
||||||
self, event_ids: Collection[str], user_id: str
|
self, event_ids: Collection[str], user_id: str
|
||||||
) -> Dict[str, bool]:
|
) -> Dict[str, bool]:
|
||||||
"""Get whether the requesting user participated in the given threads.
|
"""Get whether the requesting user participated in the given threads.
|
||||||
|
@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_bundled_aggregation_for_event(
|
|
||||||
self, event: EventBase, user_id: str
|
|
||||||
) -> Optional[BundledAggregations]:
|
|
||||||
"""Generate bundled aggregations for an event.
|
|
||||||
|
|
||||||
Note that this does not use a cache, but depends on cached methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The event to calculate bundled aggregations for.
|
|
||||||
user_id: The user requesting the bundled aggregations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The bundled aggregations for an event, if bundled aggregations are
|
|
||||||
enabled and the event can have bundled aggregations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Do not bundle aggregations for an event which represents an edit or an
|
|
||||||
# annotation. It does not make sense for them to have related events.
|
|
||||||
relates_to = event.content.get("m.relates_to")
|
|
||||||
if isinstance(relates_to, (dict, frozendict)):
|
|
||||||
relation_type = relates_to.get("rel_type")
|
|
||||||
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
|
||||||
return None
|
|
||||||
|
|
||||||
event_id = event.event_id
|
|
||||||
room_id = event.room_id
|
|
||||||
|
|
||||||
# The bundled aggregations to include, a mapping of relation type to a
|
|
||||||
# type-specific value. Some types include the direct return type here
|
|
||||||
# while others need more processing during serialization.
|
|
||||||
aggregations = BundledAggregations()
|
|
||||||
|
|
||||||
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
|
||||||
if annotations.chunk:
|
|
||||||
aggregations.annotations = await annotations.to_dict(
|
|
||||||
cast("DataStore", self)
|
|
||||||
)
|
|
||||||
|
|
||||||
references = await self.get_relations_for_event(
|
|
||||||
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
|
||||||
)
|
|
||||||
if references.chunk:
|
|
||||||
aggregations.references = await references.to_dict(cast("DataStore", self))
|
|
||||||
|
|
||||||
# Store the bundled aggregations in the event metadata for later use.
|
|
||||||
return aggregations
|
|
||||||
|
|
||||||
async def get_bundled_aggregations(
|
|
||||||
self, events: Iterable[EventBase], user_id: str
|
|
||||||
) -> Dict[str, BundledAggregations]:
|
|
||||||
"""Generate bundled aggregations for events.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
events: The iterable of events to calculate bundled aggregations for.
|
|
||||||
user_id: The user requesting the bundled aggregations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A map of event ID to the bundled aggregation for the event. Not all
|
|
||||||
events may have bundled aggregations in the results.
|
|
||||||
"""
|
|
||||||
# De-duplicate events by ID to handle the same event requested multiple times.
|
|
||||||
#
|
|
||||||
# State events do not get bundled aggregations.
|
|
||||||
events_by_id = {
|
|
||||||
event.event_id: event for event in events if not event.is_state()
|
|
||||||
}
|
|
||||||
|
|
||||||
# event ID -> bundled aggregation in non-serialized form.
|
|
||||||
results: Dict[str, BundledAggregations] = {}
|
|
||||||
|
|
||||||
# Fetch other relations per event.
|
|
||||||
for event in events_by_id.values():
|
|
||||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
|
||||||
if event_result:
|
|
||||||
results[event.event_id] = event_result
|
|
||||||
|
|
||||||
# Fetch any edits (but not for redacted events).
|
|
||||||
edits = await self._get_applicable_edits(
|
|
||||||
[
|
|
||||||
event_id
|
|
||||||
for event_id, event in events_by_id.items()
|
|
||||||
if not event.internal_metadata.is_redacted()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for event_id, edit in edits.items():
|
|
||||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
|
||||||
|
|
||||||
# Fetch thread summaries.
|
|
||||||
summaries = await self._get_thread_summaries(events_by_id.keys())
|
|
||||||
# Only fetch participated for a limited selection based on what had
|
|
||||||
# summaries.
|
|
||||||
participated = await self._get_threads_participated(
|
|
||||||
[event_id for event_id, summary in summaries.items() if summary], user_id
|
|
||||||
)
|
|
||||||
for event_id, summary in summaries.items():
|
|
||||||
if summary:
|
|
||||||
thread_count, latest_thread_event, edit = summary
|
|
||||||
results.setdefault(
|
|
||||||
event_id, BundledAggregations()
|
|
||||||
).thread = _ThreadAggregation(
|
|
||||||
latest_event=latest_thread_event,
|
|
||||||
latest_edit=edit,
|
|
||||||
count=thread_count,
|
|
||||||
# If there's a thread summary it must also exist in the
|
|
||||||
# participated dictionary.
|
|
||||||
current_user_participated=participated[event_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class RelationsStore(RelationsWorkerStore):
|
class RelationsStore(RelationsWorkerStore):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue