From d4cb3edba8c03b7a4d9a647906c58269bcb3681f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 24 Apr 2017 09:46:24 +0100 Subject: [PATCH] Moar stuff --- synapse/federation/transaction_queue.py | 6 +- synapse/push/bulk_push_rule_evaluator.py | 148 +++++++++++++++++++++-- synapse/state.py | 11 ++ synapse/storage/_base.py | 8 +- synapse/storage/events.py | 8 +- synapse/storage/push_rule.py | 2 +- synapse/storage/roommember.py | 103 +++++++++++++--- synapse/storage/state.py | 8 ++ 8 files changed, 258 insertions(+), 36 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index dee387eb7f..695f1a7375 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -24,7 +24,6 @@ from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func -from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state, get_interested_remotes import synapse.metrics @@ -183,15 +182,12 @@ class TransactionQueue(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - users_in_room = yield self.state.get_current_user_in_room( + destinations = yield self.state.get_current_hosts_in_room( event.room_id, latest_event_ids=[ prev_id for prev_id, _ in event.prev_events ], ) - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - ) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index cb13874ccf..72820ec42a 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,17 +19,21 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership logger = logging.getLogger(__name__) +rules_by_room = {} + + @defer.inlineCallbacks def evaluator_for_event(event, hs, store, context): - rules_by_user = yield store.bulk_get_push_rules_for_room( - event, context - ) + room_id = event.room_id + rules_for_room = rules_by_room.setdefault(room_id, RulesForRoom(hs, room_id)) + + rules_by_user = yield rules_for_room.get_rules(context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited @@ -66,11 +70,14 @@ class BulkPushRuleEvaluator: def action_for_event_by_user(self, event, context): actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context( - event, context - ) + # room_members = yield self.store.get_joined_users_from_context( + # event, context + # ) + room_members = {} - evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) + num_room_members = yield self.store.get_number_of_users_in_rooms(event.room_id) + + evaluator = PushRuleEvaluatorForEvent(event, num_room_members) condition_cache = {} @@ -127,3 +134,128 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return False return True + + +class RulesForRoom(object): + def __init__(self, hs, room_id): + self.room_id = room_id + self.is_mine_id = hs.is_mine_id + self.store = hs.get_datastore() + + self.member_map = {} # event_id -> (user_id, state) + self.rules_by_user = {} # user_id -> rules + self.state_group = object() + + self.sequence = 0 + + @defer.inlineCallbacks + def get_rules(self, context): + state_group = context.state_group + current_state_ids = context.current_state_ids + + if state_group and self.state_group == state_group: + defer.returnValue(self.rules_by_user) + + ret_rules_by_user = {} + missing_member_event_ids = {} + for key, event_id in current_state_ids.iteritems(): + res = self.member_map.get(event_id, None) + if res: + user_id, state = res + if state == Membership.JOIN: + rules = self.rules_by_user.get(user_id, None) + if rules: + ret_rules_by_user[user_id] = rules + continue + + if key[0] != EventTypes.Member: + continue + + user_id = key[1] + if not self.is_mine_id(user_id): + continue + + if self.store.get_if_app_services_interested_in_user(user_id): + continue + + missing_member_event_ids[user_id] = event_id + + if missing_member_event_ids: + missing_rules = yield self.get_rules_for_member_event_ids( + missing_member_event_ids, state_group + ) + ret_rules_by_user.update(missing_rules) + + defer.returnValue(ret_rules_by_user) + + @defer.inlineCallbacks + def get_rules_for_member_event_ids(self, member_event_ids, state_group): + sequence = self.sequence + + rows = yield self.store._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids.values(), + retcols=['user_id', 'membership', 'event_id'], + keyvalues={}, + batch_size=500, + desc="get_rules_for_member_event_ids", + ) + + members = { + row["event_id"]: (row["user_id"], row["membership"]) + for row in rows + } + + interested_in_user_ids = set(user_id for user_id, _ in members.itervalues()) + + if_users_with_pushers = yield self.store.get_if_users_have_pushers( + interested_in_user_ids, + on_invalidate=self.invalidate_all, + ) + + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + self.room_id, on_invalidate=self.invalidate_all, + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in interested_in_user_ids: + user_ids.add(uid) + + forgotten = yield self.store.who_forgot_in_room( + self.room_id, on_invalidate=self.invalidate_all, + ) + + for row in forgotten: + user_id = row["user_id"] + event_id = row["event_id"] + + mem_id = member_event_ids.get((user_id), None) + if event_id == mem_id: + user_ids.discard(user_id) + + rules_by_user = yield self.store.bulk_get_push_rules( + user_ids, on_invalidate=self.invalidate_all, + ) + + rules_by_user = {k: v for k, v in rules_by_user.iteritems() if v is not None} + + self.update_cache(sequence, members, rules_by_user, state_group) + + defer.returnValue(rules_by_user) + + def invalidate_all(self): + self.sequence += 1 + self.member_map = {} + self.rules_by_user = {} + + def update_cache(self, sequence, members, rules_by_user, state_group): + if sequence == self.sequence: + self.member_map.update(members) + self.rules_by_user.update(rules_by_user) + self.state_group = state_group diff --git a/synapse/state.py b/synapse/state.py index f6b83d888a..f8b18a4a2d 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -175,6 +175,17 @@ class StateHandler(object): ) defer.returnValue(joined_users) + @defer.inlineCallbacks + def get_current_hosts_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + logger.debug("calling resolve_state_groups from get_current_user_in_room") + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_hosts = yield self.store.get_joined_hosts( + room_id, entry.state_id, entry.state + ) + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """Build an EventContext structure for the event. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c659004e8d..58b73af7d2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -60,12 +60,12 @@ class LoggingTransaction(object): object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) - def call_after(self, callback, *args): + def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ - self.after_callbacks.append((callback, args)) + self.after_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -319,8 +319,8 @@ class SQLBaseStore(object): inner_func, *args, **kwargs ) finally: - for after_callback, after_args in after_callbacks: - after_callback(*after_args) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) defer.returnValue(result) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a3790419dd..a8d1b93d99 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -370,6 +370,10 @@ class EventsStore(SQLBaseStore): new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + for room_id, (_, _, new_state) in current_state_for_room.iteritems(): + self.get_current_state_ids.prefill( + (room_id, ), new_state + ) @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): @@ -529,7 +533,7 @@ class EventsStore(SQLBaseStore): if ev_id in events_to_insert } - defer.returnValue((to_delete, to_insert)) + defer.returnValue((to_delete, to_insert, current_state)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -682,7 +686,7 @@ class EventsStore(SQLBaseStore): def _update_current_state_txn(self, txn, state_delta_by_room): for room_id, current_state_tuple in state_delta_by_room.iteritems(): - to_delete, to_insert = current_state_tuple + to_delete, to_insert, _ = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", [(ev_id,) for ev_id in to_delete.itervalues()], diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 353a135c4e..b80af941f2 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -48,7 +48,7 @@ def _load_rules(rawrules, enabled_map): class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=50000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7ad2198d96..bfd1955376 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -417,25 +417,47 @@ class RoomMemberStore(SQLBaseStore): if key[0] == EventTypes.Member ] - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=['user_id', 'display_name', 'avatar_url'], - keyvalues={ - "membership": Membership.JOIN, - }, - batch_size=500, - desc="_get_joined_users_from_context", + event_map = self._get_events_from_cache( + member_event_ids, + allow_rejected=False, ) - users_in_room = { - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), + missing_member_event_ids = [] + users_in_room = {} + for event_id, ev_entry in event_map.iteritems(): + if event_id: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=('user_id', 'display_name', 'avatar_url',), + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", ) - for row in rows - } + + users_in_room.update({ + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + }) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -447,6 +469,17 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(users_in_room) + def get_number_of_users_in_rooms(self, room_id): + sql = """SELECT coalesce(count(*), 0) FROM current_state_events + INNER JOIN room_memberships USING (room_id, event_id) + WHERE room_id = ? AND membership = 'join' + """ + return self._execute( + "get_number_of_users_in_rooms", + lambda txn: txn.fetchone()[0], + sql, room_id + ) + def is_host_joined(self, room_id, host, state_group, state_ids): if not state_group: # If state_group is None it means it has yet to be assigned a @@ -482,6 +515,44 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(False) + def get_joined_hosts(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + joined_hosts = set() + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + host = get_domain_from_id(state_key) + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + if host in joined_hosts: + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + joined_hosts.add(host) + + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a16afa8df5..1e1ce87e0e 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -227,6 +227,14 @@ class StateStore(SQLBaseStore): ], ) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=context.state_group, + value=context.current_state_ids, + full=True, + ) + self._simple_insert_many_txn( txn, table="event_to_state_groups",