De-localpart `{Filtering,FilteringWorkerStore}.add_user_filter()`
Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
f98141ceb2
commit
07a5623059
|
@ -170,11 +170,9 @@ class Filtering:
|
|||
result = await self.store.get_user_filter(user_id, filter_id)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(
|
||||
self, user_localpart: str, user_filter: JsonDict
|
||||
) -> Awaitable[int]:
|
||||
def add_user_filter(self, user_id: str, user_filter: JsonDict) -> Awaitable[int]:
|
||||
self.check_valid_filter(user_filter)
|
||||
return self.store.add_user_filter(user_localpart, user_filter)
|
||||
return self.store.add_user_filter(user_id, user_filter)
|
||||
|
||||
# TODO(paul): surely we should probably add a delete_user_filter or
|
||||
# replace_user_filter at some point? There's no REST API specified for
|
||||
|
|
|
@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
|
|||
set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
|
||||
|
||||
filter_id = await self.filtering.add_user_filter(
|
||||
user_localpart=target_user.localpart, user_filter=content
|
||||
user_id=user_id, user_filter=content
|
||||
)
|
||||
|
||||
return 200, {"filter_id": str(filter_id)}
|
||||
|
|
|
@ -44,7 +44,6 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||
except ValueError:
|
||||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
try:
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
|
@ -68,7 +67,8 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||
|
||||
return db_to_json(def_json)
|
||||
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
|
||||
async def add_user_filter(self, user_id: str, user_filter: JsonDict) -> int:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
|
@ -92,10 +92,10 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||
filter_id = max_id + 1
|
||||
|
||||
sql = (
|
||||
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
|
||||
"VALUES(?, ?, ?)"
|
||||
"INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
|
||||
"VALUES(?, ?, ?, ?)"
|
||||
)
|
||||
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
|
||||
txn.execute(sql, (user_id, user_localpart, filter_id, bytearray(def_json)))
|
||||
|
||||
return filter_id
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ from tests import unittest
|
|||
from tests.events.test_utils import MockEvent
|
||||
|
||||
user_id = "@test_user:test"
|
||||
user_localpart = "test_user"
|
||||
user2_id = "@test_user2:test"
|
||||
|
||||
|
||||
|
@ -439,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
presence_states = [
|
||||
|
@ -467,7 +466,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
||||
user_id=user2_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
presence_states = [
|
||||
|
@ -493,7 +492,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||
|
@ -510,7 +509,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
event = MockEvent(
|
||||
|
@ -592,7 +591,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
filter_id = self.get_success(
|
||||
self.filtering.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -611,7 +610,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
filter_id = self.get_success(
|
||||
self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
user_id=user_id, user_filter=user_filter_json
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||
def test_get_filter(self) -> None:
|
||||
filter_id = self.get_success(
|
||||
self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
user_id="@apple:test", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
|
|
Loading…
Reference in New Issue