Fix performance regression in `get_users_in_room` (#13972)

Fixes #13942. Introduced in #13575.

Basically, let's only get the ordered set of hosts out of the DB if we need an ordered set of hosts. Since we split the function up the caching won't be as good, but I think it will still be fine as e.g. multiple backfill requests for the same room will hit the cache.
This commit is contained in:
Erik Johnston 2022-09-30 13:15:32 +01:00 committed by GitHub
parent e8f30a76ca
commit 3dfc4a08dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 69 deletions

1
changelog.d/13972.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a performance regression in the `get_users_in_room` database query. Introduced in v1.67.0.

View File

@ -412,7 +412,9 @@ class FederationHandler:
# First we try hosts that are already in the room.
# TODO: HEURISTIC ALERT.
likely_domains = (
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
room_id
)
)
async def try_backfill(domains: Collection[str]) -> bool:

View File

@ -1540,7 +1540,9 @@ class TimestampLookupHandler:
)
likely_domains = (
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
room_id
)
)
# Loop through each homeserver candidate until we get a succesful response

View File

@ -23,7 +23,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
@ -529,7 +529,18 @@ class StateStorageController:
)
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
with partial state.
"""
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@ -542,11 +553,11 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
async def get_current_hosts_in_room_or_partial_state_approximation(
self, room_id: str
) -> Sequence[str]:
) -> Collection[str]:
"""Get approximation of current hosts in room based on current state.
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
@ -566,14 +577,9 @@ class StateStorageController:
)
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
hosts_from_state_set = set(hosts_from_state)
# First take the list of hosts based on the current state.
# For rooms with partial state, this will be missing most hosts.
hosts = list(hosts_from_state)
# Then add in the list of hosts in the room at the time we joined.
# This will be an empty list for rooms with full state.
hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
hosts = set(hosts_at_join)
hosts.update(hosts_from_state)
return hosts

View File

@ -146,42 +146,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
"""
Returns a list of users in the room sorted by longest in the room first
(aka. with the lowest depth). This is done to match the sort in
`get_current_hosts_in_room()` and so we can re-use the cache but it's
not horrible to have here either.
Uses `m.room.member`s in the room state at the current forward extremities to
determine which users are in the room.
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a member is or
is not in the room when the opposite is true.
"""
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
return await self.db_pool.simple_select_onecol(
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"room_id": room_id,
"membership": Membership.JOIN,
},
retcol="state_key",
desc="get_users_in_room",
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
"""
Returns a list of users in the room sorted by longest in the room first
(aka. with the lowest depth). This is done to match the sort in
`get_current_hosts_in_room()` and so we can re-use the cache but it's
not horrible to have here either.
"""
sql = """
SELECT c.state_key FROM current_state_events as c
/* Get the depth of the event from the events table */
INNER JOIN events AS e USING (event_id)
WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
/* Sorted by lowest depth first */
ORDER BY e.depth ASC;
"""
"""Returns a list of users in the room."""
txn.execute(sql, (room_id, Membership.JOIN))
return [r[0] for r in txn]
return self.db_pool.simple_select_onecol_txn(
txn,
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"room_id": room_id,
"membership": Membership.JOIN,
},
retcol="state_key",
)
@cached()
def get_user_in_room_with_profile(
@ -931,7 +926,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
"""
Get current hosts in room based on current state.
@ -939,48 +971,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
longest is good because they're most likely to have anything we ask
about.
Uses `m.room.member`s in the room state at the current forward extremities to
determine which hosts are in the room.
For SQLite the returned list is not ordered, as SQLite doesn't support
the appropriate SQL.
Will return inaccurate results for rooms with partial state, since the state for
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a host is or
is not in the room when the opposite is true.
Uses `m.room.member`s in the room state at the current forward
extremities to determine which hosts are in the room.
Will return inaccurate results for rooms with partial state, since the
state for the forward extremities of those rooms will exclude most
members. We may also calculate room state incorrectly for such rooms and
believe that a host is or is not in the room when the opposite is true.
Returns:
Returns a list of servers sorted by longest in the room first. (aka.
sorted by join with the lowest depth first).
"""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is None and isinstance(self.database_engine, Sqlite3Engine):
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
if users is not None:
# Because `users` is sorted from lowest -> highest depth, the list
# of domains will also be sorted that way.
domains: List[str] = []
# We use a `Set` just for fast lookups
domain_set: Set[str] = set()
for u in users:
if ":" not in u:
continue
domain = get_domain_from_id(u)
if domain not in domain_set:
domain_set.add(domain)
domains.append(domain)
return domains
domains = await self.get_current_hosts_in_room(room_id)
return list(domains)
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
# Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the
@ -1008,7 +1025,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [d for d, in txn if d is not None]
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)
async def get_joined_hosts(