Merge pull request #403 from matrix-org/erikj/search-ts

Allow paginating search ordered by recents
This commit is contained in:
Erik Johnston 2015-12-01 14:46:27 +00:00
commit 2430fcd462
4 changed files with 229 additions and 93 deletions

View File

@ -131,6 +131,17 @@ class SearchHandler(BaseHandler):
if batch_group == "room_id": if batch_group == "room_id":
room_ids.intersection_update({batch_group_key}) room_ids.intersection_update({batch_group_key})
if not room_ids:
defer.returnValue({
"search_categories": {
"room_events": {
"results": {},
"count": 0,
"highlights": [],
}
}
})
rank_map = {} # event_id -> rank of event rank_map = {} # event_id -> rank of event
allowed_events = [] allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable room_groups = {} # Holds result of grouping by room, if applicable
@ -178,24 +189,18 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id) s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
# In this case we specifically loop through each room as the given
# limit applies to each room, rather than a global list.
# This is not necessarilly a good idea.
for room_id in room_ids:
room_events = [] room_events = []
if batch_group == "room_id" and batch_group_key == room_id:
pagination_token = batch_token
else:
pagination_token = None
i = 0 i = 0
pagination_token = batch_token
# We keep looping and we keep filtering until we reach the limit # We keep looping and we keep filtering until we reach the limit
# or we run out of things. # or we run out of things.
# But only go around 5 times since otherwise synapse will be sad. # But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5: while len(room_events) < search_filter.limit() and i < 5:
i += 1 i += 1
search_result = yield self.store.search_room( search_result = yield self.store.search_rooms(
room_id, search_term, keys, search_filter.limit() * 2, room_ids, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token, pagination_token=pagination_token,
) )
@ -225,39 +230,27 @@ class SearchHandler(BaseHandler):
else: else:
pagination_token = results[-1]["pagination_token"] pagination_token = results[-1]["pagination_token"]
if room_events: for event in room_events:
res = results_map[room_events[-1].event_id] group = room_groups.setdefault(event.room_id, {
pagination_token = res["pagination_token"] "results": [],
})
group["results"].append(event.event_id)
group = room_groups.setdefault(room_id, {}) if room_events and len(room_events) >= search_filter.limit():
if pagination_token: last_event_id = room_events[-1].event_id
next_batch = encode_base64("%s\n%s\n%s" % ( pagination_token = results_map[last_event_id]["pagination_token"]
global_next_batch = encode_base64("%s\n%s\n%s" % (
"all", "", pagination_token
))
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token "room_id", room_id, pagination_token
)) ))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events) allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else:
room_groups.values()[0]["order"] = 1
else: else:
# We should never get here due to the guard earlier. # We should never get here due to the guard earlier.
raise NotImplementedError() raise NotImplementedError()

View File

@ -51,6 +51,14 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True): is_new_state=True):
@ -365,6 +373,7 @@ class EventsStore(SQLBaseStore):
"processed": True, "processed": True,
"outlier": event.internal_metadata.is_outlier(), "outlier": event.internal_metadata.is_outlier(),
"content": encode_json(event.content).decode("UTF-8"), "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],
@ -964,3 +973,71 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages) ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
rows = []
for event in events:
try:
event_id = event.event_id
origin_server_ts = event.origin_server_ts
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
rows.append((origin_server_ts, event_id))
sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
)
for index in range(0, len(rows), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)

View File

@ -0,0 +1,57 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
import ujson
logger = logging.getLogger(__name__)
ALTER_TABLE = (
"ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;"
"CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
cur.execute("SELECT MIN(stream_ordering) FROM events")
rows = cur.fetchall()
min_stream_id = rows[0][0]
cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
)
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))

View File

@ -212,11 +212,11 @@ class SearchStore(BackgroundUpdateStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None): def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
room_id (str): The room_id to search in room_id (list): The room_ids to search in
search_term (str): Search term to search for search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
@ -226,7 +226,15 @@ class SearchStore(BackgroundUpdateStore):
list of dicts list of dicts
""" """
clauses = [] clauses = []
args = [search_term, room_id] args = [search_term]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -239,25 +247,25 @@ class SearchStore(BackgroundUpdateStore):
if pagination_token: if pagination_token:
try: try:
topo, stream = pagination_token.split(",") origin_server_ts, stream = pagination_token.split(",")
topo = int(topo) origin_server_ts = int(origin_server_ts)
stream = int(stream) stream = int(stream)
except: except:
raise SynapseError(400, "Invalid pagination token") raise SynapseError(400, "Invalid pagination token")
clauses.append( clauses.append(
"(topological_ordering < ?" "(origin_server_ts < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))" " OR (origin_server_ts = ? AND stream_ordering < ?))"
) )
args.extend([topo, topo, stream]) args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, query) as rank," "SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id" " origin_server_ts, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search" " FROM plainto_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events" " NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?" " WHERE vector @@ query AND "
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes. # We use CROSS JOIN here to ensure we use the right indexes.
@ -270,24 +278,23 @@ class SearchStore(BackgroundUpdateStore):
# MATCH unless it uses the full text search index # MATCH unless it uses the full text search index
sql = ( sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id," "SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering" " origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search" " FROM event_search"
" WHERE value MATCH ?" " WHERE value MATCH ?"
" )" " )"
" CROSS JOIN events USING (event_id)" " CROSS JOIN events USING (event_id)"
" WHERE room_id = ?" " WHERE "
) )
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
for clause in clauses: sql += " AND ".join(clauses)
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the # We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database. # entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
args.append(limit) args.append(limit)
@ -295,6 +302,8 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args "search_rooms", self.cursor_to_dict, sql, *args
) )
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {
@ -312,7 +321,7 @@ class SearchStore(BackgroundUpdateStore):
"event": event_map[r["event_id"]], "event": event_map[r["event_id"]],
"rank": r["rank"], "rank": r["rank"],
"pagination_token": "%s,%s" % ( "pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"] r["origin_server_ts"], r["stream_ordering"]
), ),
} }
for r in results for r in results