Pull out bits of StateStore to a mixin

... so that we don't need to secretly gut-wrench it for use in the slaved
stores. I haven't done the other stores yet, but we should. I'm tired of the
workers breaking every time we tweak the stores because I forgot to gut-wrench
the right method.

fixes https://github.com/matrix-org/synapse/issues/2655.
This commit is contained in:
Richard van der Hoff 2017-11-09 19:00:20 +00:00
parent 4dd1bfa8c1
commit 35a4b63240
2 changed files with 226 additions and 237 deletions

View File

@ -12,20 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore import logging
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
import logging from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore): class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs) super(SlavedEventStore, self).__init__(db_conn, hs)
@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_unread_counts_by_pos_txn = ( _get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__ DataStore._get_unread_counts_by_pos_txn.__func__
) )
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
_get_state_group_for_event = (
StateStore.__dict__["_get_state_group_for_event"]
)
_get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"]
)
_get_state_groups_from_groups_txn = (
DataStore._get_state_groups_from_groups_txn.__func__
)
get_recent_event_ids_for_room = ( get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__ has_room_changed_since = DataStore.has_room_changed_since.__func__
@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__ DataStore.get_room_events_stream_for_room.__func__
) )
get_events_around = DataStore.get_events_around.__func__ get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = ( _get_joined_users_from_context = (
@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__ get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__ _get_backfill_events = DataStore._get_backfill_events.__func__

View File

@ -18,6 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -41,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateStore(SQLBaseStore): class StateGroupReadStore(SQLBaseStore):
""" Keeps track of the state at a given event. """The read-only parts of StateGroupStore
This is done by the concept of `state groups`. Every event is a assigned None of these functions write to the state tables, so are suitable for
a state group (identified by an arbitrary string), which references a including in the SlavedStores.
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -65,22 +54,7 @@ class StateStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs) super(StateGroupReadStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@ -195,178 +169,6 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.iteritems() for group, event_id_map in group_to_ids.iteritems()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id) """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
@ -747,6 +549,220 @@ class StateStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def get_next_state_group(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_next() return self._state_groups_id_gen.get_next()