make /context lazyload & filter aware (#3567)
make /context lazyload & filter aware.
This commit is contained in:
parent
b0b5566f36
commit
e9b2d047f6
|
@ -0,0 +1 @@
|
|||
make the /context API filter & lazy-load aware as per MSC1227
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Contains functions for performing events on rooms."""
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import string
|
||||
|
@ -401,7 +402,7 @@ class RoomContextHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event_context(self, user, room_id, event_id, limit):
|
||||
def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
"""Retrieves events, pagination tokens and state around a given event
|
||||
in a room.
|
||||
|
||||
|
@ -411,6 +412,8 @@ class RoomContextHandler(object):
|
|||
event_id (str)
|
||||
limit (int): The maximum number of events to return in total
|
||||
(excluding state).
|
||||
event_filter (Filter|None): the filter to apply to the events returned
|
||||
(excluding the target event_id)
|
||||
|
||||
Returns:
|
||||
dict, or None if the event isn't found
|
||||
|
@ -443,7 +446,7 @@ class RoomContextHandler(object):
|
|||
)
|
||||
|
||||
results = yield self.store.get_events_around(
|
||||
room_id, event_id, before_limit, after_limit
|
||||
room_id, event_id, before_limit, after_limit, event_filter
|
||||
)
|
||||
|
||||
results["events_before"] = yield filter_evts(results["events_before"])
|
||||
|
@ -455,8 +458,23 @@ class RoomContextHandler(object):
|
|||
else:
|
||||
last_event_id = event_id
|
||||
|
||||
types = None
|
||||
filtered_types = None
|
||||
if event_filter and event_filter.lazy_load_members():
|
||||
members = set(ev.sender for ev in itertools.chain(
|
||||
results["events_before"],
|
||||
(results["event"],),
|
||||
results["events_after"],
|
||||
))
|
||||
filtered_types = [EventTypes.Member]
|
||||
types = [(EventTypes.Member, member) for member in members]
|
||||
|
||||
# XXX: why do we return the state as of the last event rather than the
|
||||
# first? Shouldn't we be consistent with /sync?
|
||||
# https://github.com/matrix-org/matrix-doc/issues/687
|
||||
|
||||
state = yield self.store.get_state_for_events(
|
||||
[last_event_id], None
|
||||
[last_event_id], types, filtered_types=filtered_types,
|
||||
)
|
||||
results["state"] = list(state[last_event_id].values())
|
||||
|
||||
|
|
|
@ -287,7 +287,7 @@ class SearchHandler(BaseHandler):
|
|||
contexts = {}
|
||||
for event in allowed_events:
|
||||
res = yield self.store.get_events_around(
|
||||
event.room_id, event.event_id, before_limit, after_limit
|
||||
event.room_id, event.event_id, before_limit, after_limit,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
|
@ -531,11 +531,20 @@ class RoomEventContextServlet(ClientV1RestServlet):
|
|||
|
||||
limit = parse_integer(request, "limit", default=10)
|
||||
|
||||
# picking the API shape for symmetry with /messages
|
||||
filter_bytes = parse_string(request, "filter")
|
||||
if filter_bytes:
|
||||
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
|
||||
event_filter = Filter(json.loads(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
results = yield self.room_context_handler.get_event_context(
|
||||
requester.user,
|
||||
room_id,
|
||||
event_id,
|
||||
limit,
|
||||
event_filter,
|
||||
)
|
||||
|
||||
if not results:
|
||||
|
|
|
@ -527,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_around(self, room_id, event_id, before_limit, after_limit):
|
||||
def get_events_around(
|
||||
self, room_id, event_id, before_limit, after_limit, event_filter=None,
|
||||
):
|
||||
"""Retrieve events and pagination tokens around a given event in a
|
||||
room.
|
||||
|
||||
|
@ -536,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
event_id (str)
|
||||
before_limit (int)
|
||||
after_limit (int)
|
||||
event_filter (Filter|None)
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
@ -543,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
results = yield self.runInteraction(
|
||||
"get_events_around", self._get_events_around_txn,
|
||||
room_id, event_id, before_limit, after_limit
|
||||
room_id, event_id, before_limit, after_limit, event_filter,
|
||||
)
|
||||
|
||||
events_before = yield self._get_events(
|
||||
|
@ -563,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
"end": results["after"]["token"],
|
||||
})
|
||||
|
||||
def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit):
|
||||
def _get_events_around_txn(
|
||||
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
|
||||
):
|
||||
"""Retrieves event_ids and pagination tokens around a given event in a
|
||||
room.
|
||||
|
||||
|
@ -572,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
event_id (str)
|
||||
before_limit (int)
|
||||
after_limit (int)
|
||||
event_filter (Filter|None)
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
@ -601,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
rows, start_token = self._paginate_room_events_txn(
|
||||
txn, room_id, before_token, direction='b', limit=before_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
events_before = [r.event_id for r in rows]
|
||||
|
||||
rows, end_token = self._paginate_room_events_txn(
|
||||
txn, room_id, after_token, direction='f', limit=after_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
events_after = [r.event_id for r in rows]
|
||||
|
||||
|
|
Loading…
Reference in New Issue