Simplify `_locally_reject_invite`
Update `EventCreationHandler.create_event` to accept an auth_events param, and use it in `_locally_reject_invite` instead of reinventing the wheel.
This commit is contained in:
parent
d9d86c2996
commit
a34b17e492
|
@ -97,32 +97,37 @@ class EventBuilder:
|
||||||
def is_state(self):
|
def is_state(self):
|
||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
async def build(self, prev_event_ids: List[str]) -> EventBase:
|
async def build(
|
||||||
|
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
|
||||||
|
) -> EventBase:
|
||||||
"""Transform into a fully signed and hashed event
|
"""Transform into a fully signed and hashed event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_event_ids: The event IDs to use as the prev events
|
prev_event_ids: The event IDs to use as the prev events
|
||||||
|
auth_event_ids: The event IDs to use as the auth events.
|
||||||
|
Should normally be set to None, which will cause them to be calculated
|
||||||
|
based on the room state at the prev_events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The signed and hashed event.
|
The signed and hashed event.
|
||||||
"""
|
"""
|
||||||
|
if auth_event_ids is None:
|
||||||
state_ids = await self._state.get_current_state_ids(
|
state_ids = await self._state.get_current_state_ids(
|
||||||
self.room_id, prev_event_ids
|
self.room_id, prev_event_ids
|
||||||
)
|
)
|
||||||
auth_ids = self._auth.compute_auth_events(self, state_ids)
|
auth_event_ids = self._auth.compute_auth_events(self, state_ids)
|
||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
if format_version == EventFormatVersions.V1:
|
if format_version == EventFormatVersions.V1:
|
||||||
# The types of auth/prev events changes between event versions.
|
# The types of auth/prev events changes between event versions.
|
||||||
auth_events = await self._store.add_event_hashes(
|
auth_events = await self._store.add_event_hashes(
|
||||||
auth_ids
|
auth_event_ids
|
||||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||||
prev_events = await self._store.add_event_hashes(
|
prev_events = await self._store.add_event_hashes(
|
||||||
prev_event_ids
|
prev_event_ids
|
||||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||||
else:
|
else:
|
||||||
auth_events = auth_ids
|
auth_events = auth_event_ids
|
||||||
prev_events = prev_event_ids
|
prev_events = prev_event_ids
|
||||||
|
|
||||||
old_depth = await self._store.get_max_depth_of(prev_event_ids)
|
old_depth = await self._store.get_max_depth_of(prev_event_ids)
|
||||||
|
|
|
@ -439,6 +439,7 @@ class EventCreationHandler:
|
||||||
event_dict: dict,
|
event_dict: dict,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str] = None,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
|
auth_event_ids: Optional[List[str]] = None,
|
||||||
require_consent: bool = True,
|
require_consent: bool = True,
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""
|
"""
|
||||||
|
@ -458,6 +459,12 @@ class EventCreationHandler:
|
||||||
new event.
|
new event.
|
||||||
|
|
||||||
If None, they will be requested from the database.
|
If None, they will be requested from the database.
|
||||||
|
|
||||||
|
auth_event_ids:
|
||||||
|
The event ids to use as the auth_events for the new event.
|
||||||
|
Should normally be left as None, which will cause them to be calculated
|
||||||
|
based on the room state at the prev_events.
|
||||||
|
|
||||||
require_consent: Whether to check if the requester has
|
require_consent: Whether to check if the requester has
|
||||||
consented to the privacy policy.
|
consented to the privacy policy.
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -516,7 +523,10 @@ class EventCreationHandler:
|
||||||
builder.internal_metadata.txn_id = txn_id
|
builder.internal_metadata.txn_id = txn_id
|
||||||
|
|
||||||
event, context = await self.create_new_client_event(
|
event, context = await self.create_new_client_event(
|
||||||
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
|
builder=builder,
|
||||||
|
requester=requester,
|
||||||
|
prev_event_ids=prev_event_ids,
|
||||||
|
auth_event_ids=auth_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# In an ideal world we wouldn't need the second part of this condition. However,
|
# In an ideal world we wouldn't need the second part of this condition. However,
|
||||||
|
@ -755,6 +765,7 @@ class EventCreationHandler:
|
||||||
builder: EventBuilder,
|
builder: EventBuilder,
|
||||||
requester: Optional[Requester] = None,
|
requester: Optional[Requester] = None,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
|
auth_event_ids: Optional[List[str]] = None,
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""Create a new event for a local client
|
"""Create a new event for a local client
|
||||||
|
|
||||||
|
@ -767,6 +778,11 @@ class EventCreationHandler:
|
||||||
|
|
||||||
If None, they will be requested from the database.
|
If None, they will be requested from the database.
|
||||||
|
|
||||||
|
auth_event_ids:
|
||||||
|
The event ids to use as the auth_events for the new event.
|
||||||
|
Should normally be left as None, which will cause them to be calculated
|
||||||
|
based on the room state at the prev_events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of created event, context
|
Tuple of created event, context
|
||||||
"""
|
"""
|
||||||
|
@ -788,7 +804,9 @@ class EventCreationHandler:
|
||||||
builder.type == EventTypes.Create or len(prev_event_ids) > 0
|
builder.type == EventTypes.Create or len(prev_event_ids) > 0
|
||||||
), "Attempting to create an event with no prev_events"
|
), "Attempting to create an event with no prev_events"
|
||||||
|
|
||||||
event = await builder.build(prev_event_ids=prev_event_ids)
|
event = await builder.build(
|
||||||
|
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
|
||||||
|
)
|
||||||
context = await self.state.compute_event_context(event)
|
context = await self.state.compute_event_context(event)
|
||||||
if requester:
|
if requester:
|
||||||
context.app_service = requester.app_service
|
context.app_service = requester.app_service
|
||||||
|
|
|
@ -17,12 +17,10 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
|
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -31,12 +29,8 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.room_versions import EventFormatVersions
|
|
||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.builder import create_local_event_from_event_dict
|
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
|
||||||
from synapse.storage.roommember import RoomsForUser
|
from synapse.storage.roommember import RoomsForUser
|
||||||
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
@ -1132,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
|
|
||||||
room_id = invite_event.room_id
|
room_id = invite_event.room_id
|
||||||
target_user = invite_event.state_key
|
target_user = invite_event.state_key
|
||||||
room_version = await self.store.get_room_version(room_id)
|
|
||||||
|
|
||||||
content["membership"] = Membership.LEAVE
|
content["membership"] = Membership.LEAVE
|
||||||
|
|
||||||
# the auth events for the new event are the same as that of the invite, plus
|
|
||||||
# the invite itself.
|
|
||||||
#
|
|
||||||
# the prev_events are just the invite.
|
|
||||||
invite_hash = invite_event.event_id # type: Union[str, Tuple]
|
|
||||||
if room_version.event_format == EventFormatVersions.V1:
|
|
||||||
alg, h = compute_event_reference_hash(invite_event)
|
|
||||||
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
|
|
||||||
|
|
||||||
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
|
|
||||||
prev_events = (invite_hash,)
|
|
||||||
|
|
||||||
# we cap depth of generated events, to ensure that they are not
|
|
||||||
# rejected by other servers (and so that they can be persisted in
|
|
||||||
# the db)
|
|
||||||
depth = min(invite_event.depth + 1, MAX_DEPTH)
|
|
||||||
|
|
||||||
event_dict = {
|
event_dict = {
|
||||||
"depth": depth,
|
|
||||||
"auth_events": auth_events,
|
|
||||||
"prev_events": prev_events,
|
|
||||||
"type": EventTypes.Member,
|
"type": EventTypes.Member,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": target_user,
|
"sender": target_user,
|
||||||
|
@ -1164,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
"state_key": target_user,
|
"state_key": target_user,
|
||||||
}
|
}
|
||||||
|
|
||||||
event = create_local_event_from_event_dict(
|
# the auth events for the new event are the same as that of the invite, plus
|
||||||
clock=self.clock,
|
# the invite itself.
|
||||||
hostname=self.hs.hostname,
|
#
|
||||||
signing_key=self.hs.signing_key,
|
# the prev_events are just the invite.
|
||||||
room_version=room_version,
|
prev_event_ids = [invite_event.event_id]
|
||||||
event_dict=event_dict,
|
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
|
||||||
|
|
||||||
|
event, context = await self.event_creation_handler.create_event(
|
||||||
|
requester,
|
||||||
|
event_dict,
|
||||||
|
txn_id=txn_id,
|
||||||
|
prev_event_ids=prev_event_ids,
|
||||||
|
auth_event_ids=auth_event_ids,
|
||||||
)
|
)
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
event.internal_metadata.out_of_band_membership = True
|
event.internal_metadata.out_of_band_membership = True
|
||||||
if txn_id is not None:
|
|
||||||
event.internal_metadata.txn_id = txn_id
|
|
||||||
if requester.access_token_id is not None:
|
|
||||||
event.internal_metadata.token_id = requester.access_token_id
|
|
||||||
|
|
||||||
EventValidator().validate_new(event, self.config)
|
|
||||||
|
|
||||||
context = await self.state_handler.compute_event_context(event)
|
|
||||||
context.app_service = requester.app_service
|
|
||||||
result_event = await self.event_creation_handler.handle_new_client_event(
|
result_event = await self.event_creation_handler.handle_new_client_event(
|
||||||
requester, event, context, extra_users=[UserID.from_string(target_user)],
|
requester, event, context, extra_users=[UserID.from_string(target_user)],
|
||||||
)
|
)
|
||||||
|
|
|
@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_latest_event_ids_in_room(room_id)
|
self.store.get_latest_event_ids_in_room(room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
event = self.get_success(builder.build(prev_event_ids))
|
event = self.get_success(builder.build(prev_event_ids, None))
|
||||||
|
|
||||||
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
|
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
builder = factory.for_room_version(room_version, event_dict)
|
builder = factory.for_room_version(room_version, event_dict)
|
||||||
join_event = self.get_success(builder.build(prev_event_ids))
|
join_event = self.get_success(builder.build(prev_event_ids, None))
|
||||||
|
|
||||||
self.get_success(federation.on_send_join_request(remote_server, join_event))
|
self.get_success(federation.on_send_join_request(remote_server, join_event))
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
|
@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self._event_id = event_id
|
self._event_id = event_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def build(self, prev_event_ids):
|
def build(self, prev_event_ids, auth_event_ids):
|
||||||
built_event = yield defer.ensureDeferred(
|
built_event = yield defer.ensureDeferred(
|
||||||
self._base_builder.build(prev_event_ids)
|
self._base_builder.build(prev_event_ids, auth_event_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
built_event._event_id = self._event_id
|
built_event._event_id = self._event_id
|
||||||
|
|
Loading…
Reference in New Issue