diff --git a/synapse/state.py b/synapse/state.py index 695a5e7ac4..c45bab5859 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -43,14 +43,30 @@ AuthEventTypes = ( ) +class _StateCacheEntry(object): + def __init__(self, state, state_group, ts): + self.state = state + self.state_group = state_group + self.ts = ts + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): + self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs + # set of event_ids -> _StateCacheEntry. + self._state_cache = {} + + def f(): + self._prune_cache() + + self.clock.looping_call(f, 10*1000) + @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): """ Returns the current state for the room as a list. This is done by @@ -70,6 +86,11 @@ class StateHandler(object): for e_id, _, _ in events ] + cache = self._state_cache.get(set(event_ids), None) + if cache: + cache.ts = self.clock.time_msec() + defer.returnValue(cache.state_group, cache.state) + res = yield self.resolve_state_groups(event_ids) if event_type: @@ -177,6 +198,11 @@ class StateHandler(object): """ logger.debug("resolve_state_groups event_ids %s", event_ids) + cache = self._state_cache.get(set(event_ids), None) + if cache and cache.state_group: + cache.ts = self.clock.time_msec() + defer.returnValue(cache.state_group, cache.state) + state_groups = yield self.store.get_state_groups( event_ids ) @@ -200,6 +226,14 @@ class StateHandler(object): else: prev_states = [] + cache = _StateCacheEntry( + state=state, + state_group=name, + ts=self.clock.time_msec() + ) + + self._state_cache[set(event_ids)] = cache + defer.returnValue((name, state, prev_states)) state = {} @@ -245,6 +279,14 @@ class StateHandler(object): new_state = unconflicted_state new_state.update(resolved_state) + cache = _StateCacheEntry( + state=new_state, + state_group=None, + ts=self.clock.time_msec() + ) + + self._state_cache[set(event_ids)] = cache + defer.returnValue((None, new_state, prev_states)) @log_function @@ -328,3 +370,24 @@ class StateHandler(object): return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return sorted(events, key=key_func) + + def _prune_cache(self): + now = self.clock.time_msec() + + if len(self._state_cache) > 100: + sorted_entries = sorted( + self._state_cache.items(), + key=lambda k, v: v.ts, + ) + + for k, _ in sorted_entries[100:]: + self._state_cache.pop(k) + + keys_to_delete = set() + + for key, cache_entry in self._state_cache.items(): + if now - cache_entry.ts > 60*1000: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._state_cache.pop(k) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 4e837a918e..1fd5ba5787 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,7 +15,7 @@ from synapse.util.logcontext import LoggingContext -from twisted.internet import reactor +from twisted.internet import reactor, task import time @@ -35,6 +35,14 @@ class Clock(object): """Returns the current system time in miliseconds since epoch.""" return self.time() * 1000 + def looping_call(self, f, msec): + l = task.LoopingCall(f) + l.start(msec/1000.0, now=False) + return l + + def looping_call(self, loop): + loop.stop() + def call_later(self, delay, callback): current_context = LoggingContext.current_context()