Merge pull request #307 from matrix-org/erikj/search

Add basic search API
This commit is contained in:
Erik Johnston 2015-10-19 13:37:15 +01:00
commit e0bf0258ee
10 changed files with 365 additions and 2 deletions

View File

@ -32,6 +32,7 @@ from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler
class Handlers(object): class Handlers(object):
@ -68,3 +69,4 @@ class Handlers(object):
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs) self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)

View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
import logging
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def search(self, user, content):
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
Returns:
dict to be returned to the client with results of search
"""
try:
search_term = content["search_categories"]["room_events"]["search_term"]
keys = content["search_categories"]["room_events"].get("keys", [
"content.body", "content.name", "content.topic",
])
except KeyError:
raise SynapseError(400, "Invalid search query")
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
room_ids = set(r.room_id for r in rooms)
# TODO: Apply room filter to rooms list
rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys)
allowed_events = yield self._filter_events_for_client(
user.to_string(), event_map.values()
)
# TODO: Filter allowed_events
# TODO: Add a limit
time_now = self.clock.time_msec()
results = {
e.event_id: {
"rank": rank_map[e.event_id],
"result": serialize_event(e, time_now)
}
for e in allowed_events
}
logger.info("Found %d results", len(results))
defer.returnValue({
"search_categories": {
"room_events": {
"results": results,
"count": len(results)
}
}
})

View File

@ -555,6 +555,22 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/search$"
)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
results = yield self.handlers.search_handler.search(auth_user, content)
defer.returnValue((200, results))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -611,3 +627,4 @@ def register_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)

View File

@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore
import logging import logging
@ -69,6 +70,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventsStore, EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View File

@ -519,7 +519,7 @@ class SQLBaseStore(object):
allow_none=False, allow_none=False,
desc="_simple_select_one_onecol"): desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it." return a single row, returning a single column from it.
Args: Args:
table : string giving the table name table : string giving the table name

View File

@ -307,6 +307,8 @@ class EventsStore(SQLBaseStore):
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event) self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)

View File

@ -19,6 +19,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
@ -175,6 +176,10 @@ class RoomStore(SQLBaseStore):
}, },
) )
self._store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
if hasattr(event, "content") and "name" in event.content: if hasattr(event, "content") and "name" in event.content:
self._simple_insert_txn( self._simple_insert_txn(
@ -187,6 +192,33 @@ class RoomStore(SQLBaseStore):
} }
) )
self._store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.execute(sql, (event.event_id, event.room_id, key, value,))
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):

View File

@ -0,0 +1,123 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import ujson
logger = logging.getLogger(__name__)
POSTGRES_SQL = """
CREATE TABLE event_search (
event_id TEXT,
room_id TEXT,
key TEXT,
vector tsvector
);
INSERT INTO event_search SELECT
event_id, room_id, 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id);
"""
SQLITE_TABLE = (
"CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
run_postgres_upgrade(cur)
return
if isinstance(database_engine, Sqlite3Engine):
run_sqlite_upgrade(cur)
return
def run_postgres_upgrade(cur):
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement)
def run_sqlite_upgrade(cur):
cur.execute(SQLITE_TABLE)
rowid = -1
while True:
cur.execute(
"SELECT rowid, json FROM event_json"
" WHERE rowid > ?"
" ORDER BY rowid ASC LIMIT 100",
(rowid,)
)
res = cur.fetchall()
if not res:
break
events = [
ujson.loads(js)
for _, js in res
]
rowid = max(rid for rid, _ in res)
rows = []
for ev in events:
if ev["type"] == "m.room.message":
rows.append((
ev["event_id"], ev["room_id"], "content.body",
ev["content"]["body"]
))
if ev["type"] == "m.room.name":
rows.append((
ev["event_id"], ev["room_id"], "content.name",
ev["content"]["name"]
))
if ev["type"] == "m.room.topic":
rows.append((
ev["event_id"], ev["room_id"], "content.topic",
ev["content"]["topic"]
))
if rows:
logger.info(rows)
cur.executemany(
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)",
rows
)

93
synapse/storage/search.py Normal file
View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
class SearchStore(SQLBaseStore):
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
room_ids (list): List of room ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
2-tuple of (dict event_id -> rank, dict event_id -> event)
"""
clauses = []
args = []
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT 0 as rank, event_id FROM event_search"
" WHERE value MATCH ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue((
{
r["event_id"]: r["rank"]
for r in results
if r["event_id"] in event_map
},
event_map
))

View File

@ -214,7 +214,6 @@ class StateStore(SQLBaseStore):
that are in the `types` list. that are in the `types` list.
Args: Args:
room_id (str)
event_ids (list) event_ids (list)
types (list): List of (type, state_key) tuples which are used to types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches filter the state fetched. `state_key` may be None, which matches