Fetch thread summaries for multiple events in a single query (#11752)

This should reduce database usage when fetching bundled aggregations
as the number of individual queries (and round trips to the database) are
reduced.
This commit is contained in:
Patrick Cloke 2022-02-11 09:50:14 -05:00 committed by GitHub
parent bb98c593a5
commit b65acead42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 74 deletions

1
changelog.d/11752.misc Normal file
View File

@ -0,0 +1 @@
Improve performance when fetching bundled aggregations for multiple events.

View File

@ -1812,7 +1812,7 @@ class PersistEventsStore:
# potentially error-prone) so it is always invalidated. # potentially error-prone) so it is always invalidated.
txn.call_after( txn.call_after(
self.store.get_thread_participated.invalidate, self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender), (parent_id, event.sender),
) )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):

View File

@ -20,6 +20,7 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Set,
Tuple, Tuple,
Union, Union,
cast, cast,
@ -454,106 +455,175 @@ class RelationsWorkerStore(SQLBaseStore):
} }
@cached() @cached()
async def get_thread_summary( def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
self, event_id: str, room_id: str raise NotImplementedError()
) -> Tuple[int, Optional[EventBase]]:
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event. """Get the number of threaded replies and the latest reply (if any) for the given event.
Args: Args:
event_id: Summarize the thread related to this event ID. event_ids: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
Returns: Returns:
The number of items in the thread and the most recent response, if any. A map of the thread summary each event. A missing event implies there
are no threaded replies.
Each summary includes the number of items in the thread and the most
recent response.
""" """
def _get_thread_summary_txn( def _get_thread_summaries_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]: ) -> Tuple[Dict[str, int], Dict[str, str]]:
# Fetch the latest event ID in the thread. # Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events. # TODO Should this only allow m.room.message events.
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
# so ordering by topologica ordering + stream ordering desc will
# ensure we get the latest event in the thread.
sql = """ sql = """
SELECT event_id SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
FROM event_relations INNER JOIN event_relations USING (event_id)
INNER JOIN events USING (event_id) INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE WHERE
relates_to_id = ? %s
AND room_id = ?
AND relation_type = ? AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
LIMIT 1 """
else:
# SQLite uses a simplified query which returns all entries for a
# thread. The first result for each thread is chosen to and subsequent
# results for a thread are ignored.
sql = """
SELECT parent.event_id, child.event_id FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
""" """
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) clause, args = make_in_list_sql_clause(
row = txn.fetchone() txn.database_engine, "relates_to_id", event_ids
if row is None: )
return 0, None args.append(RelationTypes.THREAD)
latest_event_id = row[0] txn.execute(sql % (clause,), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
if parent_event_id not in latest_event_ids:
latest_event_ids[parent_event_id] = child_event_id
# If no threads were found, bail.
if not latest_event_ids:
return {}, latest_event_ids
# Fetch the number of threaded replies. # Fetch the number of threaded replies.
sql = """ sql = """
SELECT COUNT(event_id) SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
FROM event_relations INNER JOIN event_relations USING (event_id)
INNER JOIN events USING (event_id) INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE WHERE
relates_to_id = ? %s
AND room_id = ?
AND relation_type = ? AND relation_type = ?
GROUP BY parent.event_id
""" """
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = cast(Tuple[int], txn.fetchone())[0]
return count, latest_event_id # Regenerate the arguments since only threads found above could
# possibly have any replies.
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
args.append(RelationTypes.THREAD)
count, latest_event_id = await self.db_pool.runInteraction( txn.execute(sql % (clause,), args)
"get_thread_summary", _get_thread_summary_txn counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
return counts, latest_event_ids
counts, latest_event_ids = await self.db_pool.runInteraction(
"get_thread_summaries", _get_thread_summaries_txn
) )
latest_event = None latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
return count, latest_event # Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
# due to the latest event not being known, either case is treated the same.
summaries = {}
for parent_event_id, latest_event_id in latest_event_ids.items():
latest_event = latest_events.get(latest_event_id)
summary = None
if latest_event:
summary = (counts[parent_event_id], latest_event)
summaries[parent_event_id] = summary
return summaries
@cached() @cached()
async def get_thread_participated( def get_thread_participated(self, event_id: str, user_id: str) -> bool:
self, event_id: str, room_id: str, user_id: str raise NotImplementedError()
) -> bool:
"""Get whether the requesting user participated in a thread.
This is separate from get_thread_summary since that can be cached across @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
all users while this value is specific to the requeser. async def _get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across
all users while this value is specific to the requester.
Args: Args:
event_id: The thread related to this event ID. event_ids: The thread related to these event IDs.
room_id: The room the event belongs to.
user_id: The user requesting the summary. user_id: The user requesting the summary.
Returns: Returns:
True if the requesting user participated in the thread, otherwise false. A map of event ID to a boolean which represents if the requesting
user participated in that event's thread, otherwise false.
""" """
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not. # Fetch whether the requester has participated or not.
sql = """ sql = """
SELECT 1 SELECT DISTINCT relates_to_id
FROM event_relations FROM events AS child
INNER JOIN events USING (event_id) INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE WHERE
relates_to_id = ? %s
AND room_id = ?
AND relation_type = ? AND relation_type = ?
AND sender = ? AND child.sender = ?
""" """
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) clause, args = make_in_list_sql_clause(
return bool(txn.fetchone()) txn.database_engine, "relates_to_id", event_ids
)
args.extend((RelationTypes.THREAD, user_id))
return await self.db_pool.runInteraction( txn.execute(sql % (clause,), args)
return {row[0] for row in txn.fetchall()}
participated_threads = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn "get_thread_summary", _get_thread_summary_txn
) )
return {event_id: event_id in participated_threads for event_id in event_ids}
async def events_have_relations( async def events_have_relations(
self, self,
parent_ids: List[str], parent_ids: List[str],
@ -700,21 +770,6 @@ class RelationsWorkerStore(SQLBaseStore):
if references.chunk: if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self)) aggregations.references = await references.to_dict(cast("DataStore", self))
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations.thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
current_user_participated=participated,
)
# Store the bundled aggregations in the event metadata for later use. # Store the bundled aggregations in the event metadata for later use.
return aggregations return aggregations
@ -763,6 +818,27 @@ class RelationsWorkerStore(SQLBaseStore):
for event_id, edit in edits.items(): for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
if self._msc3440_enabled:
summaries = await self._get_thread_summaries(seen_event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
summaries.keys(), user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results return results