Add some type hints to datastore (#12717)
This commit is contained in:
parent
942c30b16b
commit
6edefef602
|
@ -0,0 +1 @@
|
||||||
|
Add some type hints to datastore.
|
2
mypy.ini
2
mypy.ini
|
@ -28,8 +28,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/event_federation.py
|
|synapse/storage/databases/main/event_federation.py
|
||||||
|synapse/storage/databases/main/push_rule.py
|
|
||||||
|synapse/storage/databases/main/roommember.py
|
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/api/test_auth.py
|
|tests/api/test_auth.py
|
||||||
|
|
|
@ -15,7 +15,17 @@
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Hashable,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
@ -409,7 +419,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
destinations: Optional[Set[str]] = None
|
destinations: Optional[Collection[str]] = None
|
||||||
if not event.prev_event_ids():
|
if not event.prev_event_ids():
|
||||||
# If there are no prev event IDs then the state is empty
|
# If there are no prev event IDs then the state is empty
|
||||||
# and so no remote servers in the room
|
# and so no remote servers in the room
|
||||||
|
@ -444,7 +454,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
destinations = {
|
sharded_destinations = {
|
||||||
d
|
d
|
||||||
for d in destinations
|
for d in destinations
|
||||||
if self._federation_shard_config.should_handle(
|
if self._federation_shard_config.should_handle(
|
||||||
|
@ -456,12 +466,12 @@ class FederationSender(AbstractFederationSender):
|
||||||
# If we are sending the event on behalf of another server
|
# If we are sending the event on behalf of another server
|
||||||
# then it already has the event and there is no reason to
|
# then it already has the event and there is no reason to
|
||||||
# send the event to it.
|
# send the event to it.
|
||||||
destinations.discard(send_on_behalf_of)
|
sharded_destinations.discard(send_on_behalf_of)
|
||||||
|
|
||||||
logger.debug("Sending %s to %r", event, destinations)
|
logger.debug("Sending %s to %r", event, sharded_destinations)
|
||||||
|
|
||||||
if destinations:
|
if sharded_destinations:
|
||||||
await self._send_pdu(event, destinations)
|
await self._send_pdu(event, sharded_destinations)
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(event.event_id)
|
ts = await self.store.get_received_ts(event.event_id)
|
||||||
|
|
|
@ -411,10 +411,10 @@ class SyncHandler:
|
||||||
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
|
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
|
||||||
return sync_result
|
return sync_result
|
||||||
|
|
||||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
rules = await self.store.get_push_rules_for_user(user_id)
|
rules_raw = await self.store.get_push_rules_for_user(user_id)
|
||||||
rules = format_push_rules_for_user(user, rules)
|
rules = format_push_rules_for_user(user, rules_raw)
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
async def ephemeral_by_room(
|
async def ephemeral_by_room(
|
||||||
|
|
|
@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet):
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
# is probably not going to make a whole lot of difference
|
# is probably not going to make a whole lot of difference
|
||||||
rules = await self.store.get_push_rules_for_user(user_id)
|
rules_raw = await self.store.get_push_rules_for_user(user_id)
|
||||||
|
|
||||||
rules = format_push_rules_for_user(requester.user, rules)
|
rules = format_push_rules_for_user(requester.user, rules_raw)
|
||||||
|
|
||||||
path_parts = path.split("/")[1:]
|
path_parts = path.split("/")[1:]
|
||||||
|
|
||||||
|
|
|
@ -239,13 +239,13 @@ class StateHandler:
|
||||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||||
|
|
||||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
||||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||||
|
|
||||||
async def get_hosts_in_room_at_events(
|
async def get_hosts_in_room_at_events(
|
||||||
self, room_id: str, event_ids: Collection[str]
|
self, room_id: str, event_ids: Collection[str]
|
||||||
) -> Set[str]:
|
) -> FrozenSet[str]:
|
||||||
"""Get the hosts that were in a room at the given event ids
|
"""Get the hosts that were in a room at the given event ids
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -26,11 +26,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.stats import UserSortOrder
|
from synapse.storage.databases.main.stats import UserSortOrder
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
IdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
@ -155,8 +151,6 @@ class DataStore(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
|
||||||
self._group_updates_id_gen = StreamIdGenerator(
|
self._group_updates_id_gen = StreamIdGenerator(
|
||||||
db_conn, "local_group_updates", "stream_id"
|
db_conn, "local_group_updates", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,16 +14,19 @@
|
||||||
import calendar
|
import calendar
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
||||||
|
|
||||||
from synapse.metrics import GaugeBucketCollector
|
from synapse.metrics import GaugeBucketCollector
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.event_push_actions import (
|
from synapse.storage.databases.main.event_push_actions import (
|
||||||
EventPushActionsWorkerStore,
|
EventPushActionsWorkerStore,
|
||||||
)
|
)
|
||||||
from synapse.storage.types import Cursor
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
@wrap_as_background_process("read_forward_extremities")
|
@wrap_as_background_process("read_forward_extremities")
|
||||||
async def _read_forward_extremities(self) -> None:
|
async def _read_forward_extremities(self) -> None:
|
||||||
def fetch(txn):
|
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT t1.c, t2.c
|
SELECT t1.c, t2.c
|
||||||
|
@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
) t2 ON t1.room_id = t2.room_id
|
) t2 ON t1.room_id = t2.room_id
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[int, int]], txn.fetchall())
|
||||||
|
|
||||||
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
|
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
|
||||||
|
|
||||||
|
@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
call to this function, it will return None.
|
call to this function, it will return None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_messages(txn):
|
def _count_messages(txn: LoggingTransaction) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(*) FROM events
|
SELECT COUNT(*) FROM events
|
||||||
WHERE type = 'm.room.encrypted'
|
WHERE type = 'm.room.encrypted'
|
||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
||||||
|
|
||||||
async def count_daily_sent_e2ee_messages(self) -> int:
|
async def count_daily_sent_e2ee_messages(self) -> int:
|
||||||
def _count_messages(txn):
|
def _count_messages(txn: LoggingTransaction) -> int:
|
||||||
# This is good enough as if you have silly characters in your own
|
# This is good enough as if you have silly characters in your own
|
||||||
# hostname then that's your own fault.
|
# hostname then that's your own fault.
|
||||||
like_clause = "%:" + self.hs.hostname
|
like_clause = "%:" + self.hs.hostname
|
||||||
|
@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_active_e2ee_rooms(self) -> int:
|
async def count_daily_active_e2ee_rooms(self) -> int:
|
||||||
def _count(txn):
|
def _count(txn: LoggingTransaction) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(DISTINCT room_id) FROM events
|
SELECT COUNT(DISTINCT room_id) FROM events
|
||||||
WHERE type = 'm.room.encrypted'
|
WHERE type = 'm.room.encrypted'
|
||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
call to this function, it will return None.
|
call to this function, it will return None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_messages(txn):
|
def _count_messages(txn: LoggingTransaction) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(*) FROM events
|
SELECT COUNT(*) FROM events
|
||||||
WHERE type = 'm.room.message'
|
WHERE type = 'm.room.message'
|
||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
||||||
|
|
||||||
async def count_daily_sent_messages(self) -> int:
|
async def count_daily_sent_messages(self) -> int:
|
||||||
def _count_messages(txn):
|
def _count_messages(txn: LoggingTransaction) -> int:
|
||||||
# This is good enough as if you have silly characters in your own
|
# This is good enough as if you have silly characters in your own
|
||||||
# hostname then that's your own fault.
|
# hostname then that's your own fault.
|
||||||
like_clause = "%:" + self.hs.hostname
|
like_clause = "%:" + self.hs.hostname
|
||||||
|
@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_active_rooms(self) -> int:
|
async def count_daily_active_rooms(self) -> int:
|
||||||
def _count(txn):
|
def _count(txn: LoggingTransaction) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(DISTINCT room_id) FROM events
|
SELECT COUNT(DISTINCT room_id) FROM events
|
||||||
WHERE type = 'm.room.message'
|
WHERE type = 'm.room.message'
|
||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
|
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
|
||||||
|
@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
"count_monthly_users", self._count_users, thirty_days_ago
|
"count_monthly_users", self._count_users, thirty_days_ago
|
||||||
)
|
)
|
||||||
|
|
||||||
def _count_users(self, txn: Cursor, time_from: int) -> int:
|
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
|
||||||
"""
|
"""
|
||||||
Returns number of users seen in the past time_from period
|
Returns number of users seen in the past time_from period
|
||||||
"""
|
"""
|
||||||
|
@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
# Mypy knows that fetchone() might return None if there are no rows.
|
# Mypy knows that fetchone() might return None if there are no rows.
|
||||||
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
|
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
|
||||||
# returns exactly one row.
|
# returns exactly one row.
|
||||||
(count,) = txn.fetchone() # type: ignore[misc]
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
async def count_r30_users(self) -> Dict[str, int]:
|
async def count_r30_users(self) -> Dict[str, int]:
|
||||||
|
@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
A mapping of counts globally as well as broken out by platform.
|
A mapping of counts globally as well as broken out by platform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_r30_users(txn):
|
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
|
||||||
thirty_days_in_secs = 86400 * 30
|
thirty_days_in_secs = 86400 * 30
|
||||||
now = int(self._clock.time())
|
now = int(self._clock.time())
|
||||||
thirty_days_ago_in_secs = now - thirty_days_in_secs
|
thirty_days_ago_in_secs = now - thirty_days_in_secs
|
||||||
|
@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||||
|
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
results["all"] = count
|
results["all"] = count
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
- "web" (any web application -- it's not possible to distinguish Element Web here)
|
- "web" (any web application -- it's not possible to distinguish Element Web here)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_r30v2_users(txn):
|
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
|
||||||
thirty_days_in_secs = 86400 * 30
|
thirty_days_in_secs = 86400 * 30
|
||||||
now = int(self._clock.time())
|
now = int(self._clock.time())
|
||||||
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
|
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
|
||||||
|
@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
thirty_days_in_secs * 1000,
|
thirty_days_in_secs * 1000,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
row = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
if row is None:
|
results["all"] = count
|
||||||
results["all"] = 0
|
|
||||||
else:
|
|
||||||
results["all"] = row[0]
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
Generates daily visit data for use in cohort/ retention analysis
|
Generates daily visit data for use in cohort/ retention analysis
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _generate_user_daily_visits(txn):
|
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
|
||||||
logger.info("Calling _generate_user_daily_visits")
|
logger.info("Calling _generate_user_daily_visits")
|
||||||
today_start = self._get_start_of_day()
|
today_start = self._get_start_of_day()
|
||||||
a_day_in_milliseconds = 24 * 60 * 60 * 1000
|
a_day_in_milliseconds = 24 * 60 * 60 * 1000
|
||||||
|
|
|
@ -14,14 +14,18 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.config.homeserver import ExperimentalConfig
|
from synapse.config.homeserver import ExperimentalConfig
|
||||||
from synapse.push.baserules import list_with_base_rules
|
from synapse.push.baserules import list_with_base_rules
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||||
|
@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
|
AbstractStreamIdGenerator,
|
||||||
AbstractStreamIdTracker,
|
AbstractStreamIdTracker,
|
||||||
|
IdGenerator,
|
||||||
StreamIdGenerator,
|
StreamIdGenerator,
|
||||||
)
|
)
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
@ -57,7 +64,11 @@ def _is_experimental_rule_enabled(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
|
def _load_rules(
|
||||||
|
rawrules: List[JsonDict],
|
||||||
|
enabled_map: Dict[str, bool],
|
||||||
|
experimental_config: ExperimentalConfig,
|
||||||
|
) -> List[JsonDict]:
|
||||||
ruleslist = []
|
ruleslist = []
|
||||||
for rawrule in rawrules:
|
for rawrule in rawrules:
|
||||||
rule = dict(rawrule)
|
rule = dict(rawrule)
|
||||||
|
@ -137,7 +148,7 @@ class PushRulesWorkerStore(
|
||||||
)
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self) -> int:
|
||||||
"""Get the position of the push rules stream.
|
"""Get the position of the push rules stream.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -146,7 +157,7 @@ class PushRulesWorkerStore(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cached(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
async def get_push_rules_for_user(self, user_id):
|
async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
keyvalues={"user_name": user_id},
|
keyvalues={"user_name": user_id},
|
||||||
|
@ -168,7 +179,7 @@ class PushRulesWorkerStore(
|
||||||
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
||||||
|
|
||||||
@cached(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
|
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||||
results = await self.db_pool.simple_select_list(
|
results = await self.db_pool.simple_select_list(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
keyvalues={"user_name": user_id},
|
keyvalues={"user_name": user_id},
|
||||||
|
@ -184,13 +195,13 @@ class PushRulesWorkerStore(
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def have_push_rules_changed_txn(txn):
|
def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
||||||
" WHERE user_id = ? AND ? < stream_id"
|
" WHERE user_id = ? AND ? < stream_id"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, last_id))
|
txn.execute(sql, (user_id, last_id))
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return bool(count)
|
return bool(count)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -202,11 +213,13 @@ class PushRulesWorkerStore(
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
async def bulk_get_push_rules(self, user_ids):
|
async def bulk_get_push_rules(
|
||||||
|
self, user_ids: Collection[str]
|
||||||
|
) -> Dict[str, List[JsonDict]]:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = {user_id: [] for user_id in user_ids}
|
results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
|
@ -250,7 +263,7 @@ class PushRulesWorkerStore(
|
||||||
condition["pattern"] = new_room_id
|
condition["pattern"] = new_room_id
|
||||||
|
|
||||||
# Add the rule for the new room
|
# Add the rule for the new room
|
||||||
await self.add_push_rule(
|
await self.add_push_rule( # type: ignore[attr-defined]
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
rule_id=new_rule_id,
|
rule_id=new_rule_id,
|
||||||
priority_class=rule["priority_class"],
|
priority_class=rule["priority_class"],
|
||||||
|
@ -286,11 +299,13 @@ class PushRulesWorkerStore(
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
async def bulk_get_push_rules_enabled(
|
||||||
|
self, user_ids: Collection[str]
|
||||||
|
) -> Dict[str, Dict[str, bool]]:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = {user_id: {} for user_id in user_ids}
|
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
|
@ -306,7 +321,7 @@ class PushRulesWorkerStore(
|
||||||
|
|
||||||
async def get_all_push_rule_updates(
|
async def get_all_push_rule_updates(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||||
"""Get updates for push_rules replication stream.
|
"""Get updates for push_rules replication stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -331,7 +346,9 @@ class PushRulesWorkerStore(
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_push_rule_updates_txn(txn):
|
def get_all_push_rule_updates_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, user_id
|
SELECT stream_id, user_id
|
||||||
FROM push_rules_stream
|
FROM push_rules_stream
|
||||||
|
@ -340,7 +357,10 @@ class PushRulesWorkerStore(
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
|
updates = cast(
|
||||||
|
List[Tuple[int, Tuple[str]]],
|
||||||
|
[(stream_id, (user_id,)) for stream_id, user_id in txn],
|
||||||
|
)
|
||||||
|
|
||||||
limited = False
|
limited = False
|
||||||
upper_bound = current_id
|
upper_bound = current_id
|
||||||
|
@ -356,15 +376,30 @@ class PushRulesWorkerStore(
|
||||||
|
|
||||||
|
|
||||||
class PushRuleStore(PushRulesWorkerStore):
|
class PushRuleStore(PushRulesWorkerStore):
|
||||||
|
# Because we have write access, this will be a StreamIdGenerator
|
||||||
|
# (see PushRulesWorkerStore.__init__)
|
||||||
|
_push_rules_stream_id_gen: AbstractStreamIdGenerator
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
|
|
||||||
async def add_push_rule(
|
async def add_push_rule(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
rule_id,
|
rule_id: str,
|
||||||
priority_class,
|
priority_class: int,
|
||||||
conditions,
|
conditions: List[Dict[str, str]],
|
||||||
actions,
|
actions: List[Union[JsonDict, str]],
|
||||||
before=None,
|
before: Optional[str] = None,
|
||||||
after=None,
|
after: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
conditions_json = json_encoder.encode(conditions)
|
conditions_json = json_encoder.encode(conditions)
|
||||||
actions_json = json_encoder.encode(actions)
|
actions_json = json_encoder.encode(actions)
|
||||||
|
@ -400,17 +435,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
|
|
||||||
def _add_push_rule_relative_txn(
|
def _add_push_rule_relative_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
stream_id,
|
stream_id: int,
|
||||||
event_stream_ordering,
|
event_stream_ordering: int,
|
||||||
user_id,
|
user_id: str,
|
||||||
rule_id,
|
rule_id: str,
|
||||||
priority_class,
|
priority_class: int,
|
||||||
conditions_json,
|
conditions_json: str,
|
||||||
actions_json,
|
actions_json: str,
|
||||||
before,
|
before: str,
|
||||||
after,
|
after: str,
|
||||||
):
|
) -> None:
|
||||||
# Lock the table since otherwise we'll have annoying races between the
|
# Lock the table since otherwise we'll have annoying races between the
|
||||||
# SELECT here and the UPSERT below.
|
# SELECT here and the UPSERT below.
|
||||||
self.database_engine.lock_table(txn, "push_rules")
|
self.database_engine.lock_table(txn, "push_rules")
|
||||||
|
@ -470,15 +505,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
|
|
||||||
def _add_push_rule_highest_priority_txn(
|
def _add_push_rule_highest_priority_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
stream_id,
|
stream_id: int,
|
||||||
event_stream_ordering,
|
event_stream_ordering: int,
|
||||||
user_id,
|
user_id: str,
|
||||||
rule_id,
|
rule_id: str,
|
||||||
priority_class,
|
priority_class: int,
|
||||||
conditions_json,
|
conditions_json: str,
|
||||||
actions_json,
|
actions_json: str,
|
||||||
):
|
) -> None:
|
||||||
# Lock the table since otherwise we'll have annoying races between the
|
# Lock the table since otherwise we'll have annoying races between the
|
||||||
# SELECT here and the UPSERT below.
|
# SELECT here and the UPSERT below.
|
||||||
self.database_engine.lock_table(txn, "push_rules")
|
self.database_engine.lock_table(txn, "push_rules")
|
||||||
|
@ -510,17 +545,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
|
|
||||||
def _upsert_push_rule_txn(
|
def _upsert_push_rule_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
stream_id,
|
stream_id: int,
|
||||||
event_stream_ordering,
|
event_stream_ordering: int,
|
||||||
user_id,
|
user_id: str,
|
||||||
rule_id,
|
rule_id: str,
|
||||||
priority_class,
|
priority_class: int,
|
||||||
priority,
|
priority: int,
|
||||||
conditions_json,
|
conditions_json: str,
|
||||||
actions_json,
|
actions_json: str,
|
||||||
update_stream=True,
|
update_stream: bool = True,
|
||||||
):
|
) -> None:
|
||||||
"""Specialised version of simple_upsert_txn that picks a push_rule_id
|
"""Specialised version of simple_upsert_txn that picks a push_rule_id
|
||||||
using the _push_rule_id_gen if it needs to insert the rule. It assumes
|
using the _push_rule_id_gen if it needs to insert the rule. It assumes
|
||||||
that the "push_rules" table is locked"""
|
that the "push_rules" table is locked"""
|
||||||
|
@ -600,7 +635,11 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
rule_id: The rule_id of the rule to be deleted
|
rule_id: The rule_id of the rule to be deleted
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
|
def delete_push_rule_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
stream_id: int,
|
||||||
|
event_stream_ordering: int,
|
||||||
|
) -> None:
|
||||||
# we don't use simple_delete_one_txn because that would fail if the
|
# we don't use simple_delete_one_txn because that would fail if the
|
||||||
# user did not have a push_rule_enable row.
|
# user did not have a push_rule_enable row.
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
|
@ -661,14 +700,14 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
|
|
||||||
def _set_push_rule_enabled_txn(
|
def _set_push_rule_enabled_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
stream_id,
|
stream_id: int,
|
||||||
event_stream_ordering,
|
event_stream_ordering: int,
|
||||||
user_id,
|
user_id: str,
|
||||||
rule_id,
|
rule_id: str,
|
||||||
enabled,
|
enabled: bool,
|
||||||
is_default_rule,
|
is_default_rule: bool,
|
||||||
):
|
) -> None:
|
||||||
new_id = self._push_rules_enable_id_gen.get_next()
|
new_id = self._push_rules_enable_id_gen.get_next()
|
||||||
|
|
||||||
if not is_default_rule:
|
if not is_default_rule:
|
||||||
|
@ -740,7 +779,11 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
"""
|
"""
|
||||||
actions_json = json_encoder.encode(actions)
|
actions_json = json_encoder.encode(actions)
|
||||||
|
|
||||||
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
|
def set_push_rule_actions_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
stream_id: int,
|
||||||
|
event_stream_ordering: int,
|
||||||
|
) -> None:
|
||||||
if is_default_rule:
|
if is_default_rule:
|
||||||
# Add a dummy rule to the rules table with the user specified
|
# Add a dummy rule to the rules table with the user specified
|
||||||
# actions.
|
# actions.
|
||||||
|
@ -794,8 +837,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _insert_push_rules_update_txn(
|
def _insert_push_rules_update_txn(
|
||||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
|
self,
|
||||||
):
|
txn: LoggingTransaction,
|
||||||
|
stream_id: int,
|
||||||
|
event_stream_ordering: int,
|
||||||
|
user_id: str,
|
||||||
|
rule_id: str,
|
||||||
|
op: str,
|
||||||
|
data: Optional[JsonDict] = None,
|
||||||
|
) -> None:
|
||||||
values = {
|
values = {
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
"event_stream_ordering": event_stream_ordering,
|
"event_stream_ordering": event_stream_ordering,
|
||||||
|
@ -814,5 +864,5 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self) -> int:
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
|
@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import (
|
||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.engines import Sqlite3Engine
|
from synapse.storage.engines import Sqlite3Engine
|
||||||
from synapse.storage.roommember import (
|
from synapse.storage.roommember import (
|
||||||
|
@ -46,7 +51,7 @@ from synapse.storage.roommember import (
|
||||||
ProfileInfo,
|
ProfileInfo,
|
||||||
RoomsForUser,
|
RoomsForUser,
|
||||||
)
|
)
|
||||||
from synapse.types import PersistedEventPosition, get_domain_from_id
|
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||||
|
@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("_count_known_servers")
|
@wrap_as_background_process("_count_known_servers")
|
||||||
async def _count_known_servers(self):
|
async def _count_known_servers(self) -> int:
|
||||||
"""
|
"""
|
||||||
Count the servers that this server knows about.
|
Count the servers that this server knows about.
|
||||||
|
|
||||||
|
@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
`synapse_federation_known_servers` LaterGauge to collect.
|
`synapse_federation_known_servers` LaterGauge to collect.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _transact(txn):
|
def _transact(txn: LoggingTransaction) -> int:
|
||||||
if isinstance(self.database_engine, Sqlite3Engine):
|
if isinstance(self.database_engine, Sqlite3Engine):
|
||||||
query = """
|
query = """
|
||||||
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
|
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
|
||||||
|
@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
self._known_servers_count = max([count, 1])
|
self._known_servers_count = max([count, 1])
|
||||||
return self._known_servers_count
|
return self._known_servers_count
|
||||||
|
|
||||||
def _check_safe_current_state_events_membership_updated_txn(self, txn):
|
def _check_safe_current_state_events_membership_updated_txn(
|
||||||
|
self, txn: LoggingTransaction
|
||||||
|
) -> None:
|
||||||
"""Checks if it is safe to assume the new current_state_events
|
"""Checks if it is safe to assume the new current_state_events
|
||||||
membership column is up to date
|
membership column is up to date
|
||||||
"""
|
"""
|
||||||
|
@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
||||||
# If we can assume current_state_events.membership is up to date
|
# If we can assume current_state_events.membership is up to date
|
||||||
# then we can avoid a join, which is a Very Good Thing given how
|
# then we can avoid a join, which is a Very Good Thing given how
|
||||||
# frequently this function gets called.
|
# frequently this function gets called.
|
||||||
|
@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
A mapping from user ID to ProfileInfo.
|
A mapping from user ID to ProfileInfo.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
|
def _get_users_in_room_with_profiles(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Dict[str, ProfileInfo]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT state_key, display_name, avatar_url FROM room_memberships as m
|
SELECT state_key, display_name, avatar_url FROM room_memberships as m
|
||||||
INNER JOIN current_state_events as c
|
INNER JOIN current_state_events as c
|
||||||
|
@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
dict of membership states, pointing to a MemberSummary named tuple.
|
dict of membership states, pointing to a MemberSummary named tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_room_summary_txn(txn):
|
def _get_room_summary_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Dict[str, MemberSummary]:
|
||||||
# first get counts.
|
# first get counts.
|
||||||
# We do this all in one transaction to keep the cache small.
|
# We do this all in one transaction to keep the cache small.
|
||||||
# FIXME: get rid of this when we have room_stats
|
# FIXME: get rid of this when we have room_stats
|
||||||
|
@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
res = {}
|
res: Dict[str, MemberSummary] = {}
|
||||||
for count, membership in txn:
|
for count, membership in txn:
|
||||||
res.setdefault(membership, MemberSummary([], count))
|
res.setdefault(membership, MemberSummary([], count))
|
||||||
|
|
||||||
|
@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
membership_list: List[str],
|
membership_list: List[str],
|
||||||
) -> List[RoomsForUser]:
|
) -> List[RoomsForUser]:
|
||||||
|
@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rooms_for_user_with_stream_ordering_txn(
|
def _get_rooms_for_user_with_stream_ordering_txn(
|
||||||
self, txn, user_id: str
|
self, txn: LoggingTransaction, user_id: str
|
||||||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||||
# We use `current_state_events` here and not `local_current_membership`
|
# We use `current_state_events` here and not `local_current_membership`
|
||||||
# as a) this gets called with remote users and b) this only gets called
|
# as a) this gets called with remote users and b) this only gets called
|
||||||
|
@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rooms_for_users_with_stream_ordering_txn(
|
def _get_rooms_for_users_with_stream_ordering_txn(
|
||||||
self, txn, user_ids: Collection[str]
|
self, txn: LoggingTransaction, user_ids: Collection[str]
|
||||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
||||||
|
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
|
@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
txn.execute(sql, [Membership.JOIN] + args)
|
txn.execute(sql, [Membership.JOIN] + args)
|
||||||
|
|
||||||
result = {user_id: set() for user_id in user_ids}
|
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
|
||||||
|
user_id: set() for user_id in user_ids
|
||||||
|
}
|
||||||
for user_id, room_id, instance, stream_id in txn:
|
for user_id, room_id, instance, stream_id in txn:
|
||||||
result[user_id].add(
|
result[user_id].add(
|
||||||
GetRoomsForUserWithStreamOrdering(
|
GetRoomsForUserWithStreamOrdering(
|
||||||
|
@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _get_users_server_still_shares_room_with_txn(txn):
|
def _get_users_server_still_shares_room_with_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Set[str]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT state_key FROM current_state_events
|
SELECT state_key FROM current_state_events
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
async def get_joined_users_from_context(
|
async def get_joined_users_from_context(
|
||||||
self, event: EventBase, context: EventContext
|
self, event: EventBase, context: EventContext
|
||||||
) -> Dict[str, ProfileInfo]:
|
) -> Dict[str, ProfileInfo]:
|
||||||
state_group = context.state_group
|
state_group: Union[object, int] = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
current_state_ids = await context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
|
assert current_state_ids is not None
|
||||||
|
assert state_group is not None
|
||||||
return await self._get_joined_users_from_context(
|
return await self._get_joined_users_from_context(
|
||||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_joined_users_from_state(
|
async def get_joined_users_from_state(
|
||||||
self, room_id, state_entry
|
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||||
) -> Dict[str, ProfileInfo]:
|
) -> Dict[str, ProfileInfo]:
|
||||||
state_group = state_entry.state_group
|
state_group: Union[object, int] = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
|
assert state_group is not None
|
||||||
with Measure(self._clock, "get_joined_users_from_state"):
|
with Measure(self._clock, "get_joined_users_from_state"):
|
||||||
return await self._get_joined_users_from_context(
|
return await self._get_joined_users_from_context(
|
||||||
room_id, state_group, state_entry.state, context=state_entry
|
room_id, state_group, state_entry.state, context=state_entry
|
||||||
|
@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
||||||
async def _get_joined_users_from_context(
|
async def _get_joined_users_from_context(
|
||||||
self,
|
self,
|
||||||
room_id,
|
room_id: str,
|
||||||
state_group,
|
state_group: Union[object, int],
|
||||||
current_state_ids,
|
current_state_ids: StateMap[str],
|
||||||
cache_context,
|
cache_context: _CacheContext,
|
||||||
event=None,
|
event: Optional[EventBase] = None,
|
||||||
context=None,
|
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
|
||||||
) -> Dict[str, ProfileInfo]:
|
) -> Dict[str, ProfileInfo]:
|
||||||
# We don't use `state_group`, it's there so that we can cache based
|
# We don't use `state_group`, it's there so that we can cache based
|
||||||
# on it. However, it's important that it's never None, since two current_states
|
# on it. However, it's important that it's never None, since two current_states
|
||||||
|
@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return users_in_room
|
return users_in_room
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def _get_joined_profile_from_event_id(self, event_id):
|
def _get_joined_profile_from_event_id(
|
||||||
|
self, event_id: str
|
||||||
|
) -> Optional[Tuple[str, ProfileInfo]]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="_get_joined_profile_from_event_id",
|
cached_method_name="_get_joined_profile_from_event_id",
|
||||||
list_name="event_ids",
|
list_name="event_ids",
|
||||||
)
|
)
|
||||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
async def _get_joined_profiles_from_event_ids(
|
||||||
|
self, event_ids: Iterable[str]
|
||||||
|
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
|
||||||
"""For given set of member event_ids check if they point to a join
|
"""For given set of member event_ids check if they point to a join
|
||||||
event and if so return the associated user and profile info.
|
event and if so return the associated user and profile info.
|
||||||
|
|
||||||
|
@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
event_ids: The member event IDs to lookup
|
event_ids: The member event IDs to lookup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
|
||||||
to `user_id` and ProfileInfo (or None if not join event).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
|
@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_joined_hosts(self, room_id: str, state_entry):
|
async def get_joined_hosts(
|
||||||
state_group = state_entry.state_group
|
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||||
|
) -> FrozenSet[str]:
|
||||||
|
state_group: Union[object, int] = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
|
assert state_group is not None
|
||||||
with Measure(self._clock, "get_joined_hosts"):
|
with Measure(self._clock, "get_joined_hosts"):
|
||||||
return await self._get_joined_hosts(
|
return await self._get_joined_hosts(
|
||||||
room_id, state_group, state_entry=state_entry
|
room_id, state_group, state_entry=state_entry
|
||||||
|
@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||||
async def _get_joined_hosts(
|
async def _get_joined_hosts(
|
||||||
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
|
self,
|
||||||
|
room_id: str,
|
||||||
|
state_group: Union[object, int],
|
||||||
|
state_entry: "_StateCacheEntry",
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
# We don't use `state_group`, it's there so that we can cache based on
|
# We don't use `state_group`, it's there so that we can cache based on
|
||||||
# it. However, its important that its never None, since two
|
# it. However, its important that its never None, since two
|
||||||
|
@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# `get_joined_hosts` is called with the "current" state group for the
|
# `get_joined_hosts` is called with the "current" state group for the
|
||||||
# room, and so consecutive calls will be for consecutive state groups
|
# room, and so consecutive calls will be for consecutive state groups
|
||||||
# which point to the previous state group.
|
# which point to the previous state group.
|
||||||
cache = await self._get_joined_hosts_cache(room_id)
|
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
|
||||||
|
|
||||||
# If the state group in the cache matches, we already have the data we need.
|
# If the state group in the cache matches, we already have the data we need.
|
||||||
if state_entry.state_group == cache.state_group:
|
if state_entry.state_group == cache.state_group:
|
||||||
|
@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
elif state_entry.prev_group == cache.state_group:
|
elif state_entry.prev_group == cache.state_group:
|
||||||
# The cached work is for the previous state group, so we work out
|
# The cached work is for the previous state group, so we work out
|
||||||
# the delta.
|
# the delta.
|
||||||
|
assert state_entry.delta_ids is not None
|
||||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||||
if typ != EventTypes.Member:
|
if typ != EventTypes.Member:
|
||||||
continue
|
continue
|
||||||
|
@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
Returns False if they have since re-joined."""
|
Returns False if they have since re-joined."""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> int:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT"
|
"SELECT"
|
||||||
" COUNT(*)"
|
" COUNT(*)"
|
||||||
|
@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
The forgotten rooms.
|
The forgotten rooms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_forgotten_rooms_for_user_txn(txn):
|
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
# This is a slightly convoluted query that first looks up all rooms
|
# This is a slightly convoluted query that first looks up all rooms
|
||||||
# that the user has forgotten in the past, then rechecks that list
|
# that the user has forgotten in the past, then rechecks that list
|
||||||
# to see if any have subsequently been updated. This is done so that
|
# to see if any have subsequently been updated. This is done so that
|
||||||
|
@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
clause,
|
clause,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_local_host_in_room_ignoring_users_txn(txn):
|
def _is_local_host_in_room_ignoring_users_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> bool:
|
||||||
txn.execute(sql, (room_id, Membership.JOIN, *args))
|
txn.execute(sql, (room_id, Membership.JOIN, *args))
|
||||||
|
|
||||||
return bool(txn.fetchone())
|
return bool(txn.fetchone())
|
||||||
|
@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
where_clause="forgotten = 1",
|
where_clause="forgotten = 1",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_add_membership_profile(self, progress, batch_size):
|
async def _background_add_membership_profile(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
target_min_stream_id = progress.get(
|
target_min_stream_id = progress.get(
|
||||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
max_stream_id = progress.get(
|
max_stream_id = progress.get(
|
||||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_membership_profile_txn(txn):
|
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
||||||
FROM events
|
FROM events
|
||||||
|
@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _background_current_state_membership(self, progress, batch_size):
|
async def _background_current_state_membership(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
"""Update the new membership column on current_state_events.
|
"""Update the new membership column on current_state_events.
|
||||||
|
|
||||||
This works by iterating over all rooms in alphebetical order.
|
This works by iterating over all rooms in alphebetical order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _background_current_state_membership_txn(txn, last_processed_room):
|
def _background_current_state_membership_txn(
|
||||||
|
txn: LoggingTransaction, last_processed_room: str
|
||||||
|
) -> Tuple[int, bool]:
|
||||||
processed = 0
|
processed = 0
|
||||||
while processed < batch_size:
|
while processed < batch_size:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
return row_count
|
return row_count
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
class RoomMemberStore(
|
||||||
|
RoomMemberWorkerStore,
|
||||||
|
RoomMemberBackgroundUpdateStore,
|
||||||
|
CacheInvalidationWorkerStore,
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||||
async def forget(self, user_id: str, room_id: str) -> None:
|
async def forget(self, user_id: str, room_id: str) -> None:
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> None:
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE"
|
"UPDATE"
|
||||||
" room_memberships"
|
" room_memberships"
|
||||||
|
@ -1288,5 +1328,5 @@ class _JoinedHostsCache:
|
||||||
# equal to anything else).
|
# equal to anything else).
|
||||||
state_group: Union[object, int] = attr.Factory(object)
|
state_group: Union[object, int] = attr.Factory(object)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return sum(len(v) for v in self.hosts_to_joined_users.values())
|
return sum(len(v) for v in self.hosts_to_joined_users.values())
|
||||||
|
|
Loading…
Reference in New Issue