Convert tags and metrics databases to async/await (#8062)
This commit is contained in:
parent
a0acdfa9e9
commit
04faa0bfa9
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -15,8 +15,6 @@
|
||||||
import typing
|
import typing
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.metrics import BucketCollector
|
from synapse.metrics import BucketCollector
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
|
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
|
||||||
self._current_forward_extremities_amount = Counter([x[0] for x in res])
|
self._current_forward_extremities_amount = Counter([x[0] for x in res])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def count_daily_messages(self):
|
||||||
def count_daily_messages(self):
|
|
||||||
"""
|
"""
|
||||||
Returns an estimate of the number of messages sent in the last day.
|
Returns an estimate of the number of messages sent in the last day.
|
||||||
|
|
||||||
|
@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
(count,) = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
|
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
||||||
return ret
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def count_daily_sent_messages(self):
|
||||||
def count_daily_sent_messages(self):
|
|
||||||
def _count_messages(txn):
|
def _count_messages(txn):
|
||||||
# This is good enough as if you have silly characters in your own
|
# This is good enough as if you have silly characters in your own
|
||||||
# hostname then thats your own fault.
|
# hostname then thats your own fault.
|
||||||
|
@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
(count,) = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"count_daily_sent_messages", _count_messages
|
"count_daily_sent_messages", _count_messages
|
||||||
)
|
)
|
||||||
return ret
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def count_daily_active_rooms(self):
|
||||||
def count_daily_active_rooms(self):
|
|
||||||
def _count(txn):
|
def _count(txn):
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
|
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
|
||||||
|
@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||||
(count,) = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
|
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
|
||||||
return ret
|
|
||||||
|
|
|
@ -15,14 +15,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.storage._base import db_to_json
|
from synapse.storage._base import db_to_json
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TagsWorkerStore(AccountDataWorkerStore):
|
class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
@cached()
|
@cached()
|
||||||
def get_tags_for_user(self, user_id):
|
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
|
||||||
"""Get all the tags for a user.
|
"""Get all the tags for a user.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the tags for.
|
user_id: The user to get the tags for.
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict mapping from room_id strings to dicts mapping from
|
A mapping from room_id strings to dicts mapping from tag strings to
|
||||||
tag strings to tag content.
|
tag content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
deferred = self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@deferred.addCallback
|
tags_by_room = {}
|
||||||
def tags_by_room(rows):
|
for row in rows:
|
||||||
tags_by_room = {}
|
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
||||||
for row in rows:
|
room_tags[row["tag"]] = db_to_json(row["content"])
|
||||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
return tags_by_room
|
||||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
|
||||||
return tags_by_room
|
|
||||||
|
|
||||||
return deferred
|
|
||||||
|
|
||||||
async def get_all_updated_tags(
|
async def get_all_updated_tags(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
|
@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
return results, upto_token, limited
|
return results, upto_token, limited
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_updated_tags(
|
||||||
def get_updated_tags(self, user_id, stream_id):
|
self, user_id: str, stream_id: int
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
"""Get all the tags for the rooms where the tags have changed since the
|
"""Get all the tags for the rooms where the tags have changed since the
|
||||||
given version
|
given version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the tags for.
|
user_id(str): The user to get the tags for.
|
||||||
stream_id(int): The earliest update to get for the user.
|
stream_id(int): The earliest update to get for the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict mapping from room_id strings to lists of tag
|
A mapping from room_id strings to lists of tag strings for all the
|
||||||
strings for all the rooms that changed since the stream_id token.
|
rooms that changed since the stream_id token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_updated_tags_txn(txn):
|
def get_updated_tags_txn(txn):
|
||||||
|
@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
if not changed:
|
if not changed:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
room_ids = yield self.db_pool.runInteraction(
|
room_ids = await self.db_pool.runInteraction(
|
||||||
"get_updated_tags", get_updated_tags_txn
|
"get_updated_tags", get_updated_tags_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
if room_ids:
|
if room_ids:
|
||||||
tags_by_room = yield self.get_tags_for_user(user_id)
|
tags_by_room = await self.get_tags_for_user(user_id)
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
results[room_id] = tags_by_room.get(room_id, {})
|
results[room_id] = tags_by_room.get(room_id, {})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_tags_for_room(self, user_id, room_id):
|
async def get_tags_for_room(
|
||||||
|
self, user_id: str, room_id: str
|
||||||
|
) -> Dict[str, JsonDict]:
|
||||||
"""Get all the tags for the given room
|
"""Get all the tags for the given room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get tags for
|
user_id: The user to get tags for
|
||||||
room_id(str): The room to get tags for
|
room_id: The room to get tags for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred list of string tags.
|
A mapping of tags to tag content.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="room_tags",
|
table="room_tags",
|
||||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||||
retcols=("tag", "content"),
|
retcols=("tag", "content"),
|
||||||
desc="get_tags_for_room",
|
desc="get_tags_for_room",
|
||||||
).addCallback(
|
|
||||||
lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
|
|
||||||
)
|
)
|
||||||
|
return {row["tag"]: db_to_json(row["content"]) for row in rows}
|
||||||
|
|
||||||
|
|
||||||
class TagsStore(TagsWorkerStore):
|
class TagsStore(TagsWorkerStore):
|
||||||
@defer.inlineCallbacks
|
async def add_tag_to_room(
|
||||||
def add_tag_to_room(self, user_id, room_id, tag, content):
|
self, user_id: str, room_id: str, tag: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
"""Add a tag to a room for a user.
|
"""Add a tag to a room for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to add a tag for.
|
user_id: The user to add a tag for.
|
||||||
room_id(str): The room to add a tag for.
|
room_id: The room to add a tag for.
|
||||||
tag(str): The tag name to add.
|
tag: The tag name to add.
|
||||||
content(dict): A json object to associate with the tag.
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that completes once the tag has been added.
|
The next account data ID.
|
||||||
"""
|
"""
|
||||||
content_json = json.dumps(content)
|
content_json = json.dumps(content)
|
||||||
|
|
||||||
|
@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
return result
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
|
||||||
def remove_tag_from_room(self, user_id, room_id, tag):
|
|
||||||
"""Remove a tag from a room for a user.
|
"""Remove a tag from a room for a user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that completes once the tag has been removed
|
The next account data ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def remove_tag_txn(txn, next_id):
|
def remove_tag_txn(txn, next_id):
|
||||||
|
@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
return result
|
|
||||||
|
|
||||||
def _update_revision_txn(self, txn, user_id, room_id, next_id):
|
def _update_revision_txn(
|
||||||
|
self, txn, user_id: str, room_id: str, next_id: int
|
||||||
|
) -> None:
|
||||||
"""Update the latest revision of the tags for the given user and room.
|
"""Update the latest revision of the tags for the given user and room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn: The database cursor
|
txn: The database cursor
|
||||||
user_id(str): The ID of the user.
|
user_id: The ID of the user.
|
||||||
room_id(str): The ID of the room.
|
room_id: The ID of the room.
|
||||||
next_id(int): The the revision to advance to.
|
next_id: The the revision to advance to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
|
|
|
@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
return_value=defer.succeed("!something:localhost")
|
return_value=defer.succeed("!something:localhost")
|
||||||
)
|
)
|
||||||
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
||||||
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
|
self._rlsn._store.get_tags_for_room = Mock(
|
||||||
|
side_effect=lambda user_id, room_id: make_awaitable({})
|
||||||
|
)
|
||||||
|
|
||||||
@override_config({"hs_disabled": True})
|
@override_config({"hs_disabled": True})
|
||||||
def test_maybe_send_server_notice_disabled_hs(self):
|
def test_maybe_send_server_notice_disabled_hs(self):
|
||||||
|
|
Loading…
Reference in New Issue