642 lines
21 KiB
Python
642 lines
21 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2014 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
from synapse.api.errors import StoreError
|
|
from synapse.events import FrozenEvent
|
|
from synapse.events.utils import prune_event
|
|
from synapse.util.logutils import log_function
|
|
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
|
|
|
from twisted.internet import defer
|
|
|
|
import collections
|
|
import json
|
|
import sys
|
|
import time
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
sql_logger = logging.getLogger("synapse.storage.SQL")
|
|
transaction_logger = logging.getLogger("synapse.storage.txn")
|
|
|
|
|
|
class LoggingTransaction(object):
|
|
"""An object that almost-transparently proxies for the 'txn' object
|
|
passed to the constructor. Adds logging to the .execute() method."""
|
|
__slots__ = ["txn", "name"]
|
|
|
|
def __init__(self, txn, name):
|
|
object.__setattr__(self, "txn", txn)
|
|
object.__setattr__(self, "name", name)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.txn, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
setattr(self.txn, name, value)
|
|
|
|
def execute(self, sql, *args, **kwargs):
|
|
# TODO(paul): Maybe use 'info' and 'debug' for values?
|
|
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
|
|
try:
|
|
if args and args[0]:
|
|
values = args[0]
|
|
sql_logger.debug(
|
|
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
|
|
self.name,
|
|
*values
|
|
)
|
|
except:
|
|
# Don't let logging failures stop SQL from working
|
|
pass
|
|
|
|
start = time.clock() * 1000
|
|
try:
|
|
return self.txn.execute(
|
|
sql, *args, **kwargs
|
|
)
|
|
except:
|
|
logger.exception("[SQL FAIL] {%s}", self.name)
|
|
raise
|
|
finally:
|
|
end = time.clock() * 1000
|
|
sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
|
|
|
|
|
|
class SQLBaseStore(object):
|
|
_TXN_ID = 0
|
|
|
|
def __init__(self, hs):
|
|
self.hs = hs
|
|
self._db_pool = hs.get_db_pool()
|
|
self._clock = hs.get_clock()
|
|
|
|
@defer.inlineCallbacks
|
|
def runInteraction(self, desc, func, *args, **kwargs):
|
|
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
|
current_context = LoggingContext.current_context()
|
|
|
|
def inner_func(txn, *args, **kwargs):
|
|
with LoggingContext("runInteraction") as context:
|
|
current_context.copy_to(context)
|
|
start = time.clock() * 1000
|
|
txn_id = SQLBaseStore._TXN_ID
|
|
|
|
# We don't really need these to be unique, so lets stop it from
|
|
# growing really large.
|
|
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
|
|
|
name = "%s-%x" % (desc, txn_id, )
|
|
|
|
transaction_logger.debug("[TXN START] {%s}", name)
|
|
try:
|
|
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
|
except:
|
|
logger.exception("[TXN FAIL] {%s}", name)
|
|
raise
|
|
finally:
|
|
end = time.clock() * 1000
|
|
transaction_logger.debug(
|
|
"[TXN END] {%s} %f",
|
|
name, end - start
|
|
)
|
|
with PreserveLoggingContext():
|
|
result = yield self._db_pool.runInteraction(
|
|
inner_func, *args, **kwargs
|
|
)
|
|
defer.returnValue(result)
|
|
|
|
def cursor_to_dict(self, cursor):
|
|
"""Converts a SQL cursor into an list of dicts.
|
|
|
|
Args:
|
|
cursor : The DBAPI cursor which has executed a query.
|
|
Returns:
|
|
A list of dicts where the key is the column header.
|
|
"""
|
|
col_headers = list(column[0] for column in cursor.description)
|
|
results = list(
|
|
dict(zip(col_headers, row)) for row in cursor.fetchall()
|
|
)
|
|
return results
|
|
|
|
def _execute(self, decoder, query, *args):
|
|
"""Runs a single query for a result set.
|
|
|
|
Args:
|
|
decoder - The function which can resolve the cursor results to
|
|
something meaningful.
|
|
query - The query string to execute
|
|
*args - Query args.
|
|
Returns:
|
|
The result of decoder(results)
|
|
"""
|
|
def interaction(txn):
|
|
cursor = txn.execute(query, args)
|
|
if decoder:
|
|
return decoder(cursor)
|
|
else:
|
|
return cursor.fetchall()
|
|
|
|
return self.runInteraction("_execute", interaction)
|
|
|
|
def _execute_and_decode(self, query, *args):
|
|
return self._execute(self.cursor_to_dict, query, *args)
|
|
|
|
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
|
# no complex WHERE clauses, just a dict of values for columns.
|
|
|
|
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
|
"""Executes an INSERT query on the named table.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
values : dict of new column names and values for them
|
|
or_replace : bool; if True performs an INSERT OR REPLACE
|
|
"""
|
|
return self.runInteraction(
|
|
"_simple_insert",
|
|
self._simple_insert_txn, table, values, or_replace=or_replace,
|
|
or_ignore=or_ignore,
|
|
)
|
|
|
|
@log_function
|
|
def _simple_insert_txn(self, txn, table, values, or_replace=False,
|
|
or_ignore=False):
|
|
sql = "%s INTO %s (%s) VALUES(%s)" % (
|
|
("INSERT OR REPLACE" if or_replace else
|
|
"INSERT OR IGNORE" if or_ignore else "INSERT"),
|
|
table,
|
|
", ".join(k for k in values),
|
|
", ".join("?" for k in values)
|
|
)
|
|
|
|
logger.debug(
|
|
"[SQL] %s Args=%s",
|
|
sql, values.values(),
|
|
)
|
|
|
|
txn.execute(sql, values.values())
|
|
return txn.lastrowid
|
|
|
|
def _simple_select_one(self, table, keyvalues, retcols,
|
|
allow_none=False):
|
|
"""Executes a SELECT query on the named table, which is expected to
|
|
return a single row, returning a single column from it.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the row with
|
|
retcols : list of strings giving the names of the columns to return
|
|
|
|
allow_none : If true, return None instead of failing if the SELECT
|
|
statement returns no rows
|
|
"""
|
|
return self._simple_selectupdate_one(
|
|
table, keyvalues, retcols=retcols, allow_none=allow_none
|
|
)
|
|
|
|
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
|
allow_none=False):
|
|
"""Executes a SELECT query on the named table, which is expected to
|
|
return a single row, returning a single column from it."
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the row with
|
|
retcol : string giving the name of the column to return
|
|
"""
|
|
return self.runInteraction(
|
|
"_simple_select_one_onecol",
|
|
self._simple_select_one_onecol_txn,
|
|
table, keyvalues, retcol, allow_none=allow_none,
|
|
)
|
|
|
|
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
|
|
allow_none=False):
|
|
ret = self._simple_select_onecol_txn(
|
|
txn,
|
|
table=table,
|
|
keyvalues=keyvalues,
|
|
retcol=retcol,
|
|
)
|
|
|
|
if ret:
|
|
return ret[0]
|
|
else:
|
|
if allow_none:
|
|
return None
|
|
else:
|
|
raise StoreError(404, "No row found")
|
|
|
|
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
|
sql = (
|
|
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
|
|
"ORDER BY rowid asc"
|
|
) % {
|
|
"retcol": retcol,
|
|
"table": table,
|
|
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
|
}
|
|
|
|
txn.execute(sql, keyvalues.values())
|
|
|
|
return [r[0] for r in txn.fetchall()]
|
|
|
|
def _simple_select_onecol(self, table, keyvalues, retcol):
|
|
"""Executes a SELECT query on the named table, which returns a list
|
|
comprising of the values of the named column from the selected rows.
|
|
|
|
Args:
|
|
table (str): table name
|
|
keyvalues (dict): column names and values to select the rows with
|
|
retcol (str): column whos value we wish to retrieve.
|
|
|
|
Returns:
|
|
Deferred: Results in a list
|
|
"""
|
|
return self.runInteraction(
|
|
"_simple_select_onecol",
|
|
self._simple_select_onecol_txn,
|
|
table, keyvalues, retcol
|
|
)
|
|
|
|
def _simple_select_list(self, table, keyvalues, retcols):
|
|
"""Executes a SELECT query on the named table, which may return zero or
|
|
more rows, returning the result as a list of dicts.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the rows with
|
|
retcols : list of strings giving the names of the columns to return
|
|
"""
|
|
return self.runInteraction(
|
|
"_simple_select_list",
|
|
self._simple_select_list_txn,
|
|
table, keyvalues, retcols
|
|
)
|
|
|
|
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
|
|
"""Executes a SELECT query on the named table, which may return zero or
|
|
more rows, returning the result as a list of dicts.
|
|
|
|
Args:
|
|
txn : Transaction object
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the rows with
|
|
retcols : list of strings giving the names of the columns to return
|
|
"""
|
|
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
|
", ".join(retcols),
|
|
table,
|
|
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
|
)
|
|
|
|
txn.execute(sql, keyvalues.values())
|
|
return self.cursor_to_dict(txn)
|
|
|
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
|
retcols=None):
|
|
"""Executes an UPDATE query on the named table, setting new values for
|
|
columns in a row matching the key values.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the row with
|
|
updatevalues : dict giving column names and values to update
|
|
retcols : optional list of column names to return
|
|
|
|
If present, retcols gives a list of column names on which to perform
|
|
a SELECT statement *before* performing the UPDATE statement. The values
|
|
of these will be returned in a dict.
|
|
|
|
These are performed within the same transaction, allowing an atomic
|
|
get-and-set. This can be used to implement compare-and-set by putting
|
|
the update column in the 'keyvalues' dict as well.
|
|
"""
|
|
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
|
retcols=retcols)
|
|
|
|
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
|
retcols=None, allow_none=False):
|
|
""" Combined SELECT then UPDATE."""
|
|
if retcols:
|
|
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
|
", ".join(retcols),
|
|
table,
|
|
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
|
)
|
|
|
|
if updatevalues:
|
|
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
|
table,
|
|
", ".join("%s = ?" % (k) for k in updatevalues),
|
|
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
|
)
|
|
|
|
def func(txn):
|
|
ret = None
|
|
if retcols:
|
|
txn.execute(select_sql, keyvalues.values())
|
|
|
|
row = txn.fetchone()
|
|
if not row:
|
|
if allow_none:
|
|
return None
|
|
raise StoreError(404, "No row found")
|
|
if txn.rowcount > 1:
|
|
raise StoreError(500, "More than one row matched")
|
|
|
|
ret = dict(zip(retcols, row))
|
|
|
|
if updatevalues:
|
|
txn.execute(
|
|
update_sql,
|
|
updatevalues.values() + keyvalues.values()
|
|
)
|
|
|
|
if txn.rowcount == 0:
|
|
raise StoreError(404, "No row found")
|
|
if txn.rowcount > 1:
|
|
raise StoreError(500, "More than one row matched")
|
|
|
|
return ret
|
|
return self.runInteraction("_simple_selectupdate_one", func)
|
|
|
|
def _simple_delete_one(self, table, keyvalues):
|
|
"""Executes a DELETE query on the named table, expecting to delete a
|
|
single row.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the row with
|
|
"""
|
|
sql = "DELETE FROM %s WHERE %s" % (
|
|
table,
|
|
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
|
)
|
|
|
|
def func(txn):
|
|
txn.execute(sql, keyvalues.values())
|
|
if txn.rowcount == 0:
|
|
raise StoreError(404, "No row found")
|
|
if txn.rowcount > 1:
|
|
raise StoreError(500, "more than one row matched")
|
|
return self.runInteraction("_simple_delete_one", func)
|
|
|
|
def _simple_delete(self, table, keyvalues):
|
|
"""Executes a DELETE query on the named table.
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
keyvalues : dict of column names and values to select the row with
|
|
"""
|
|
|
|
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
|
|
|
def _simple_delete_txn(self, txn, table, keyvalues):
|
|
sql = "DELETE FROM %s WHERE %s" % (
|
|
table,
|
|
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
|
)
|
|
|
|
return txn.execute(sql, keyvalues.values())
|
|
|
|
def _simple_max_id(self, table):
|
|
"""Executes a SELECT query on the named table, expecting to return the
|
|
max value for the column "id".
|
|
|
|
Args:
|
|
table : string giving the table name
|
|
"""
|
|
sql = "SELECT MAX(id) AS id FROM %s" % table
|
|
|
|
def func(txn):
|
|
txn.execute(sql)
|
|
max_id = self.cursor_to_dict(txn)[0]["id"]
|
|
if max_id is None:
|
|
return 0
|
|
return max_id
|
|
|
|
return self.runInteraction("_simple_max_id", func)
|
|
|
|
def _get_events(self, event_ids):
|
|
return self.runInteraction(
|
|
"_get_events", self._get_events_txn, event_ids
|
|
)
|
|
|
|
def _get_events_txn(self, txn, event_ids):
|
|
events = []
|
|
for e_id in event_ids:
|
|
ev = self._get_event_txn(txn, e_id)
|
|
|
|
if ev:
|
|
events.append(ev)
|
|
|
|
return events
|
|
|
|
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
|
get_prev_content=True):
|
|
sql = (
|
|
"SELECT internal_metadata, json, r.event_id FROM event_json as e "
|
|
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
|
|
"WHERE e.event_id = ? "
|
|
"LIMIT 1 "
|
|
)
|
|
|
|
txn.execute(sql, (event_id,))
|
|
|
|
res = txn.fetchone()
|
|
|
|
if not res:
|
|
return None
|
|
|
|
internal_metadata, js, redacted = res
|
|
|
|
d = json.loads(js)
|
|
internal_metadata = json.loads(internal_metadata)
|
|
|
|
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
|
|
|
|
if check_redacted and redacted:
|
|
ev = prune_event(ev)
|
|
|
|
ev.unsigned["redacted_by"] = redacted
|
|
# Get the redaction event.
|
|
|
|
because = self._get_event_txn(
|
|
txn,
|
|
redacted,
|
|
check_redacted=False
|
|
)
|
|
|
|
if because:
|
|
ev.unsigned["redacted_because"] = because
|
|
|
|
if get_prev_content and "replaces_state" in ev.unsigned:
|
|
prev = self._get_event_txn(
|
|
txn,
|
|
ev.unsigned["replaces_state"],
|
|
get_prev_content=False,
|
|
)
|
|
if prev:
|
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
|
|
|
return ev
|
|
|
|
def _parse_events(self, rows):
|
|
return self.runInteraction(
|
|
"_parse_events", self._parse_events_txn, rows
|
|
)
|
|
|
|
def _parse_events_txn(self, txn, rows):
|
|
event_ids = [r["event_id"] for r in rows]
|
|
|
|
return self._get_events_txn(txn, event_ids)
|
|
|
|
def _has_been_redacted_txn(self, txn, event):
|
|
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
|
txn.execute(sql, (event.event_id,))
|
|
result = txn.fetchone()
|
|
return result[0] if result else None
|
|
|
|
|
|
class Table(object):
|
|
""" A base class used to store information about a particular table.
|
|
"""
|
|
|
|
table_name = None
|
|
""" str: The name of the table """
|
|
|
|
fields = None
|
|
""" list: The field names """
|
|
|
|
EntryType = None
|
|
""" Type: A tuple type used to decode the results """
|
|
|
|
_select_where_clause = "SELECT %s FROM %s WHERE %s"
|
|
_select_clause = "SELECT %s FROM %s"
|
|
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
|
|
|
|
@classmethod
|
|
def select_statement(cls, where_clause=None):
|
|
"""
|
|
Args:
|
|
where_clause (str): The WHERE clause to use.
|
|
|
|
Returns:
|
|
str: An SQL statement to select rows from the table with the given
|
|
WHERE clause.
|
|
"""
|
|
if where_clause:
|
|
return cls._select_where_clause % (
|
|
", ".join(cls.fields),
|
|
cls.table_name,
|
|
where_clause
|
|
)
|
|
else:
|
|
return cls._select_clause % (
|
|
", ".join(cls.fields),
|
|
cls.table_name,
|
|
)
|
|
|
|
@classmethod
|
|
def insert_statement(cls):
|
|
return cls._insert_clause % (
|
|
cls.table_name,
|
|
", ".join(cls.fields),
|
|
", ".join(["?"] * len(cls.fields)),
|
|
)
|
|
|
|
@classmethod
|
|
def decode_single_result(cls, results):
|
|
""" Given an iterable of tuples, return a single instance of
|
|
`EntryType` or None if the iterable is empty
|
|
Args:
|
|
results (list): The results list to convert to `EntryType`
|
|
Returns:
|
|
EntryType: An instance of `EntryType`
|
|
"""
|
|
results = list(results)
|
|
if results:
|
|
return cls.EntryType(*results[0])
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def decode_results(cls, results):
|
|
""" Given an iterable of tuples, return a list of `EntryType`
|
|
Args:
|
|
results (list): The results list to convert to `EntryType`
|
|
|
|
Returns:
|
|
list: A list of `EntryType`
|
|
"""
|
|
return [cls.EntryType(*row) for row in results]
|
|
|
|
@classmethod
|
|
def get_fields_string(cls, prefix=None):
|
|
if prefix:
|
|
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
|
|
else:
|
|
to_join = cls.fields
|
|
|
|
return ", ".join(to_join)
|
|
|
|
|
|
class JoinHelper(object):
|
|
""" Used to help do joins on tables by looking at the tables' fields and
|
|
creating a list of unique fields to use with SELECTs and a namedtuple
|
|
to dump the results into.
|
|
|
|
Attributes:
|
|
tables (list): List of `Table` classes
|
|
EntryType (type)
|
|
"""
|
|
|
|
def __init__(self, *tables):
|
|
self.tables = tables
|
|
|
|
res = []
|
|
for table in self.tables:
|
|
res += [f for f in table.fields if f not in res]
|
|
|
|
self.EntryType = collections.namedtuple("JoinHelperEntry", res)
|
|
|
|
def get_fields(self, **prefixes):
|
|
"""Get a string representing a list of fields for use in SELECT
|
|
statements with the given prefixes applied to each.
|
|
|
|
For example::
|
|
|
|
JoinHelper(PdusTable, StateTable).get_fields(
|
|
PdusTable="pdus",
|
|
StateTable="state"
|
|
)
|
|
"""
|
|
res = []
|
|
for field in self.EntryType._fields:
|
|
for table in self.tables:
|
|
if field in table.fields:
|
|
res.append("%s.%s" % (prefixes[table.__name__], field))
|
|
break
|
|
|
|
return ", ".join(res)
|
|
|
|
def decode_results(self, rows):
|
|
return [self.EntryType(*row) for row in rows]
|