Speed up updating state in large rooms (#15971)
This should speed up updating state in rooms with lots of state.
This commit is contained in:
parent
835174180b
commit
fc1e534e41
|
@ -0,0 +1 @@
|
|||
Speed up updating state in large rooms.
|
|
@ -1565,12 +1565,11 @@ class EventCreationHandler:
|
|||
if state_entry.state_group in self._external_cache_joined_hosts_updates:
|
||||
return
|
||||
|
||||
state = await state_entry.get_state(
|
||||
self._storage_controllers.state, StateFilter.all()
|
||||
)
|
||||
with opentracing.start_active_span("get_joined_hosts"):
|
||||
joined_hosts = await self.store.get_joined_hosts(
|
||||
event.room_id, state, state_entry
|
||||
joined_hosts = (
|
||||
await self._storage_controllers.state.get_joined_hosts(
|
||||
event.room_id, state_entry
|
||||
)
|
||||
)
|
||||
|
||||
# Note that the expiry times must be larger than the expiry time in
|
||||
|
|
|
@ -268,8 +268,7 @@ class StateHandler:
|
|||
The hosts in the room at the given events
|
||||
"""
|
||||
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
|
||||
return await self.store.get_joined_hosts(room_id, state, entry)
|
||||
return await self._state_storage_controller.get_joined_hosts(room_id, entry)
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
|
@ -19,14 +20,16 @@ from typing import (
|
|||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.opentracing import tag_args, trace
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
|
@ -34,14 +37,20 @@ from synapse.storage.util.partial_state_events_tracker import (
|
|||
PartialCurrentStateTracker,
|
||||
PartialStateEventsTracker,
|
||||
)
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.state import _StateCacheEntry
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -52,10 +61,15 @@ class StateStorageController:
|
|||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self._clock = hs.get_clock()
|
||||
self.stores = stores
|
||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
|
||||
|
||||
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
|
||||
# at a time. Keyed by room_id.
|
||||
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
|
||||
|
||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||
|
||||
|
@ -627,3 +641,122 @@ class StateStorageController:
|
|||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
return await self.stores.main.get_users_in_room_with_profiles(room_id)
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_hosts"):
|
||||
return await self._get_joined_hosts(
|
||||
room_id, state_group, state_entry=state_entry
|
||||
)
|
||||
|
||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||
async def _get_joined_hosts(
|
||||
self,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
state_entry: "_StateCacheEntry",
|
||||
) -> FrozenSet[str]:
|
||||
# We don't use `state_group`, it's there so that we can cache based on
|
||||
# it. However, its important that its never None, since two
|
||||
# current_state's with a state_group of None are likely to be different.
|
||||
#
|
||||
# The `state_group` must match the `state_entry.state_group` (if not None).
|
||||
assert state_group is not None
|
||||
assert state_entry.state_group is None or state_entry.state_group == state_group
|
||||
|
||||
# We use a secondary cache of previous work to allow us to build up the
|
||||
# joined hosts for the given state group based on previous state groups.
|
||||
#
|
||||
# We cache one object per room containing the results of the last state
|
||||
# group we got joined hosts for. The idea is that generally
|
||||
# `get_joined_hosts` is called with the "current" state group for the
|
||||
# room, and so consecutive calls will be for consecutive state groups
|
||||
# which point to the previous state group.
|
||||
cache = await self.stores.main._get_joined_hosts_cache(room_id)
|
||||
|
||||
# If the state group in the cache matches, we already have the data we need.
|
||||
if state_entry.state_group == cache.state_group:
|
||||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
||||
# Since we'll mutate the cache we need to lock.
|
||||
async with self._joined_host_linearizer.queue(room_id):
|
||||
if state_entry.state_group == cache.state_group:
|
||||
# Same state group, so nothing to do. We've already checked for
|
||||
# this above, but the cache may have changed while waiting on
|
||||
# the lock.
|
||||
pass
|
||||
elif state_entry.prev_group == cache.state_group:
|
||||
# The cached work is for the previous state group, so we work out
|
||||
# the delta.
|
||||
assert state_entry.delta_ids is not None
|
||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
|
||||
host = intern_string(get_domain_from_id(state_key))
|
||||
user_id = state_key
|
||||
known_joins = cache.hosts_to_joined_users.setdefault(host, set())
|
||||
|
||||
event = await self.stores.main.get_event(event_id)
|
||||
if event.membership == Membership.JOIN:
|
||||
known_joins.add(user_id)
|
||||
else:
|
||||
known_joins.discard(user_id)
|
||||
|
||||
if not known_joins:
|
||||
cache.hosts_to_joined_users.pop(host, None)
|
||||
else:
|
||||
# The cache doesn't match the state group or prev state group,
|
||||
# so we calculate the result from first principles.
|
||||
#
|
||||
# We need to fetch all hosts joined to the room according to `state` by
|
||||
# inspecting all join memberships in `state`. However, if the `state` is
|
||||
# relatively recent then many of its events are likely to be held in
|
||||
# the current state of the room, which is easily available and likely
|
||||
# cached.
|
||||
#
|
||||
# We therefore compute the set of `state` events not in the
|
||||
# current state and only fetch those.
|
||||
current_memberships = (
|
||||
await self.stores.main._get_approximate_current_memberships_in_room(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
unknown_state_events = {}
|
||||
joined_users_in_current_state = []
|
||||
|
||||
state = await state_entry.get_state(
|
||||
self, StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
|
||||
for (type, state_key), event_id in state.items():
|
||||
if event_id not in current_memberships:
|
||||
unknown_state_events[type, state_key] = event_id
|
||||
elif current_memberships[event_id] == Membership.JOIN:
|
||||
joined_users_in_current_state.append(state_key)
|
||||
|
||||
joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
|
||||
room_id, unknown_state_events
|
||||
)
|
||||
|
||||
cache.hosts_to_joined_users = {}
|
||||
for user_id in chain(joined_user_ids, joined_users_in_current_state):
|
||||
host = intern_string(get_domain_from_id(user_id))
|
||||
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
|
||||
|
||||
if state_entry.state_group:
|
||||
cache.state_group = state_entry.state_group
|
||||
else:
|
||||
cache.state_group = object()
|
||||
|
||||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
|
@ -57,15 +56,12 @@ from synapse.types import (
|
|||
StrCollection,
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.state import _StateCacheEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
|
||||
# at a time. Keyed by room_id.
|
||||
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
|
||||
|
||||
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
|
||||
|
||||
if (
|
||||
|
@ -1057,120 +1049,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
||||
)
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_hosts"):
|
||||
return await self._get_joined_hosts(
|
||||
room_id, state_group, state, state_entry=state_entry
|
||||
)
|
||||
|
||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||
async def _get_joined_hosts(
|
||||
self,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
state: StateMap[str],
|
||||
state_entry: "_StateCacheEntry",
|
||||
) -> FrozenSet[str]:
|
||||
# We don't use `state_group`, it's there so that we can cache based on
|
||||
# it. However, its important that its never None, since two
|
||||
# current_state's with a state_group of None are likely to be different.
|
||||
#
|
||||
# The `state_group` must match the `state_entry.state_group` (if not None).
|
||||
assert state_group is not None
|
||||
assert state_entry.state_group is None or state_entry.state_group == state_group
|
||||
|
||||
# We use a secondary cache of previous work to allow us to build up the
|
||||
# joined hosts for the given state group based on previous state groups.
|
||||
#
|
||||
# We cache one object per room containing the results of the last state
|
||||
# group we got joined hosts for. The idea is that generally
|
||||
# `get_joined_hosts` is called with the "current" state group for the
|
||||
# room, and so consecutive calls will be for consecutive state groups
|
||||
# which point to the previous state group.
|
||||
cache = await self._get_joined_hosts_cache(room_id)
|
||||
|
||||
# If the state group in the cache matches, we already have the data we need.
|
||||
if state_entry.state_group == cache.state_group:
|
||||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
||||
# Since we'll mutate the cache we need to lock.
|
||||
async with self._joined_host_linearizer.queue(room_id):
|
||||
if state_entry.state_group == cache.state_group:
|
||||
# Same state group, so nothing to do. We've already checked for
|
||||
# this above, but the cache may have changed while waiting on
|
||||
# the lock.
|
||||
pass
|
||||
elif state_entry.prev_group == cache.state_group:
|
||||
# The cached work is for the previous state group, so we work out
|
||||
# the delta.
|
||||
assert state_entry.delta_ids is not None
|
||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
|
||||
host = intern_string(get_domain_from_id(state_key))
|
||||
user_id = state_key
|
||||
known_joins = cache.hosts_to_joined_users.setdefault(host, set())
|
||||
|
||||
event = await self.get_event(event_id)
|
||||
if event.membership == Membership.JOIN:
|
||||
known_joins.add(user_id)
|
||||
else:
|
||||
known_joins.discard(user_id)
|
||||
|
||||
if not known_joins:
|
||||
cache.hosts_to_joined_users.pop(host, None)
|
||||
else:
|
||||
# The cache doesn't match the state group or prev state group,
|
||||
# so we calculate the result from first principles.
|
||||
#
|
||||
# We need to fetch all hosts joined to the room according to `state` by
|
||||
# inspecting all join memberships in `state`. However, if the `state` is
|
||||
# relatively recent then many of its events are likely to be held in
|
||||
# the current state of the room, which is easily available and likely
|
||||
# cached.
|
||||
#
|
||||
# We therefore compute the set of `state` events not in the
|
||||
# current state and only fetch those.
|
||||
current_memberships = (
|
||||
await self._get_approximate_current_memberships_in_room(room_id)
|
||||
)
|
||||
unknown_state_events = {}
|
||||
joined_users_in_current_state = []
|
||||
|
||||
for (type, state_key), event_id in state.items():
|
||||
if event_id not in current_memberships:
|
||||
unknown_state_events[type, state_key] = event_id
|
||||
elif current_memberships[event_id] == Membership.JOIN:
|
||||
joined_users_in_current_state.append(state_key)
|
||||
|
||||
joined_user_ids = await self.get_joined_user_ids_from_state(
|
||||
room_id, unknown_state_events
|
||||
)
|
||||
|
||||
cache.hosts_to_joined_users = {}
|
||||
for user_id in chain(joined_user_ids, joined_users_in_current_state):
|
||||
host = intern_string(get_domain_from_id(user_id))
|
||||
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
|
||||
|
||||
if state_entry.state_group:
|
||||
cache.state_group = state_entry.state_group
|
||||
else:
|
||||
cache.state_group = object()
|
||||
|
||||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
||||
async def _get_approximate_current_memberships_in_room(
|
||||
self, room_id: str
|
||||
) -> Mapping[str, Optional[str]]:
|
||||
|
|
Loading…
Reference in New Issue