diff --git a/tests/test_visibility.py b/tests/test_visibility.py index c385b2f8d4..02cdbf53de 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import defaultdict from typing import Optional from unittest.mock import patch @@ -20,8 +21,11 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.types import JsonDict, create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server +from synapse.storage.controllers.state import StateStorageController +import synapse.visibility from tests import unittest +from tests.events.test_utils import MockEvent from tests.utils import create_room logger = logging.getLogger(__name__) @@ -328,3 +332,47 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): ), [], ) + + @patch.object(StateStorageController, "get_state_for_events", + return_value=defaultdict(None)) + @patch.object(synapse.visibility, "_check_client_allowed_to_see_event") + def test_limit_early_return(self, check_mock, state_mock): + def _check_client_allowed_to_see_event( + user_id: str, + event: EventBase, + **kwargs, + ) -> Optional[EventBase]: + return event + check_mock.side_effect = _check_client_allowed_to_see_event + + ev1 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message") + ev2 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message") + ev3 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message") + ev4 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message") + ev5 = MockEvent(room_id=TEST_ROOM_ID, sender="@foo:bar", type="m.room.message") + events = [ev1, ev2, ev3, ev4, ev5] + + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@user:test", + events, + ) + ), + events, + ) + self.assertEqual(check_mock.call_count, 5) + + self.assertEqual( + self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@user:test", + events, + limit=2, + ) + ), + events[0:2], + ) + self.assertEqual(check_mock.call_count, 7)