From 546b9c9e648f5e2b25bb7c8350570787ff9befae Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 22 Feb 2022 11:44:11 +0000 Subject: [PATCH] Add more tests for in-flight state query duplication. (#12033) --- changelog.d/12033.misc | 1 + tests/storage/databases/test_state_store.py | 196 +++++++++++++++++--- 2 files changed, 174 insertions(+), 23 deletions(-) create mode 100644 changelog.d/12033.misc diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc new file mode 100644 index 0000000000..3af049b969 --- /dev/null +++ b/changelog.d/12033.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index cf126ee62d..3a4a4a3a29 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -18,8 +18,9 @@ from unittest.mock import patch from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes from synapse.storage.state import StateFilter -from synapse.types import MutableStateMap, StateMap +from synapse.types import StateMap from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -27,6 +28,21 @@ from tests.unittest import HomeserverTestCase if typing.TYPE_CHECKING: from synapse.server import HomeServer +# StateFilter for ALL non-m.room.member state events +ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( + types={EventTypes.Member: set()}, + include_others=True, +) + +FAKE_STATE = { + (EventTypes.Member, "@alice:test"): "join", + (EventTypes.Member, "@bob:test"): "leave", + (EventTypes.Member, "@charlie:test"): "invite", + ("test.type", "a"): "AAA", + ("test.type", "b"): "BBB", + ("other.event.type", "state.key"): "123", +} + class StateGroupInflightCachingTestCase(HomeserverTestCase): def prepare( @@ -65,24 +81,8 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): Assemble a fake database response and complete the database request. """ - result: Dict[int, StateMap[str]] = {} - - for group in groups: - group_result: MutableStateMap[str] = {} - result[group] = group_result - - for state_type, state_keys in state_filter.types.items(): - if state_keys is None: - group_result[(state_type, "a")] = "xyz" - group_result[(state_type, "b")] = "xyz" - else: - for state_key in state_keys: - group_result[(state_type, state_key)] = "abc" - - if state_filter.include_others: - group_result[("other.event.type", "state.key")] = "123" - - d.callback(result) + # Return a filtered copy of the fake state + d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) def test_duplicate_requests_deduplicated(self) -> None: """ @@ -125,9 +125,159 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): # Now we can complete the request self._complete_request_fake(groups, sf, d) - self.assertEqual( - self.get_success(req1), {("other.event.type", "state.key"): "123"} + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_smaller_request_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests a subset of that state, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the correct retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) ) - self.assertEqual( - self.get_success(req2), {("other.event.type", "state.key"): "123"} + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", "b"),)) + ) ) + self.pump(by=0.1) + + # No more calls should have gone to the database, because the second + # request was already in the in-flight cache! + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) + + def test_partially_overlapping_request_deduplicated(self) -> None: + """ + Tests that partially-overlapping requests are partially deduplicated. + + This test: + - requests a single type of wildcard state + (This is internally expanded to be all non-member state) + - requests the entire state in parallel + - checks to see that two database queries were made, but that the second + one is only for member state. + - completes the database queries + - checks that both requests have the correct result. + """ + + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # Because it only partially overlaps, this also went to the database + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + # First request: + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + self._complete_request_fake(groups, sf, d) + + # Second request: + groups, sf, d = self.get_state_group_calls[1] + self.assertEqual(groups, (42,)) + # The state filter is narrowed to only request membership state, because + # the remainder of the state is already being queried in the first request! + self.assertEqual( + sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) + ) + self._complete_request_fake(groups, sf, d) + + # Check the results are correct + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_in_flight_requests_stop_being_in_flight(self) -> None: + """ + Tests that in-flight request deduplication doesn't somehow 'hold on' + to completed requests: once they're done, they're taken out of the + in-flight cache. + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[0]) + self.assertTrue(req1.called) + + # Send off another request + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # It should have gone to the database again, because the previous request + # isn't in-flight and therefore isn't available for deduplication. + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req2.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[1]) + self.assertTrue(req2.called) + groups, sf, d = self.get_state_group_calls[0] + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE)