Sanitize TransactionStore

This commit is contained in:
Erik Johnston 2015-03-23 13:43:21 +00:00
parent f6583796fe
commit 278149f533
2 changed files with 104 additions and 87 deletions

View File

@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts):
if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table, cached
from ._base import SQLBaseStore, cached
from collections import namedtuple
@ -84,13 +84,18 @@ class TransactionStore(SQLBaseStore):
def _set_received_txn_response(self, txn, transaction_id, origin, code,
response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND origin = ?"
) % ReceivedTransactionsTable.table_name
txn.execute(query, (code, response_json, transaction_id, origin))
self._simple_update_one_txn(
txn,
table=ReceivedTransactionsTable.table_name,
keyvalues={
"transaction_id": transaction_id,
"origin": origin,
},
updatevalues={
"response_code": code,
"response_json": response_json,
}
)
def prep_send_transaction(self, transaction_id, destination,
origin_server_ts):
@ -121,38 +126,32 @@ class TransactionStore(SQLBaseStore):
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
# we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % (
SentTransactions.select_statement("destination = ?"),
query = (
"SELECT * FROM sent_transactions"
" WHERE destination = ?"
" ORDER BY id DESC LIMIT 1"
)
txn.execute(query, (destination,))
results = SentTransactions.decode_results(txn.fetchall())
results = self.cursor_to_dict(txn)
prev_txns = [r.transaction_id for r in results]
prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table.
query = SentTransactions.insert_statement()
txn.execute(query, SentTransactions.EntryType(
self.get_next_stream_id(),
transaction_id=transaction_id,
destination=destination,
ts=origin_server_ts,
response_code=0,
response_json=None
))
self._simple_insert_txn(
txn,
table=SentTransactions.table_name,
values={
"transaction_id": self.get_next_stream_id(),
"destination": destination,
"ts": origin_server_ts,
"response_code": 0,
"response_json": None,
}
)
# Update the tx id -> pdu id mapping
# values = [
# (transaction_id, destination, pdu[0], pdu[1])
# for pdu in pdu_list
# ]
#
# logger.debug("Inserting: %s", repr(values))
#
# query = TransactionsToPduTable.insert_statement()
# txn.executemany(query, values)
# TODO Update the tx id -> pdu id mapping
return prev_txns
@ -171,15 +170,20 @@ class TransactionStore(SQLBaseStore):
transaction_id, destination, code, response_dict
)
def _delivered_txn(cls, txn, transaction_id, destination,
def _delivered_txn(self, txn, transaction_id, destination,
code, response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND destination = ?"
) % SentTransactions.table_name
txn.execute(query, (code, response_json, transaction_id, destination))
self._simple_update_one_txn(
txn,
table=SentTransactions.table_name,
keyvalues={
"transaction_id": transaction_id,
"destination": destination,
},
updatevalues={
"response_code": code,
"response_json": response_json,
}
)
def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id.
@ -189,25 +193,26 @@ class TransactionStore(SQLBaseStore):
destination (str)
Returns:
list: A list of `ReceivedTransactionsTable.EntryType`
list: A list of dicts
"""
return self.runInteraction(
"get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
def _get_transactions_after(cls, txn, transaction_id, destination):
where = (
"destination = ? AND id > (select id FROM %s WHERE "
"transaction_id = ? AND destination = ?)"
) % (
SentTransactions.table_name
def _get_transactions_after(self, txn, transaction_id, destination):
query = (
"SELECT * FROM sent_transactions"
" WHERE destination = ? AND id >"
" ("
" SELECT id FROM sent_transactions"
" WHERE transaction_id = ? AND destination = ?"
" )"
)
query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination))
return ReceivedTransactionsTable.decode_results(txn.fetchall())
return self.cursor_to_dict(txn)
@cached()
def get_destination_retry_timings(self, destination):
@ -218,19 +223,24 @@ class TransactionStore(SQLBaseStore):
Returns:
None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme
Otherwise a dict for the retry scheme
"""
return self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
def _get_destination_retry_timings(cls, txn, destination):
query = DestinationsTable.select_statement("destination = ?")
txn.execute(query, (destination,))
result = txn.fetchall()
if result:
result = DestinationsTable.decode_single_result(result)
if result.retry_last_ts > 0:
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
txn,
table=DestinationsTable.table_name,
keyvalues={
"destination": destination,
},
retcols=DestinationsTable.fields,
allow_none=True,
)
if result["retry_last_ts"] > 0:
return result
else:
return None
@ -249,11 +259,11 @@ class TransactionStore(SQLBaseStore):
# As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
DestinationsTable.EntryType(
destination,
retry_last_ts,
retry_interval
)
{
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval
},
)
# XXX: we could chose to not bother persisting this if our cache thinks
@ -270,18 +280,27 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval):
query = (
"REPLACE INTO %s "
"(destination, retry_last_ts, retry_interval) "
"VALUES (?, ?, ?) "
) % DestinationsTable.table_name
"INSERT INTO destinations"
" (destination, retry_last_ts, retry_interval)"
" VALUES (?, ?, ?)"
" ON DUPLICATE KEY UPDATE"
" retry_last_ts=?, retry_interval=?"
)
txn.execute(query, (destination, retry_last_ts, retry_interval))
txn.execute(
query,
(
destination,
retry_last_ts, retry_interval,
retry_last_ts, retry_interval,
)
)
def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction.
Returns:
list: A list of `DestinationsTable.EntryType`
list: A list of dicts
"""
return self.runInteraction(
@ -289,14 +308,17 @@ class TransactionStore(SQLBaseStore):
self._get_destinations_needing_retry
)
def _get_destinations_needing_retry(cls, txn):
where = "retry_last_ts > 0 and retry_next_ts < now()"
query = DestinationsTable.select_statement(where)
txn.execute(query)
return DestinationsTable.decode_results(txn.fetchall())
def _get_destinations_needing_retry(self, txn):
query = (
"SELECT * FROM destinations"
" WHERE retry_last_ts > 0 and retry_next_ts < ?"
)
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
class ReceivedTransactionsTable(Table):
class ReceivedTransactionsTable(object):
table_name = "received_transactions"
fields = [
@ -308,10 +330,8 @@ class ReceivedTransactionsTable(Table):
"has_been_referenced",
]
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
class SentTransactions(Table):
class SentTransactions(object):
table_name = "sent_transactions"
fields = [
@ -326,7 +346,7 @@ class SentTransactions(Table):
EntryType = namedtuple("SentTransactionsEntry", fields)
class TransactionsToPduTable(Table):
class TransactionsToPduTable(object):
table_name = "transaction_id_to_pdu"
fields = [
@ -336,10 +356,8 @@ class TransactionsToPduTable(Table):
"pdu_origin",
]
EntryType = namedtuple("TransactionsToPduEntry", fields)
class DestinationsTable(Table):
class DestinationsTable(object):
table_name = "destinations"
fields = [
@ -348,4 +366,3 @@ class DestinationsTable(Table):
"retry_interval",
]
EntryType = namedtuple("DestinationsEntry", fields)