Micro-optimisations to get_auth_chain_ids (#8132)
This commit is contained in:
parent
3f91638da6
commit
09fd0eda81
|
@ -0,0 +1 @@
|
||||||
|
Micro-optimisations to get_auth_chain_ids.
|
|
@ -15,14 +15,16 @@
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
from typing import Dict, Iterable, List, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||||
|
from synapse.types import Collection
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
|
@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||||
async def get_auth_chain(self, event_ids, include_given=False):
|
async def get_auth_chain(
|
||||||
|
self, event_ids: Collection[str], include_given: bool = False
|
||||||
|
) -> List[EventBase]:
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (list): state events
|
event_ids: state events
|
||||||
include_given (bool): include the given events in result
|
include_given: include the given events in result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of events
|
list of events
|
||||||
|
@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
return await self.get_events_as_list(event_ids)
|
return await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
def get_auth_chain_ids(
|
async def get_auth_chain_ids(
|
||||||
self,
|
self, event_ids: Collection[str], include_given: bool = False,
|
||||||
event_ids: List[str],
|
) -> List[str]:
|
||||||
include_given: bool = False,
|
|
||||||
ignore_events: Optional[Set[str]] = None,
|
|
||||||
):
|
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids: state events
|
event_ids: state events
|
||||||
include_given: include the given events in result
|
include_given: include the given events in result
|
||||||
ignore_events: Set of events to exclude from the returned auth
|
|
||||||
chain. This is useful if the caller will just discard the
|
|
||||||
given events anyway, and saves us from figuring out their auth
|
|
||||||
chains if not required.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of event_ids
|
list of event_ids
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_auth_chain_ids",
|
"get_auth_chain_ids",
|
||||||
self._get_auth_chain_ids_txn,
|
self._get_auth_chain_ids_txn,
|
||||||
event_ids,
|
event_ids,
|
||||||
include_given,
|
include_given,
|
||||||
ignore_events,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
|
def _get_auth_chain_ids_txn(
|
||||||
if ignore_events is None:
|
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||||
ignore_events = set()
|
) -> List[str]:
|
||||||
|
|
||||||
if include_given:
|
if include_given:
|
||||||
results = set(event_ids)
|
results = set(event_ids)
|
||||||
else:
|
else:
|
||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
base_sql = "SELECT auth_id FROM event_auth WHERE "
|
base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
|
||||||
|
|
||||||
front = set(event_ids)
|
front = set(event_ids)
|
||||||
while front:
|
while front:
|
||||||
|
@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
txn.execute(base_sql + clause, args)
|
txn.execute(base_sql + clause, args)
|
||||||
new_front.update(r[0] for r in txn)
|
new_front.update(r[0] for r in txn)
|
||||||
|
|
||||||
new_front -= ignore_events
|
|
||||||
new_front -= results
|
new_front -= results
|
||||||
|
|
||||||
front = new_front
|
front = new_front
|
||||||
|
|
Loading…
Reference in New Issue