Replace `EventContext` fields `prev_group` and `delta_ids` with field `state_group_deltas` (#15233)

This commit is contained in:
Shay 2023-06-13 13:22:06 -07:00 committed by GitHub
parent 59ec4a0dc1
commit 553f2f53e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 126 additions and 58 deletions

1
changelog.d/15233.misc Normal file
View File

@ -0,0 +1 @@
Replace `EventContext` fields `prev_group` and `delta_ids` with field `state_group_deltas`.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr import attr
from immutabledict import immutabledict from immutabledict import immutabledict
@ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase):
state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
then this is the delta of the state between the two groups. then this is the delta of the state between the two groups.
prev_group: If it is known, ``state_group``'s prev_group. Note that this being state_group_deltas: If not empty, this is a dict collecting a mapping of the state
None does not necessarily mean that ``state_group`` does not have difference between state groups.
a prev_group!
If the event is a state event, this is normally the same as The keys are a tuple of two integers: the initial group and final state group.
``state_group_before_event``. The corresponding value is a state map representing the state delta between
these state groups.
If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` The dictionary is expected to have at most two entries with state groups of:
will always also be ``None``.
Note that this *not* (necessarily) the state group associated with 1. The state group before the event and after the event.
``_prev_state_ids``. 2. The state group preceding the state group before the event and the
state group before the event.
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` This information is collected and stored as part of an optimization for persisting
and ``state_group``. events.
partial_state: if True, we may be storing this event with a temporary, partial_state: if True, we may be storing this event with a temporary,
incomplete state. incomplete state.
""" """
_storage: "StorageControllers" _storage: "StorageControllers"
state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
rejected: Optional[str] = None rejected: Optional[str] = None
_state_group: Optional[int] = None _state_group: Optional[int] = None
state_group_before_event: Optional[int] = None state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None _state_delta_due_to_event: Optional[StateMap[str]] = None
prev_group: Optional[int] = None
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None app_service: Optional[ApplicationService] = None
partial_state: bool = False partial_state: bool = False
@ -145,16 +144,14 @@ class EventContext(UnpersistedEventContextBase):
state_group_before_event: Optional[int], state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]], state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool, partial_state: bool,
prev_group: Optional[int] = None, state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext": ) -> "EventContext":
return EventContext( return EventContext(
storage=storage, storage=storage,
state_group=state_group, state_group=state_group,
state_group_before_event=state_group_before_event, state_group_before_event=state_group_before_event,
state_delta_due_to_event=state_delta_due_to_event, state_delta_due_to_event=state_delta_due_to_event,
prev_group=prev_group, state_group_deltas=state_group_deltas,
delta_ids=delta_ids,
partial_state=partial_state, partial_state=partial_state,
) )
@ -163,7 +160,7 @@ class EventContext(UnpersistedEventContextBase):
storage: "StorageControllers", storage: "StorageControllers",
) -> "EventContext": ) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event""" """Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage) return EventContext(storage=storage, state_group_deltas={})
async def persist(self, event: EventBase) -> "EventContext": async def persist(self, event: EventBase) -> "EventContext":
return self return self
@ -183,13 +180,15 @@ class EventContext(UnpersistedEventContextBase):
"state_group": self._state_group, "state_group": self._state_group,
"state_group_before_event": self.state_group_before_event, "state_group_before_event": self.state_group_before_event,
"rejected": self.rejected, "rejected": self.rejected,
"prev_group": self.prev_group, "state_group_deltas": _encode_state_group_delta(self.state_group_deltas),
"state_delta_due_to_event": _encode_state_dict( "state_delta_due_to_event": _encode_state_dict(
self._state_delta_due_to_event self._state_delta_due_to_event
), ),
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None, "app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state, "partial_state": self.partial_state,
# add dummy delta_ids and prev_group for backwards compatibility
"delta_ids": None,
"prev_group": None,
} }
@staticmethod @staticmethod
@ -204,17 +203,24 @@ class EventContext(UnpersistedEventContextBase):
Returns: Returns:
The event context. The event context.
""" """
# workaround for backwards/forwards compatibility: if the input doesn't have a value
# for "state_group_deltas" just assign an empty dict
state_group_deltas = input.get("state_group_deltas", None)
if state_group_deltas:
state_group_deltas = _decode_state_group_delta(state_group_deltas)
else:
state_group_deltas = {}
context = EventContext( context = EventContext(
# We use the state_group and prev_state_id stuff to pull the # We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids. # current_state_ids out of the DB and construct prev_state_ids.
storage=storage, storage=storage,
state_group=input["state_group"], state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"], state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"], state_group_deltas=state_group_deltas,
state_delta_due_to_event=_decode_state_dict( state_delta_due_to_event=_decode_state_dict(
input["state_delta_due_to_event"] input["state_delta_due_to_event"]
), ),
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"], rejected=input["rejected"],
partial_state=input.get("partial_state", False), partial_state=input.get("partial_state", False),
) )
@ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
_storage: "StorageControllers" _storage: "StorageControllers"
state_group_before_event: Optional[int] state_group_before_event: Optional[int]
state_group_after_event: Optional[int] state_group_after_event: Optional[int]
state_delta_due_to_event: Optional[dict] state_delta_due_to_event: Optional[StateMap[str]]
prev_group_for_state_group_before_event: Optional[int] prev_group_for_state_group_before_event: Optional[int]
delta_ids_to_state_group_before_event: Optional[StateMap[str]] delta_ids_to_state_group_before_event: Optional[StateMap[str]]
partial_state: bool partial_state: bool
@ -380,26 +386,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
events_and_persisted_context = [] events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context: for event, unpersisted_context in amended_events_and_context:
if event.is_state(): state_group_deltas = unpersisted_context._build_state_group_deltas()
context = EventContext(
storage=unpersisted_context._storage, context = EventContext(
state_group=unpersisted_context.state_group_after_event, storage=unpersisted_context._storage,
state_group_before_event=unpersisted_context.state_group_before_event, state_group=unpersisted_context.state_group_after_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, state_group_before_event=unpersisted_context.state_group_before_event,
partial_state=unpersisted_context.partial_state, state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
prev_group=unpersisted_context.state_group_before_event, partial_state=unpersisted_context.partial_state,
delta_ids=unpersisted_context.state_delta_due_to_event, state_group_deltas=state_group_deltas,
) )
else:
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
)
events_and_persisted_context.append((event, context)) events_and_persisted_context.append((event, context))
return events_and_persisted_context return events_and_persisted_context
@ -452,11 +448,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
# if the event isn't a state event the state group doesn't change # if the event isn't a state event the state group doesn't change
if not self.state_delta_due_to_event: if not self.state_delta_due_to_event:
state_group_after_event = self.state_group_before_event self.state_group_after_event = self.state_group_before_event
# otherwise if it is a state event we need to get a state group for it # otherwise if it is a state event we need to get a state group for it
else: else:
state_group_after_event = await self._storage.state.store_state_group( self.state_group_after_event = await self._storage.state.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=self.state_group_before_event, prev_group=self.state_group_before_event,
@ -464,16 +460,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
current_state_ids=None, current_state_ids=None,
) )
state_group_deltas = self._build_state_group_deltas()
return EventContext.with_state( return EventContext.with_state(
storage=self._storage, storage=self._storage,
state_group=state_group_after_event, state_group=self.state_group_after_event,
state_group_before_event=self.state_group_before_event, state_group_before_event=self.state_group_before_event,
state_delta_due_to_event=self.state_delta_due_to_event, state_delta_due_to_event=self.state_delta_due_to_event,
state_group_deltas=state_group_deltas,
partial_state=self.partial_state, partial_state=self.partial_state,
prev_group=self.state_group_before_event,
delta_ids=self.state_delta_due_to_event,
) )
def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:
"""
Collect deltas between the state groups associated with this context
"""
state_group_deltas = {}
# if we know the state group before the event and after the event, add them and the
# state delta between them to state_group_deltas
if self.state_group_before_event and self.state_group_after_event:
# if we have the state groups we should have the delta
assert self.state_delta_due_to_event is not None
state_group_deltas[
(
self.state_group_before_event,
self.state_group_after_event,
)
] = self.state_delta_due_to_event
# the state group before the event may also have a state group which precedes it, if
# we have that and the state group before the event, add them and the state
# delta between them to state_group_deltas
if (
self.prev_group_for_state_group_before_event
and self.state_group_before_event
):
# if we have both state groups we should have the delta between them
assert self.delta_ids_to_state_group_before_event is not None
state_group_deltas[
(
self.prev_group_for_state_group_before_event,
self.state_group_before_event,
)
] = self.delta_ids_to_state_group_before_event
return state_group_deltas
def _encode_state_group_delta(
state_group_delta: Dict[Tuple[int, int], StateMap[str]]
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
if not state_group_delta:
return []
state_group_delta_encoded = []
for key, value in state_group_delta.items():
state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value)))
return state_group_delta_encoded
def _decode_state_group_delta(
input: List[Tuple[int, int, List[Tuple[str, str, str]]]]
) -> Dict[Tuple[int, int], StateMap[str]]:
if not input:
return {}
state_group_deltas = {}
for state_group_1, state_group_2, state_dict in input:
state_map = _decode_state_dict(state_dict)
assert state_map is not None
state_group_deltas[(state_group_1, state_group_2)] = state_map
return state_group_deltas
def _encode_state_dict( def _encode_state_dict(
state_dict: Optional[StateMap[str]], state_dict: Optional[StateMap[str]],

View File

@ -839,9 +839,8 @@ class EventsPersistenceStorageController:
"group" % (ev.event_id,) "group" % (ev.event_id,)
) )
continue continue
if ctx.state_group_deltas:
if ctx.prev_group: state_group_deltas.update(ctx.state_group_deltas)
state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
# We need to map the event_ids to their state groups. First, let's # We need to map the event_ids to their state groups. First, let's
# check if the event is one we're persisting, in which case we can # check if the event is one we're persisting, in which case we can

View File

@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event context.state_group_before_event, d_context.state_group_before_event
) )
self.assertEqual(context.prev_group, d_context.prev_group) self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
self.assertEqual(context.delta_ids, d_context.delta_ids)
self.assertEqual(context.app_service, d_context.app_service) self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual( self.assertEqual(

View File

@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None assert persist_events_store is not None
persist_events_store._store_event_txn( persist_events_store._store_event_txn(
txn, txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events], [
(e, EventContext(self.hs.get_storage_controllers(), {}))
for e in events
],
) )
# Actually call the function that calculates the auth chain stuff. # Actually call the function that calculates the auth chain stuff.

View File

@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
(e.event_id for e in old_state + [event]), current_state_ids.values() (e.event_id for e in old_state + [event]), current_state_ids.values()
) )
self.assertIsNotNone(context.state_group_before_event) assert context.state_group_before_event is not None
assert context.state_group is not None
self.assertEqual(
context.state_group_deltas.get(
(context.state_group_before_event, context.state_group)
),
{(event.type, event.state_key): event.event_id},
)
self.assertNotEqual(context.state_group_before_event, context.state_group) self.assertNotEqual(context.state_group_before_event, context.state_group)
self.assertEqual(context.state_group_before_event, context.prev_group)
self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message( def test_trivial_annotate_message(