Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible (#12791)
This commit is contained in:
parent
2be5a2b07b
commit
71e8afe34d
|
@ -0,0 +1 @@
|
|||
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible.
|
|
@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
|
|||
if TYPE_CHECKING:
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.state import StateFilter
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
|
@ -196,7 +197,9 @@ class EventContext:
|
|||
|
||||
return self._state_group
|
||||
|
||||
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||
async def get_current_state_ids(
|
||||
self, state_filter: Optional["StateFilter"] = None
|
||||
) -> Optional[StateMap[str]]:
|
||||
"""
|
||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||
|
||||
|
@ -204,6 +207,9 @@ class EventContext:
|
|||
not make it into the room state. This method will raise an exception if
|
||||
``rejected`` is set.
|
||||
|
||||
Arg:
|
||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||
|
||||
Returns:
|
||||
Returns None if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
|
@ -216,7 +222,7 @@ class EventContext:
|
|||
|
||||
assert self._state_delta_due_to_event is not None
|
||||
|
||||
prev_state_ids = await self.get_prev_state_ids()
|
||||
prev_state_ids = await self.get_prev_state_ids(state_filter)
|
||||
|
||||
if self._state_delta_due_to_event:
|
||||
prev_state_ids = dict(prev_state_ids)
|
||||
|
@ -224,12 +230,17 @@ class EventContext:
|
|||
|
||||
return prev_state_ids
|
||||
|
||||
async def get_prev_state_ids(self) -> StateMap[str]:
|
||||
async def get_prev_state_ids(
|
||||
self, state_filter: Optional["StateFilter"] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Gets the room state map, excluding this event.
|
||||
|
||||
For a non-state event, this will be the same as get_current_state_ids().
|
||||
|
||||
Args:
|
||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||
|
||||
Returns:
|
||||
Returns {} if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
|
@ -239,7 +250,7 @@ class EventContext:
|
|||
"""
|
||||
assert self.state_group_before_event is not None
|
||||
return await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group_before_event
|
||||
self.state_group_before_event, state_filter
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ from synapse.replication.http.federation import (
|
|||
ReplicationStoreRoomOnOutlierMembershipRestServlet,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
@ -1259,7 +1260,9 @@ class FederationHandler:
|
|||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
original_invite = None
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
|
||||
)
|
||||
original_invite_id = prev_state_ids.get(key)
|
||||
if original_invite_id:
|
||||
original_invite = await self.store.get_event(
|
||||
|
@ -1308,7 +1311,9 @@ class FederationHandler:
|
|||
signed = event.content["third_party_invite"]["signed"]
|
||||
token = signed["token"]
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
|
||||
)
|
||||
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
|
||||
|
||||
invite_event = None
|
||||
|
|
|
@ -30,6 +30,7 @@ from typing import (
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
|
@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
|
|||
)
|
||||
from synapse.state import StateResolutionStore
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
PersistedEventPosition,
|
||||
RoomStreamToken,
|
||||
|
@ -1500,7 +1502,11 @@ class FederationEventHandler:
|
|||
return context
|
||||
|
||||
# now check auth against what we think the auth events *should* be.
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
event_types = event_auth.auth_types_for_event(event.room_version, event)
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types(event_types)
|
||||
)
|
||||
|
||||
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
|
|
|
@ -634,7 +634,9 @@ class EventCreationHandler:
|
|||
# federation as well as those created locally. As of room v3, aliases events
|
||||
# can be created by users that are not in the room, therefore we have to
|
||||
# tolerate them in event_auth.check().
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
||||
prev_event = (
|
||||
await self.store.get_event(prev_event_id, allow_none=True)
|
||||
|
@ -761,7 +763,9 @@ class EventCreationHandler:
|
|||
# This can happen due to out of band memberships
|
||||
return None
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(event.type, None)])
|
||||
)
|
||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if not prev_event_id:
|
||||
return None
|
||||
|
@ -1547,7 +1551,11 @@ class EventCreationHandler:
|
|||
"Redacting MSC2716 events is not supported in this room version",
|
||||
)
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
event_types = event_auth.auth_types_for_event(event.room_version, event)
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types(event_types)
|
||||
)
|
||||
|
||||
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
|
|
|
@ -303,7 +303,10 @@ class RoomCreationHandler:
|
|||
context=tombstone_context,
|
||||
)
|
||||
|
||||
old_room_state = await tombstone_context.get_current_state_ids()
|
||||
state_filter = StateFilter.from_types(
|
||||
[(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
|
||||
)
|
||||
old_room_state = await tombstone_context.get_current_state_ids(state_filter)
|
||||
|
||||
# We know the tombstone event isn't an outlier so it has current state.
|
||||
assert old_room_state is not None
|
||||
|
|
|
@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
|
|||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
|
@ -362,7 +363,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
historical=historical,
|
||||
)
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
|
||||
|
@ -1160,7 +1163,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
else:
|
||||
requester = types.create_requester(target_user)
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.GuestAccess, None)])
|
||||
)
|
||||
if event.membership == Membership.JOIN:
|
||||
if requester.is_guest:
|
||||
guest_can_join = await self._can_guest_join(prev_state_ids)
|
||||
|
|
|
@ -20,7 +20,7 @@ import attr
|
|||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.event_auth import auth_types_for_event, get_user_power_level
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import POWER_KEY
|
||||
|
@ -31,6 +31,7 @@ from synapse.util.caches.descriptors import lru_cache
|
|||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.metrics import measure_func
|
||||
|
||||
from ..storage.state import StateFilter
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -168,8 +169,12 @@ class BulkPushRuleEvaluator:
|
|||
async def _get_power_levels_and_sender_level(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[dict, int]:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
event_types = auth_types_for_event(event.room_version, event)
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types(event_types)
|
||||
)
|
||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
|
|
|
@ -634,16 +634,19 @@ class StateGroupStorage:
|
|||
|
||||
return group_to_state
|
||||
|
||||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
|
||||
async def get_state_ids_for_group(
|
||||
self, state_group: int, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group: A state group for which we want to get the state IDs.
|
||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||
|
||||
Returns:
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = await self.get_state_for_groups((state_group,))
|
||||
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ class _DummyStore:
|
|||
|
||||
return groups
|
||||
|
||||
async def get_state_ids_for_group(self, state_group):
|
||||
async def get_state_ids_for_group(self, state_group, state_filter=None):
|
||||
return self._group_to_state[state_group]
|
||||
|
||||
async def store_state_group(
|
||||
|
|
Loading…
Reference in New Issue