Expose DataStore._get_events as get_events_as_list

This is in preparation for reaction work which requires it.
This commit is contained in:
Erik Johnston 2019-05-14 13:37:44 +01:00
parent df2ebd75d3
commit 4fb44fb5b9
6 changed files with 54 additions and 30 deletions

View File

@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = json.loads(entry["event_ids"]) event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids) events = yield self.get_events_as_list(event_ids)
defer.returnValue( defer.returnValue(
AppServiceTransaction(service=service, id=entry["txn_id"], events=events) AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn "get_new_events_for_appservice", get_new_events_for_appservice_txn
) )
events = yield self._get_events(event_ids) events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events)) defer.returnValue((upper_bound, events))

View File

@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
""" """
return self.get_auth_chain_ids( return self.get_auth_chain_ids(
event_ids, include_given=include_given event_ids, include_given=include_given
).addCallback(self._get_events) ).addCallback(self.get_events_as_list)
def get_auth_chain_ids(self, event_ids, include_given=False): def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.
@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list, event_list,
limit, limit,
) )
.addCallback(self._get_events) .addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
) )
@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events, latest_events,
limit, limit,
) )
events = yield self._get_events(ids) events = yield self.get_events_as_list(ids)
defer.returnValue(events) defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

View File

@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred : A FrozenEvent. Deferred : A FrozenEvent.
""" """
events = yield self._get_events( events = yield self.get_events_as_list(
[event_id], [event_id],
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred : Dict from event_id to event. Deferred : Dict from event_id to event.
""" """
events = yield self._get_events( events = yield self.get_events_as_list(
event_ids, event_ids,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_events( def get_events_as_list(
self, self,
event_ids, event_ids,
check_redacted=True, check_redacted=True,
get_prev_content=False, get_prev_content=False,
allow_rejected=False, allow_rejected=False,
): ):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[list]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
"""
if not event_ids: if not event_ids:
defer.returnValue([]) defer.returnValue([])
@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore):
# #
# The problem is that we end up at this point when an event # The problem is that we end up at this point when an event
# which has been redacted is pulled out of the database by # which has been redacted is pulled out of the database by
# _enqueue_events, because _enqueue_events needs to check the # _enqueue_events, because _enqueue_events needs to check
# redaction before it can cache the redacted event. So obviously, # the redaction before it can cache the redacted event. So
# calling get_event to get the redacted event out of the database # obviously, calling get_event to get the redacted event out
# gives us an infinite loop. # of the database gives us an infinite loop.
# #
# For now (quick hack to fix during 0.99 release cycle), we just # For now (quick hack to fix during 0.99 release cycle), we
# go and fetch the relevant row from the db, but it would be nice # just go and fetch the relevant row from the db, but it
# to think about how we can cache this rather than hit the db # would be nice to think about how we can cache this rather
# every time we access a redaction event. # than hit the db every time we access a redaction event.
# #
# One thought on how to do this: # One thought on how to do this:
# 1. split _get_events up so that it is divided into (a) get the # 1. split get_events_as_list up so that it is divided into
# rawish event from the db/cache, (b) do the redaction/rejection # (a) get the rawish event from the db/cache, (b) do the
# filtering # redaction/rejection filtering
# 2. have _get_event_from_row just call the first half of that # 2. have _get_event_from_row just call the first half of
# that
orig_sender = yield self._simple_select_one_onecol( orig_sender = yield self._simple_select_one_onecol(
table="events", table="events",

View File

@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results]) events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events} event_map = {ev.event_id: ev for ev in events}
@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results]) events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events} event_map = {ev.event_id: ev for ev in events}

View File

@ -319,7 +319,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f) rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) ret = yield self.get_events_as_list([
r.event_id for r in rows], get_prev_content=True,
)
self._set_before_and_after(ret, rows, topo_order=from_id is None) self._set_before_and_after(ret, rows, topo_order=from_id is None)
@ -367,7 +369,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f) rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True,
)
self._set_before_and_after(ret, rows, topo_order=False) self._set_before_and_after(ret, rows, topo_order=False)
@ -394,7 +398,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
logger.debug("stream before") logger.debug("stream before")
events = yield self._get_events( events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
logger.debug("stream after") logger.debug("stream after")
@ -580,11 +584,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter, event_filter,
) )
events_before = yield self._get_events( events_before = yield self.get_events_as_list(
[e for e in results["before"]["event_ids"]], get_prev_content=True [e for e in results["before"]["event_ids"]], get_prev_content=True
) )
events_after = yield self._get_events( events_after = yield self.get_events_as_list(
[e for e in results["after"]["event_ids"]], get_prev_content=True [e for e in results["after"]["event_ids"]], get_prev_content=True
) )
@ -697,7 +701,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_all_new_events_stream", get_all_new_events_stream_txn "get_all_new_events_stream", get_all_new_events_stream_txn
) )
events = yield self._get_events(event_ids) events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events)) defer.returnValue((upper_bound, events))
@ -849,7 +853,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter, event_filter,
) )
events = yield self._get_events( events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )

View File

@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out # we aren't testing store._base stuff here, so mock this out
self.store._get_events = Mock(return_value=events) self.store.get_events_as_list = Mock(return_value=events)
yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 10, events)