Micro-optimisations to get_auth_chain_ids (#8132)
This commit is contained in:
parent
3f91638da6
commit
09fd0eda81
|
@ -0,0 +1 @@
|
|||
Micro-optimisations to get_auth_chain_ids.
|
|
@ -15,14 +15,16 @@
|
|||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Set, Tuple
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.events import EventBase
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.types import Collection
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
|
@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
async def get_auth_chain(self, event_ids, include_given=False):
|
||||
async def get_auth_chain(
|
||||
self, event_ids: Collection[str], include_given: bool = False
|
||||
) -> List[EventBase]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
event_ids (list): state events
|
||||
include_given (bool): include the given events in result
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
|
||||
Returns:
|
||||
list of events
|
||||
|
@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
)
|
||||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
def get_auth_chain_ids(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
include_given: bool = False,
|
||||
ignore_events: Optional[Set[str]] = None,
|
||||
):
|
||||
async def get_auth_chain_ids(
|
||||
self, event_ids: Collection[str], include_given: bool = False,
|
||||
) -> List[str]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
ignore_events: Set of events to exclude from the returned auth
|
||||
chain. This is useful if the caller will just discard the
|
||||
given events anyway, and saves us from figuring out their auth
|
||||
chains if not required.
|
||||
|
||||
Returns:
|
||||
list of event_ids
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_ids,
|
||||
include_given,
|
||||
ignore_events,
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
|
||||
if ignore_events is None:
|
||||
ignore_events = set()
|
||||
|
||||
def _get_auth_chain_ids_txn(
|
||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||
) -> List[str]:
|
||||
if include_given:
|
||||
results = set(event_ids)
|
||||
else:
|
||||
results = set()
|
||||
|
||||
base_sql = "SELECT auth_id FROM event_auth WHERE "
|
||||
base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
|
||||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
|
@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
txn.execute(base_sql + clause, args)
|
||||
new_front.update(r[0] for r in txn)
|
||||
|
||||
new_front -= ignore_events
|
||||
new_front -= results
|
||||
|
||||
front = new_front
|
||||
|
|
Loading…
Reference in New Issue