Initial stab at real SQL storage implementation of user filter definitions
This commit is contained in:
parent
0c14a699bb
commit
06cc147012
|
@ -61,6 +61,7 @@ SCHEMAS = [
|
||||||
"event_edges",
|
"event_edges",
|
||||||
"event_signatures",
|
"event_signatures",
|
||||||
"media_repository",
|
"media_repository",
|
||||||
|
"filtering",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul)
|
# TODO(paul)
|
||||||
_filters_for_user = {}
|
_filters_for_user = {}
|
||||||
|
@ -25,22 +27,41 @@ _filters_for_user = {}
|
||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_user_filter(self, user_localpart, filter_id):
|
def get_user_filter(self, user_localpart, filter_id):
|
||||||
filters = _filters_for_user.get(user_localpart, None)
|
def_json = yield self._simple_select_one_onecol(
|
||||||
|
table="user_filters",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_localpart,
|
||||||
|
"filter_id": filter_id,
|
||||||
|
},
|
||||||
|
retcol="definition",
|
||||||
|
allow_none=False,
|
||||||
|
)
|
||||||
|
|
||||||
if not filters or filter_id >= len(filters):
|
defer.returnValue(json.loads(def_json))
|
||||||
raise KeyError()
|
|
||||||
|
|
||||||
# trivial yield to make it a generator so d.iC works
|
|
||||||
yield
|
|
||||||
defer.returnValue(filters[filter_id])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def add_user_filter(self, user_localpart, definition):
|
def add_user_filter(self, user_localpart, definition):
|
||||||
filters = _filters_for_user.setdefault(user_localpart, [])
|
def_json = json.dumps(definition)
|
||||||
|
|
||||||
filter_id = len(filters)
|
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||||
filters.append(definition)
|
# INSERT a new one
|
||||||
|
def _do_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT MAX(filter_id) FROM user_filters "
|
||||||
|
"WHERE user_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_localpart,))
|
||||||
|
max_id = txn.fetchone()[0]
|
||||||
|
if max_id is None:
|
||||||
|
filter_id = 0
|
||||||
|
else:
|
||||||
|
filter_id = max_id + 1
|
||||||
|
|
||||||
# trivial yield, see above
|
sql = (
|
||||||
yield
|
"INSERT INTO user_filters (user_id, filter_id, definition)"
|
||||||
defer.returnValue(filter_id)
|
"VALUES(?, ?, ?)"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_localpart, filter_id, def_json))
|
||||||
|
|
||||||
|
return filter_id
|
||||||
|
|
||||||
|
return self.runInteraction("add_user_filter", _do_txn)
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
CREATE TABLE IF NOT EXISTS user_filters(
|
||||||
|
user_id TEXT,
|
||||||
|
filter_id INTEGER,
|
||||||
|
definition TEXT,
|
||||||
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
|
||||||
|
user_id, filter_id
|
||||||
|
);
|
|
@ -53,16 +53,33 @@ class FilteringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.filtering = hs.get_filtering()
|
self.filtering = hs.get_filtering()
|
||||||
|
|
||||||
|
self.datastore = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter(self):
|
def test_add_filter(self):
|
||||||
filter_id = yield self.filtering.add_user_filter(
|
filter_id = yield self.filtering.add_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart,
|
||||||
definition={"type": ["m.*"]},
|
definition={"type": ["m.*"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(filter_id, 0)
|
self.assertEquals(filter_id, 0)
|
||||||
|
self.assertEquals({"type": ["m.*"]},
|
||||||
|
(yield self.datastore.get_user_filter(
|
||||||
|
user_localpart=user_localpart,
|
||||||
|
filter_id=0,
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_filter(self):
|
||||||
|
filter_id = yield self.datastore.add_user_filter(
|
||||||
|
user_localpart=user_localpart,
|
||||||
|
definition={"type": ["m.*"]},
|
||||||
|
)
|
||||||
|
|
||||||
filter = yield self.filtering.get_user_filter(
|
filter = yield self.filtering.get_user_filter(
|
||||||
user_localpart=user_localpart,
|
user_localpart=user_localpart,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(filter, {"type": ["m.*"]})
|
self.assertEquals(filter, {"type": ["m.*"]})
|
||||||
|
|
Loading…
Reference in New Issue