diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 87b4d381c7..6a2339f2eb 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler + RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, ) from .message import MessageHandler from .events import EventStreamHandler, EventHandler @@ -70,3 +70,4 @@ class Handlers(object): self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) + self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 60f9fa58b0..36878a6c20 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from collections import OrderedDict from unpaddedbase64 import decode_base64 import logging +import math import string logger = logging.getLogger(__name__) @@ -747,6 +748,60 @@ class RoomListHandler(BaseHandler): defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) +class RoomContextHandler(BaseHandler): + @defer.inlineCallbacks + def get_event_context(self, user, room_id, event_id, limit): + """Retrieves events, pagination tokens and state around a given event + in a room. + + Args: + user (UserID) + room_id (str) + event_id (str) + limit (int): The maximum number of events to return in total + (excluding state). + + Returns: + dict + """ + before_limit = math.floor(limit/2.) + after_limit = limit - before_limit + + now_token = yield self.hs.get_event_sources().get_current_token() + + results = yield self.store.get_events_around( + room_id, event_id, before_limit, after_limit + ) + + results["events_before"] = yield self._filter_events_for_client( + user.to_string(), results["events_before"] + ) + + results["events_after"] = yield self._filter_events_for_client( + user.to_string(), results["events_after"] + ) + + if results["events_after"]: + last_event_id = results["events_after"][-1].event_id + else: + last_event_id = event_id + + state = yield self.store.get_state_for_events( + [last_event_id], None + ) + results["state"] = state[last_event_id].values() + + results["start"] = now_token.copy_and_replace( + "room_key", results["start"] + ).to_string() + + results["end"] = now_token.copy_and_replace( + "room_key", results["end"] + ).to_string() + + defer.returnValue(results) + + class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4cee1c1599..2dcaee86cd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -397,6 +397,41 @@ class RoomTriggerBackfill(ClientV1RestServlet): defer.returnValue((200, res)) +class RoomEventContext(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/rooms/(?P[^/]*)/context/(?P[^/]*)$" + ) + + def __init__(self, hs): + super(RoomEventContext, self).__init__(hs) + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, event_id): + user, _ = yield self.auth.get_user_by_req(request) + + limit = int(request.args.get("limit", [10])[0]) + + results = yield self.handlers.room_context_handler.get_event_context( + user, room_id, event_id, limit, + ) + + time_now = self.clock.time_msec() + results["events_before"] = [ + serialize_event(event, time_now) for event in results["events_before"] + ] + results["events_after"] = [ + serialize_event(event, time_now) for event in results["events_after"] + ] + results["state"] = [ + serialize_event(event, time_now) for event in results["state"] + ] + + logger.info("Responding with %r", results) + + defer.returnValue((200, results)) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -628,3 +663,4 @@ def register_servlets(hs, http_server): RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + RoomEventContext(hs).register(http_server) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3cab06fdef..15d4c2bf68 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -23,7 +23,7 @@ paginate bacwards. This is implemented by keeping two ordering columns: stream_ordering and topological_ordering. Stream ordering is basically insertion/received order -(except for events from backfill requests). The topolgical_ordering is a +(except for events from backfill requests). The topological_ordering is a weak ordering of events based on the pdu graph. This means that we have to have two different types of tokens, depending on @@ -436,3 +436,138 @@ class StreamStore(SQLBaseStore): internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) + + @defer.inlineCallbacks + def get_events_around(self, room_id, event_id, before_limit, after_limit): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + + results = yield self.runInteraction( + "get_events_around", self._get_events_around_txn, + room_id, event_id, before_limit, after_limit + ) + + events_before = yield self._get_events( + [e for e in results["before"]["event_ids"]], + get_prev_content=True + ) + + events_after = yield self._get_events( + [e for e in results["after"]["event_ids"]], + get_prev_content=True + ) + + defer.returnValue({ + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + }) + + def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + + results = self._simple_select_one_txn( + txn, + "events", + keyvalues={ + "event_id": event_id, + "room_id": room_id, + }, + retcols=["stream_ordering", "topological_ordering"], + ) + + stream_ordering = results["stream_ordering"] + topological_ordering = results["topological_ordering"] + + query_before = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering < ?" + " OR (topological_ordering = ? AND stream_ordering < ?))" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + + query_after = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering > ?" + " OR (topological_ordering = ? AND stream_ordering > ?))" + " ORDER BY topological_ordering ASC, stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute( + query_before, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, before_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_before = [r["event_id"] for r in rows] + + if rows: + start_token = str(RoomStreamToken( + rows[0]["topological_ordering"], + rows[0]["stream_ordering"] - 1, + )) + else: + start_token = str(RoomStreamToken( + topological_ordering, + stream_ordering - 1, + )) + + txn.execute( + query_after, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, after_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_after = [r["event_id"] for r in rows] + + if rows: + end_token = str(RoomStreamToken( + rows[-1]["topological_ordering"], + rows[-1]["stream_ordering"], + )) + else: + end_token = str(RoomStreamToken( + topological_ordering, + stream_ordering, + )) + + return { + "before": { + "event_ids": events_before, + "token": start_token, + }, + "after": { + "event_ids": events_after, + "token": end_token, + }, + }