This commit is contained in:
Michael Telatynski 2024-01-07 06:41:56 +02:00 committed by GitHub
commit 0d5fa25436
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 5 deletions

View File

@ -587,6 +587,7 @@ class SyncHandler:
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
limit=timeline_limit + 1,
)
log_kv({"recents_after_visibility_filtering": len(recents)})
else:
@ -668,6 +669,7 @@ class SyncHandler:
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
limit=timeline_limit + 1 - len(recents),
)
log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)})

View File

@ -70,6 +70,7 @@ async def filter_events_for_client(
is_peeking: bool = False,
always_include_ids: FrozenSet[str] = frozenset(),
filter_send_to_client: bool = True,
limit: int = -1,
) -> List[EventBase]:
"""
Check which events a user is allowed to see. If the user can see the event but its
@ -88,6 +89,7 @@ async def filter_events_for_client(
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
limit: The number of events to bail at, as a hot path optimisation.
Returns:
The filtered events.
@ -140,12 +142,15 @@ async def filter_events_for_client(
sender_erased=erased_senders.get(event.sender, False),
)
# Check each event: gives an iterable of None or (a potentially modified)
# EventBase.
filtered_events = map(allowed, events)
filtered_events: List[EventBase] = []
for event in events:
checked = allowed(event)
if checked is not None:
filtered_events.append(checked)
if len(filtered_events) >= limit >= 0:
break
# Turn it into a list and remove None entries before returning.
return [ev for ev in filtered_events if ev]
return filtered_events
async def filter_event_for_clients_with_state(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import defaultdict
from typing import Optional
from unittest.mock import patch
@ -20,8 +21,11 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
from synapse.storage.controllers.state import StateStorageController
import synapse.visibility
from tests import unittest
from tests.events.test_utils import MockEvent
from tests.utils import create_room
logger = logging.getLogger(__name__)
@ -357,3 +361,47 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
),
[],
)
@patch.object(StateStorageController, "get_state_for_events",
return_value=defaultdict(None))
@patch.object(synapse.visibility, "_check_client_allowed_to_see_event")
def test_limit_early_return(self, check_mock, state_mock):
def _check_client_allowed_to_see_event(
user_id: str,
event: EventBase,
**kwargs,
) -> Optional[EventBase]:
return event
check_mock.side_effect = _check_client_allowed_to_see_event
ev1 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message")
ev2 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message")
ev3 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message")
ev4 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message")
ev5 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message")
events = [ev1, ev2, ev3, ev4, ev5]
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage_controllers(),
"@user:test",
events,
)
),
events,
)
self.assertEqual(check_mock.call_count, 5)
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage_controllers(),
"@user:test",
events,
limit=2,
)
),
events[0:2],
)
self.assertEqual(check_mock.call_count, 7)