Fix bug where state_group tables got corrupted

This is due to the fact that we prefilled caches using txn.call_after,
which always gets called including on error.

We fix this by making txn.call_after only fire when a transaction
completes successfully, which is what we want most of the time anyway.
This commit is contained in:
Erik Johnston 2017-06-07 17:34:20 +01:00
parent 3accee1a8c
commit 197bd126f0
3 changed files with 24 additions and 10 deletions

View File

@ -225,7 +225,8 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine, database_engine=self.database_engine,
after_callbacks=[] after_callbacks=[],
final_callbacks=[],
) )
self._find_stream_orderings_for_times_txn(cur) self._find_stream_orderings_for_times_txn(cur)
cur.close() cur.close()

View File

@ -52,13 +52,17 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name", "database_engine", "after_callbacks"] __slots__ = [
"txn", "name", "database_engine", "after_callbacks", "final_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks): def __init__(self, txn, name, database_engine, after_callbacks,
final_callbacks):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks) object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "final_callbacks", final_callbacks)
def call_after(self, callback, *args, **kwargs): def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
@ -67,6 +71,9 @@ class LoggingTransaction(object):
""" """
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_finally(self, callback, *args, **kwargs):
self.final_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -217,8 +224,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, logging_context, def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
func, *args, **kwargs): logging_context, func, *args, **kwargs):
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -237,7 +244,8 @@ class SQLBaseStore(object):
try: try:
txn = conn.cursor() txn = conn.cursor()
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks txn, name, self.database_engine, after_callbacks,
final_callbacks,
) )
r = func(txn, *args, **kwargs) r = func(txn, *args, **kwargs)
conn.commit() conn.commit()
@ -298,6 +306,7 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
after_callbacks = [] after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
@ -309,7 +318,7 @@ class SQLBaseStore(object):
current_context.copy_to(context) current_context.copy_to(context)
return self._new_transaction( return self._new_transaction(
conn, desc, after_callbacks, current_context, conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs func, *args, **kwargs
) )
@ -318,9 +327,13 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally:
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
finally:
for after_callback, after_args, after_kwargs in final_callbacks:
after_callback(*after_args, **after_kwargs)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -936,7 +949,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes. # __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next() ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__() stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None) txn.call_finally(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn( self._simple_insert_txn(

View File

@ -1419,7 +1419,7 @@ class EventsStore(SQLBaseStore):
] ]
rows = self._new_transaction( rows = self._new_transaction(
conn, "do_fetch", [], None, self._fetch_event_rows, event_ids conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
) )
row_dict = { row_dict = {