Initial stab at real SQL storage implementation of user filter definitions
This commit is contained in:
parent
0c14a699bb
commit
06cc147012
|
@ -61,6 +61,7 @@ SCHEMAS = [
|
|||
"event_edges",
|
||||
"event_signatures",
|
||||
"media_repository",
|
||||
"filtering",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ from twisted.internet import defer
|
|||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
import json
|
||||
|
||||
|
||||
# TODO(paul)
|
||||
_filters_for_user = {}
|
||||
|
@ -25,22 +27,41 @@ _filters_for_user = {}
|
|||
class FilteringStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_user_filter(self, user_localpart, filter_id):
|
||||
filters = _filters_for_user.get(user_localpart, None)
|
||||
def_json = yield self._simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={
|
||||
"user_id": user_localpart,
|
||||
"filter_id": filter_id,
|
||||
},
|
||||
retcol="definition",
|
||||
allow_none=False,
|
||||
)
|
||||
|
||||
if not filters or filter_id >= len(filters):
|
||||
raise KeyError()
|
||||
defer.returnValue(json.loads(def_json))
|
||||
|
||||
# trivial yield to make it a generator so d.iC works
|
||||
yield
|
||||
defer.returnValue(filters[filter_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_user_filter(self, user_localpart, definition):
|
||||
filters = _filters_for_user.setdefault(user_localpart, [])
|
||||
def_json = json.dumps(definition)
|
||||
|
||||
filter_id = len(filters)
|
||||
filters.append(definition)
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
# INSERT a new one
|
||||
def _do_txn(txn):
|
||||
sql = (
|
||||
"SELECT MAX(filter_id) FROM user_filters "
|
||||
"WHERE user_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_localpart,))
|
||||
max_id = txn.fetchone()[0]
|
||||
if max_id is None:
|
||||
filter_id = 0
|
||||
else:
|
||||
filter_id = max_id + 1
|
||||
|
||||
# trivial yield, see above
|
||||
yield
|
||||
defer.returnValue(filter_id)
|
||||
sql = (
|
||||
"INSERT INTO user_filters (user_id, filter_id, definition)"
|
||||
"VALUES(?, ?, ?)"
|
||||
)
|
||||
txn.execute(sql, (user_localpart, filter_id, def_json))
|
||||
|
||||
return filter_id
|
||||
|
||||
return self.runInteraction("add_user_filter", _do_txn)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS user_filters(
|
||||
user_id TEXT,
|
||||
filter_id INTEGER,
|
||||
definition TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
|
||||
user_id, filter_id
|
||||
);
|
|
@ -53,16 +53,33 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
self.filtering = hs.get_filtering()
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter(self):
|
||||
def test_add_filter(self):
|
||||
filter_id = yield self.filtering.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
definition={"type": ["m.*"]},
|
||||
)
|
||||
|
||||
self.assertEquals(filter_id, 0)
|
||||
self.assertEquals({"type": ["m.*"]},
|
||||
(yield self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=0,
|
||||
))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter(self):
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
definition={"type": ["m.*"]},
|
||||
)
|
||||
|
||||
filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
|
||||
self.assertEquals(filter, {"type": ["m.*"]})
|
||||
|
|
Loading…
Reference in New Issue