Convert storage layer to be mysql compatible

This commit is contained in:
Erik Johnston 2015-03-19 15:59:48 +00:00
parent 58ed393235
commit d7a0496f3e
13 changed files with 171 additions and 101 deletions

View File

@ -51,6 +51,8 @@ import logging
import os import os
import re import re
import threading
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,6 +91,9 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token() self.min_token_deferred = self._get_min_token()
self.min_token = None self.min_token = None
self._next_stream_id_lock = threading.Lock()
self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False, def persist_event(self, event, context, backfilled=False,
@ -172,7 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": s.type, "type": s.type,
"state_key": s.state_key, "state_key": s.state_key,
}, },
or_replace=True,
) )
if event.is_state() and is_new_state: if event.is_state() and is_new_state:
@ -186,7 +190,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
}, },
or_replace=True,
) )
for prev_state_id, _ in event.prev_state: for prev_state_id, _ in event.prev_state:
@ -285,7 +288,6 @@ class DataStore(RoomMemberStore, RoomStore,
"internal_metadata": metadata_json.decode("UTF-8"), "internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"), "json": encode_canonical_json(event_dict).decode("UTF-8"),
}, },
or_replace=True,
) )
content = encode_canonical_json( content = encode_canonical_json(
@ -303,8 +305,9 @@ class DataStore(RoomMemberStore, RoomStore,
"depth": event.depth, "depth": event.depth,
} }
if stream_ordering is not None: if stream_ordering is None:
vals["stream_ordering"] = stream_ordering stream_ordering = self.get_next_stream_id()
unrec = { unrec = {
k: v k: v
@ -322,21 +325,18 @@ class DataStore(RoomMemberStore, RoomStore,
unrec unrec
).decode("UTF-8") ).decode("UTF-8")
try: sql = (
self._simple_insert_txn( "INSERT INTO events"
txn, " (stream_ordering, topological_ordering, event_id, type,"
"events", " room_id, content, processed, outlier, depth)"
vals, " VALUES (%s,?,?,?,?,?,?,?,?)"
or_replace=(not outlier), ) % (stream_ordering,)
or_ignore=bool(outlier),
txn.execute(
sql,
(event.depth, event.event_id, event.type, event.room_id,
content, True, outlier, event.depth)
) )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
event.event_id,
exc_info=True,
)
raise _RollbackButIsFineException("_persist_event")
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(txn, event.event_id, context.rejected)
@ -357,7 +357,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn, txn,
"state_events", "state_events",
vals, vals,
or_replace=True,
) )
if is_new_state and not context.rejected: if is_new_state and not context.rejected:
@ -370,7 +369,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
}, },
or_replace=True,
) )
for e_id, h in event.prev_state: for e_id, h in event.prev_state:
@ -383,7 +381,6 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id, "room_id": event.room_id,
"is_state": 1, "is_state": 1,
}, },
or_ignore=True,
) )
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
@ -408,7 +405,6 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id, "room_id": event.room_id,
"auth_id": auth_id, "auth_id": auth_id,
}, },
or_ignore=True,
) )
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
@ -420,8 +416,7 @@ class DataStore(RoomMemberStore, RoomStore,
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts) self._get_event_cache.pop(event.redacts)
txn.execute( txn.execute(
"INSERT OR IGNORE INTO redactions " "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
"(event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts) (event.event_id, event.redacts)
) )
@ -515,7 +510,8 @@ class DataStore(RoomMemberStore, RoomStore,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"last_seen": int(self._clock.time_msec()), "last_seen": int(self._clock.time_msec()),
} },
or_replace=True,
) )
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
@ -559,6 +555,12 @@ class DataStore(RoomMemberStore, RoomStore,
"have_events", f, "have_events", f,
) )
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def read_schema(path): def read_schema(path):
""" Read the named database schema. """ Read the named database schema.
@ -594,7 +596,7 @@ def prepare_database(db_conn):
else: else:
_setup_new_database(cur) _setup_new_database(cur)
cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -657,19 +659,17 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir) directory_entries = os.listdir(sql_dir)
sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"): for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename) sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc) logger.debug("Applying schema %s", sql_loc)
sql_script += read_schema(sql_loc) executescript(cur, sql_loc)
sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
cur.executescript(sql_script)
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" _convert_param_style(
" VALUES (?,?)", "REPLACE INTO schema_version (version, upgraded)"
(max_current_ver, False) " VALUES (?,?)"
),
(max_current_ver, False,)
) )
_upgrade_existing_database( _upgrade_existing_database(
@ -737,6 +737,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded: if not upgraded:
start_ver += 1 start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1): for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v) logger.debug("Upgrading schema to v%d", v)
@ -753,6 +755,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort() directory_entries.sort()
for file_name in directory_entries: for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name) relative_path = os.path.join(str(v), file_name)
logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files: if relative_path in applied_delta_files:
continue continue
@ -774,9 +777,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur) module.run_upgrade(cur)
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)
cur.executescript(delta_schema) executescript(cur, absolute_path)
else: else:
# Not a valid delta file. # Not a valid delta file.
logger.warn( logger.warn(
@ -788,24 +790,85 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done. # Mark as done.
cur.execute( cur.execute(
_convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)" "INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)", " VALUES (?,?)"
),
(v, relative_path) (v, relative_path)
) )
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" _convert_param_style(
" VALUES (?,?)", "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)"
),
(v, True) (v, True)
) )
def _convert_param_style(sql):
return sql.replace("?", "%s")
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
for line in f:
line = line.strip()
if in_comment:
# Check if this line contains an end to the comment
comments = line.split("*/", 1)
if len(comments) == 1:
continue
line = comments[1]
in_comment = False
# Remove inline block comments
line = re.sub(r"/\*.*\*/", " ", line)
# Does this line start a comment?
comments = line.split("/*", 1)
if len(comments) > 1:
line = comments[0]
in_comment = True
# Deal with line comments
line = line.split("--", 1)[0]
line = line.split("//", 1)[0]
# Find *all* semicolons. We need to treat first and last entry
# specially.
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement
# Every entry, except the last, is a full statement
for statement in statements[:-1]:
yield statement.strip()
# The last entry did *not* end in a semicolon, so we store it for the
# next semicolon we find
statement_buffer = statements[-1].strip()
def executescript(txn, schema_path):
with open(schema_path, 'r') as f:
for statement in get_statements(f):
txn.execute(statement)
def _get_or_create_schema_state(txn): def _get_or_create_schema_state(txn):
schema_path = os.path.join( schema_path = os.path.join(
dir_path, "schema", "schema_version.sql", dir_path, "schema", "schema_version.sql",
) )
create_schema = read_schema(schema_path) executescript(txn, schema_path)
txn.executescript(create_schema)
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()
@ -814,10 +877,13 @@ def _get_or_create_schema_state(txn):
if current_version: if current_version:
txn.execute( txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?", _convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,) (current_version,)
) )
return current_version, txn.fetchall(), upgraded applied_deltas = [d for d, in txn.fetchall()]
return current_version, applied_deltas, upgraded
return None return None
@ -849,7 +915,9 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]: if row and row[0]:
db_conn.execute( db_conn.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" _convert_param_style(
" VALUES (?,?)", "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)"
),
(row[0], False) (row[0], False)
) )

View File

@ -102,6 +102,10 @@ def cached(max_entries=1000):
return wrap return wrap
def _convert_param_style(sql):
return sql.replace("?", "%s")
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
@ -122,6 +126,8 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = _convert_param_style(sql)
try: try:
if args and args[0]: if args and args[0]:
values = args[0] values = args[0]
@ -305,11 +311,11 @@ class SQLBaseStore(object):
The result of decoder(results) The result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) txn.execute(query, args)
if decoder: if decoder:
return decoder(cursor) return decoder(txn)
else: else:
return cursor.fetchall() return txn.fetchall()
return self.runInteraction(desc, interaction) return self.runInteraction(desc, interaction)
@ -337,8 +343,7 @@ class SQLBaseStore(object):
def _simple_insert_txn(self, txn, table, values, or_replace=False, def _simple_insert_txn(self, txn, table, values, or_replace=False,
or_ignore=False): or_ignore=False):
sql = "%s INTO %s (%s) VALUES(%s)" % ( sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else ("REPLACE" if or_replace else "INSERT"),
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table, table,
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)
@ -449,7 +454,6 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s" "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
"ORDER BY rowid asc"
) % { ) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
@ -505,14 +509,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
else: else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s" % (
", ".join(retcols), ", ".join(retcols),
table table
) )
@ -546,7 +550,7 @@ class SQLBaseStore(object):
retcols=None, allow_none=False): retcols=None, allow_none=False):
""" Combined SELECT then UPDATE.""" """ Combined SELECT then UPDATE."""
if retcols: if retcols:
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k) for k in keyvalues)
@ -580,8 +584,8 @@ class SQLBaseStore(object):
updatevalues.values() + keyvalues.values() updatevalues.values() + keyvalues.values()
) )
if txn.rowcount == 0: # if txn.rowcount == 0:
raise StoreError(404, "No row found") # raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
@ -802,7 +806,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s" _select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s" _select_clause = "SELECT %s FROM %s"
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod @classmethod
def select_statement(cls, where_clause=None): def select_statement(cls, where_clause=None):

View File

@ -147,11 +147,11 @@ class ApplicationServiceStore(SQLBaseStore):
return True return True
def _get_as_id_txn(self, txn, token): def _get_as_id_txn(self, txn, token):
cursor = txn.execute( txn.execute(
"SELECT id FROM application_services WHERE token=?", "SELECT id FROM application_services WHERE token=?",
(token,) (token,)
) )
res = cursor.fetchone() res = txn.fetchone()
if res: if res:
return res[0] return res[0]

View File

@ -111,12 +111,12 @@ class DirectoryStore(SQLBaseStore):
) )
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):
cursor = txn.execute( txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?", "SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),) (room_alias.to_string(),)
) )
res = cursor.fetchone() res = txn.fetchone()
if res: if res:
room_id = res[0] room_id = res[0]
else: else:

View File

@ -242,7 +242,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"min_depth": depth, "min_depth": depth,
}, },
or_replace=True,
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@ -262,7 +261,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"is_state": 0, "is_state": 0,
}, },
or_ignore=True,
) )
# Update the extremities table if this is not an outlier. # Update the extremities table if this is not an outlier.
@ -281,19 +279,19 @@ class EventFederationStore(SQLBaseStore):
# We only insert as a forward extremity the new event if there are # We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event # no other events that reference it as a prev event
query = ( query = (
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) " "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
"SELECT ?, ? WHERE NOT EXISTS (" )
"SELECT 1 FROM %(event_edges)s WHERE "
"prev_event_id = ? "
")"
) % {
"table": "event_forward_extremities",
"event_edges": "event_edges",
}
logger.debug("query: %s", query) txn.execute(query, (event_id,))
txn.execute(query, (event_id, room_id, event_id)) if not txn.fetchone():
query = (
"INSERT INTO event_forward_extremities"
" (event_id, room_id)"
" VALUES (?, ?)"
)
txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get # Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway. # deleted in a second if they're incorrect anyway.
@ -306,7 +304,6 @@ class EventFederationStore(SQLBaseStore):
"event_id": e_id, "event_id": e_id,
"room_id": room_id, "room_id": room_id,
}, },
or_ignore=True,
) )
# Also delete from the backwards extremities table all ones that # Also delete from the backwards extremities table all ones that

View File

@ -45,7 +45,6 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"], "status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()}, "mtime": self._clock.time_msec()},
retcols=["state"],
) )
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):

View File

@ -153,7 +153,7 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule # now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@ -182,7 +182,7 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"

View File

@ -39,14 +39,10 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
yield self._simple_insert( yield self._simple_insert(
"access_tokens", "access_tokens",
{ {
"user_id": row_id, "user_id": user_id,
"token": token "token": token
} }
) )
@ -82,7 +78,7 @@ class RegistrationStore(SQLBaseStore):
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute("INSERT INTO access_tokens(user_id, token) " + txn.execute("INSERT INTO access_tokens(user_id, token) " +
"VALUES (?,?)", [txn.lastrowid, token]) "VALUES (?,?)", [user_id, token])
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
query = ("SELECT users.name, users.password_hash FROM users" query = ("SELECT users.name, users.password_hash FROM users"
@ -124,12 +120,12 @@ class RegistrationStore(SQLBaseStore):
"SELECT users.name, users.admin," "SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id" " access_tokens.device_id, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"
) )
cursor = txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.cursor_to_dict(cursor) rows = self.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return rows[0]

View File

@ -114,9 +114,9 @@ class RoomStore(SQLBaseStore):
"name": name_subquery, "name": name_subquery,
} }
c = txn.execute(sql, (is_public,)) txn.execute(sql, (is_public,))
return c.fetchall() return txn.fetchall()
rows = yield self.runInteraction( rows = yield self.runInteraction(
"get_rooms", f "get_rooms", f

View File

@ -68,7 +68,7 @@ class RoomMemberStore(SQLBaseStore):
# Update room hosts table # Update room hosts table
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
sql = ( sql = (
"INSERT OR IGNORE INTO room_hosts (room_id, host) " "REPLACE INTO room_hosts (room_id, host) "
"VALUES (?, ?)" "VALUES (?, ?)"
) )
txn.execute(sql, (event.room_id, domain)) txn.execute(sql, (event.room_id, domain))

View File

@ -15,6 +15,8 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.stringutils import random_string
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,14 +91,15 @@ class StateStore(SQLBaseStore):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
group = _make_group_id(self._clock)
state_group = self._simple_insert_txn( state_group = self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
or_ignore=True,
) )
for state in state_events.values(): for state in state_events.values():
@ -110,7 +113,6 @@ class StateStore(SQLBaseStore):
"state_key": state.state_key, "state_key": state.state_key,
"event_id": state.event_id, "event_id": state.event_id,
}, },
or_ignore=True,
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -122,3 +124,7 @@ class StateStore(SQLBaseStore):
}, },
or_replace=True, or_replace=True,
) )
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View File

@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering") return "(%d < %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d < %s OR (%d == %s AND %d < %s))" % ( return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",
@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering") return "(%d >= %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d > %s OR (%d == %s AND %d >= %s))" % ( return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",

View File

@ -121,8 +121,8 @@ class TransactionStore(SQLBaseStore):
SentTransactions.select_statement("destination = ?"), SentTransactions.select_statement("destination = ?"),
) )
results = txn.execute(query, (destination,)) txn.execute(query, (destination,))
results = SentTransactions.decode_results(results) results = SentTransactions.decode_results(txn)
prev_txns = [r.transaction_id for r in results] prev_txns = [r.transaction_id for r in results]
@ -266,7 +266,7 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
query = ( query = (
"INSERT OR REPLACE INTO %s " "REPLACE INTO %s "
"(destination, retry_last_ts, retry_interval) " "(destination, retry_last_ts, retry_interval) "
"VALUES (?, ?, ?) " "VALUES (?, ?, ?) "
) % DestinationsTable.table_name ) % DestinationsTable.table_name