Add a stream for push rule updates

This commit is contained in:
Mark Haines 2016-03-01 13:35:37 +00:00
parent a612ce6659
commit a1cf9e3bf3
5 changed files with 251 additions and 74 deletions

View File

@ -45,7 +45,7 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(

View File

@ -766,6 +766,19 @@ class SQLBaseStore(object):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(
desc, self._simple_delete_one_txn, table, keyvalues
)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
@ -775,13 +788,11 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
txn.execute(sql, keyvalues.values())
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):

View File

@ -99,30 +99,31 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
@defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
before=None, after=None
):
conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions)
if before or after:
return self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
user_id, rule_id, priority_class,
conditions_json, actions_json, before, after,
)
else:
return self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
user_id, rule_id, priority_class,
conditions_json, actions_json,
)
with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
if before or after:
yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after,
)
else:
yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json,
)
def _add_push_rule_relative_txn(
self, txn, user_id, rule_id, priority_class,
self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after
):
# Lock the table since otherwise we'll have annoying races between the
@ -174,12 +175,12 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, new_rule_priority,
conditions_json, actions_json,
txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
new_rule_priority, conditions_json, actions_json,
)
def _add_push_rule_highest_priority_txn(
self, txn, user_id, rule_id, priority_class,
self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json
):
# Lock the table since otherwise we'll have annoying races between the
@ -201,13 +202,13 @@ class PushRuleStore(SQLBaseStore):
self._upsert_push_rule_txn(
txn,
user_id, rule_id, priority_class, new_prio,
stream_id, stream_ordering, user_id, rule_id, priority_class, new_prio,
conditions_json, actions_json,
)
def _upsert_push_rule_txn(
self, txn, user_id, rule_id, priority_class,
priority, conditions_json, actions_json
self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
priority, conditions_json, actions_json, update_stream=True
):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
@ -242,6 +243,23 @@ class PushRuleStore(SQLBaseStore):
},
)
if update_stream:
self._simple_insert_txn(
txn,
table="push_rules_stream",
values={
"stream_id": stream_id,
"stream_ordering": stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": "ADD",
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
}
)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
@ -260,25 +278,47 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(
"push_rules",
{'user_name': user_id, 'rule_id': rule_id},
desc="delete_push_rule",
)
def delete_push_rule_txn(txn, stream_id, stream_ordering):
self._simple_delete_one_txn(
txn,
"push_rules",
{'user_name': user_id, 'rule_id': rule_id},
)
self._simple_insert_txn(
txn,
table="push_rules_stream",
values={
"stream_id": stream_id,
"stream_ordering": stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": "DELETE",
}
)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, stream_ordering
)
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
user_id, rule_id, enabled
)
defer.returnValue(ret)
with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id, stream_ordering, user_id, rule_id, enabled
)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
def _set_push_rule_enabled_txn(
self, txn, stream_id, stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
@ -287,6 +327,19 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
self._simple_insert_txn(
txn,
"push_rules_stream",
values={
"stream_id": stream_id,
"stream_ordering": stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": "ENABLE" if enabled else "DISABLE",
}
)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
@ -294,18 +347,20 @@ class PushRuleStore(SQLBaseStore):
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json.dumps(actions)
def set_push_rule_actions_txn(txn):
def set_push_rule_actions_txn(txn, stream_id, stream_ordering):
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
priority_class = -1
priority = 1
self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, priority,
"[]", actions_json
txn, stream_id, stream_ordering, user_id, rule_id,
priority_class, priority, "[]", actions_json,
update_stream=False
)
else:
self._simple_update_one_txn(
@ -315,8 +370,46 @@ class PushRuleStore(SQLBaseStore):
{'actions': actions_json},
)
self._simple_insert_txn(
txn,
"push_rules_stream",
values={
"stream_id": stream_id,
"stream_ordering": stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": "ACTIONS",
"actions": actions_json,
}
)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
yield self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn,
stream_id, stream_ordering
)
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id and stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn,
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)

View File

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE push_rules_stream(
stream_id BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL,
user_id TEXT NOT NULL,
rule_id TEXT NOT NULL,
op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE"
priority_class SMALLINT,
priority INTEGER,
conditions TEXT,
actions TEXT
);
-- The extra data for each operation is:
-- * ENABLE, DISABLE, DELETE: []
-- * ACTIONS: ["actions"]
-- * ADD: ["priority_class", "priority", "actions", "conditions"]
-- Index for replication queries.
CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id);
-- Index for /sync queries.
CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id);

View File

@ -20,23 +20,21 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
cur = db_conn.cursor()
self._next_id = self._load_next_id(cur)
cur.close()
def _load_next_id(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
val, = txn.fetchone()
return val + 1 if val else 1
self._next_id = _load_max_id(db_conn, table, column)
def get_next(self):
with self._lock:
i = self._next_id
self._next_id += 1
return i
return self._next_id
def _load_max_id(db_conn, table, column):
cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
return val if val else 1
class StreamIdGenerator(object):
@ -52,23 +50,10 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
cur = db_conn.cursor()
self._current_max = self._load_current_max(cur)
cur.close()
self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
def _load_current_max(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
return int(val) if val else 1
def get_next(self):
"""
Usage:
@ -124,3 +109,50 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1
return self._current_max
class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
with another stream. It generates pairs of IDs, the first element is an
integer ID for this stream, the second element is the ID for the stream
that this stream needs to be kept in sync with."""
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
"""
Usage:
with stream_id_gen.get_next() as (stream_id, chained_id):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_max_token()
self._unfinished_ids.append((next_id, chained_id))
@contextlib.contextmanager
def manager():
try:
yield (next_id, chained_id)
finally:
with self._lock:
self._unfinished_ids.remove((next_id, chained_id))
return manager()
def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token())