Add type hints to `synapse/storage/databases/main/events_bg_updates.py` (#11654)
This commit is contained in:
parent
2c7f5e74e5
commit
07a3b5daba
|
@ -0,0 +1 @@
|
|||
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
|
@ -28,7 +28,6 @@ exclude = (?x)
|
|||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|synapse/storage/databases/main/event_federation.py
|
||||
|synapse/storage/databases/main/events_bg_updates.py
|
||||
|synapse/storage/databases/main/group_server.py
|
||||
|synapse/storage/databases/main/metrics.py
|
||||
|synapse/storage/databases/main/monthly_active_users.py
|
||||
|
@ -200,6 +199,9 @@ disallow_untyped_defs = True
|
|||
[mypy-synapse.storage.databases.main.event_push_actions]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.storage.databases.main.events_bg_updates]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.storage.databases.main.events_worker]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -240,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
################################################################################
|
||||
|
||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
async def _background_reindex_fields_sender(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
def reindex_txn(txn):
|
||||
def reindex_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id, json FROM events"
|
||||
" INNER JOIN event_json USING (event_id)"
|
||||
|
@ -307,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
async def _background_reindex_origin_server_ts(self, progress, batch_size):
|
||||
async def _background_reindex_origin_server_ts(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
def reindex_search_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id FROM events"
|
||||
" WHERE ? <= stream_ordering AND stream_ordering < ?"
|
||||
|
@ -381,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
async def _cleanup_extremities_bg_update(self, progress, batch_size):
|
||||
async def _cleanup_extremities_bg_update(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Background update to clean out extremities that should have been
|
||||
deleted previously.
|
||||
|
||||
|
@ -402,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
# have any descendants, but if they do then we should delete those
|
||||
# extremities.
|
||||
|
||||
def _cleanup_extremities_bg_update_txn(txn):
|
||||
def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
|
||||
# The set of extremity event IDs that we're checking this round
|
||||
original_set = set()
|
||||
|
||||
# A dict[str, set[str]] of event ID to their prev events.
|
||||
graph = {}
|
||||
# A dict[str, Set[str]] of event ID to their prev events.
|
||||
graph: Dict[str, Set[str]] = {}
|
||||
|
||||
# The set of descendants of the original set that are not rejected
|
||||
# nor soft-failed. Ancestors of these events should be removed
|
||||
|
@ -536,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
room_ids = {row["room_id"] for row in rows}
|
||||
for room_id in room_ids:
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
|
@ -558,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
|
||||
)
|
||||
|
||||
def _drop_table_txn(txn):
|
||||
def _drop_table_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute("DROP TABLE _extremities_to_check")
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -567,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return num_handled
|
||||
|
||||
async def _redactions_received_ts(self, progress, batch_size):
|
||||
async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
|
||||
"""Handles filling out the `received_ts` column in redactions."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
def _redactions_received_ts_txn(txn):
|
||||
def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
|
||||
# Fetch the set of event IDs that we want to update
|
||||
sql = """
|
||||
SELECT event_id FROM redactions
|
||||
|
@ -622,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return count
|
||||
|
||||
async def _event_fix_redactions_bytes(self, progress, batch_size):
|
||||
async def _event_fix_redactions_bytes(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Undoes hex encoded censored redacted event JSON."""
|
||||
|
||||
def _event_fix_redactions_bytes_txn(txn):
|
||||
def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
|
||||
# This update is quite fast due to new index.
|
||||
txn.execute(
|
||||
"""
|
||||
|
@ -650,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return 1
|
||||
|
||||
async def _event_store_labels(self, progress, batch_size):
|
||||
async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
|
||||
"""Background update handler which will store labels for existing events."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
def _event_store_labels_txn(txn):
|
||||
def _event_store_labels_txn(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT event_id, json FROM event_json
|
||||
|
@ -754,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
),
|
||||
)
|
||||
|
||||
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
|
||||
return cast(
|
||||
List[Tuple[str, str, JsonDict, bool, bool]],
|
||||
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
|
||||
)
|
||||
|
||||
results = await self.db_pool.runInteraction(
|
||||
desc="_rejected_events_metadata_get", func=get_rejected_events
|
||||
|
@ -912,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
def _calculate_chain_cover_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
txn: LoggingTransaction,
|
||||
last_room_id: str,
|
||||
last_depth: int,
|
||||
last_stream: int,
|
||||
|
@ -1023,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
PersistEventsStore._add_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
self.event_chain_id_gen,
|
||||
self.event_chain_id_gen, # type: ignore[attr-defined]
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
cast(Dict[str, Sequence[str]], event_to_auth_chain),
|
||||
)
|
||||
|
||||
return _CalculateChainCover(
|
||||
|
@ -1046,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
"""
|
||||
current_event_id = progress.get("current_event_id", "")
|
||||
|
||||
def purged_chain_cover_txn(txn) -> int:
|
||||
def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
|
||||
# The event ID from events will be null if the chain ID / sequence
|
||||
# number points to a purged event.
|
||||
sql = """
|
||||
|
@ -1181,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
# Iterate the parent IDs and invalidate caches.
|
||||
for parent_id in {r[1] for r in relations_to_insert}:
|
||||
cache_tuple = (parent_id,)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_relations_for_event, cache_tuple
|
||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aggregation_groups_for_event, cache_tuple
|
||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||
txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_thread_summary, cache_tuple
|
||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if results:
|
||||
|
@ -1220,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
"""
|
||||
batch_size = max(batch_size, 1)
|
||||
|
||||
def process(txn: Cursor) -> int:
|
||||
def process(txn: LoggingTransaction) -> int:
|
||||
last_stream = progress.get("last_stream", -(1 << 31))
|
||||
txn.execute(
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue