From 07a5623059961afc3adec4534fb24b49db1a39c4 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Sat, 15 Apr 2023 02:43:04 +0100 Subject: [PATCH] De-localpart `{Filtering,FilteringWorkerStore}.add_user_filter()` Signed-off-by: Sean Quah --- synapse/api/filtering.py | 6 ++---- synapse/rest/client/filter.py | 2 +- synapse/storage/databases/main/filtering.py | 10 +++++----- tests/api/test_filtering.py | 13 ++++++------- tests/rest/client/test_filter.py | 2 +- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 870baff2c4..84e57ac2c1 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -170,11 +170,9 @@ class Filtering: result = await self.store.get_user_filter(user_id, filter_id) return FilterCollection(self._hs, result) - def add_user_filter( - self, user_localpart: str, user_filter: JsonDict - ) -> Awaitable[int]: + def add_user_filter(self, user_id: str, user_filter: JsonDict) -> Awaitable[int]: self.check_valid_filter(user_filter) - return self.store.add_user_filter(user_localpart, user_filter) + return self.store.add_user_filter(user_id, user_filter) # TODO(paul): surely we should probably add a delete_user_filter or # replace_user_filter at some point? There's no REST API specified for diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 2db8dacf7c..156ae61d21 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet): set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) filter_id = await self.filtering.add_user_filter( - user_localpart=target_user.localpart, user_filter=content + user_id=user_id, user_filter=content ) return 200, {"filter_id": str(filter_id)} diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 24a60ab6f3..e4de5000d0 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -44,7 +44,6 @@ class FilteringWorkerStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - user_localpart = UserID.from_string(user_id).localpart try: def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", @@ -68,7 +67,8 @@ class FilteringWorkerStore(SQLBaseStore): return db_to_json(def_json) - async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: + async def add_user_filter(self, user_id: str, user_filter: JsonDict) -> int: + user_localpart = UserID.from_string(user_id).localpart def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -92,10 +92,10 @@ class FilteringWorkerStore(SQLBaseStore): filter_id = max_id + 1 sql = ( - "INSERT INTO user_filters (user_id, filter_id, filter_json)" - "VALUES(?, ?, ?)" + "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)" + "VALUES(?, ?, ?, ?)" ) - txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) + txn.execute(sql, (user_id, user_localpart, filter_id, bytearray(def_json))) return filter_id diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 48f40da176..4a55aa96a5 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -34,7 +34,6 @@ from tests import unittest from tests.events.test_utils import MockEvent user_id = "@test_user:test" -user_localpart = "test_user" user2_id = "@test_user2:test" @@ -439,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): user_filter_json = {"presence": {"senders": ["@foo:bar"]}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) presence_states = [ @@ -467,7 +466,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart + "2", user_filter=user_filter_json + user_id=user2_id, user_filter=user_filter_json ) ) presence_states = [ @@ -493,7 +492,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") @@ -510,7 +509,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) event = MockEvent( @@ -592,7 +591,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filter_id = self.get_success( self.filtering.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) @@ -611,7 +610,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filter_id = self.get_success( self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + user_id=user_id, user_filter=user_filter_json ) ) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 436a186f23..65eff4fe10 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -76,7 +76,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_get_filter(self) -> None: filter_id = self.get_success( self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER + user_id="@apple:test", user_filter=self.EXAMPLE_FILTER ) ) self.reactor.advance(1)