Convert some of the general database methods to async (#8100)
This commit is contained in:
parent
e04e465b4d
commit
050e20e7ca
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -332,8 +332,7 @@ class DatabasePool(object):
|
|||
"""
|
||||
return self._db_pool.running
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_safe_to_upsert(self):
|
||||
async def _check_safe_to_upsert(self):
|
||||
"""
|
||||
Is it safe to use native UPSERT?
|
||||
|
||||
|
@ -342,7 +341,7 @@ class DatabasePool(object):
|
|||
|
||||
If the background updates have not completed, wait 15 sec and check again.
|
||||
"""
|
||||
updates = yield self.simple_select_list(
|
||||
updates = await self.simple_select_list(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcols=["update_name"],
|
||||
|
@ -614,8 +613,7 @@ class DatabasePool(object):
|
|||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
|
@ -631,7 +629,7 @@ class DatabasePool(object):
|
|||
`or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||
await self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||
except self.engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
|
@ -684,8 +682,7 @@ class DatabasePool(object):
|
|||
|
||||
txn.executemany(sql, vals)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_upsert(
|
||||
async def simple_upsert(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
|
@ -714,14 +711,14 @@ class DatabasePool(object):
|
|||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Deferred(None or bool): Native upserts always return None. Emulated
|
||||
None or bool: Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
result = yield self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc,
|
||||
self.simple_upsert_txn,
|
||||
table,
|
||||
|
@ -730,7 +727,6 @@ class DatabasePool(object):
|
|||
insertion_values,
|
||||
lock=lock,
|
||||
)
|
||||
return result
|
||||
except self.engine.module.IntegrityError as e:
|
||||
attempts += 1
|
||||
if attempts >= 5:
|
||||
|
@ -1121,8 +1117,7 @@ class DatabasePool(object):
|
|||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def simple_select_many_batch(
|
||||
async def simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
|
@ -1156,7 +1151,7 @@ class DatabasePool(object):
|
|||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||
]
|
||||
for chunk in chunks:
|
||||
rows = yield self.runInteraction(
|
||||
rows = await self.runInteraction(
|
||||
desc,
|
||||
self.simple_select_many_txn,
|
||||
table,
|
||||
|
|
|
@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
service(ApplicationService): The service whose state to set.
|
||||
state(ApplicationServiceState): The connectivity state to apply.
|
||||
Returns:
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
An Awaitable which resolves when the state was set successfully.
|
||||
"""
|
||||
return self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
|
|
|
@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
rows = yield defer.ensureDeferred(
|
||||
self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
)
|
||||
|
||||
return {r["event_id"] for r in rows}
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from typing import Awaitable, Dict, List, Optional
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
|
@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
id_server (str)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
Awaitable
|
||||
"""
|
||||
# We need to use an upsert, in case they user had already bound the
|
||||
# threepid
|
||||
|
@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
|
||||
def record_user_external_id(
|
||||
self, auth_provider: str, external_id: str, user_id: str
|
||||
) -> Deferred:
|
||||
) -> Awaitable:
|
||||
"""Record a mapping from an external user id to a mxid
|
||||
|
||||
Args:
|
||||
|
|
|
@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
return set(room_ids)
|
||||
|
||||
def get_membership_from_event_ids(
|
||||
async def get_membership_from_event_ids(
|
||||
self, member_event_ids: Iterable[str]
|
||||
) -> List[dict]:
|
||||
"""Get user_id and membership of a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_many_batch(
|
||||
return await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=member_event_ids,
|
||||
|
|
|
@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
self.bob = UserID.from_string("@4567:test")
|
||||
self.alice = UserID.from_string("@alice:remote")
|
||||
|
||||
yield self.store.create_profile(self.frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
|
||||
|
||||
self.handler = hs.get_profile_handler()
|
||||
self.hs = hs
|
||||
|
@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield self.store.create_profile("caroline")
|
||||
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||
|
||||
response = yield defer.ensureDeferred(
|
||||
|
|
|
@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
([], 0)
|
||||
)
|
||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
|
||||
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
|
||||
None
|
||||
)
|
||||
|
||||
|
|
|
@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_set_appservices_state_down(self):
|
||||
service = Mock(id=self.as_list[1]["id"])
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
|
@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_set_appservices_state_multiple_up(self):
|
||||
service = Mock(id=self.as_list[1]["id"])
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
|
|
|
@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
def test_insert_1col(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
|
||||
yield self.datastore.db_pool.simple_insert(
|
||||
table="tablename", values={"columname": "Value"}
|
||||
yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_insert(
|
||||
table="tablename", values={"columname": "Value"}
|
||||
)
|
||||
)
|
||||
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
|
@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
def test_insert_3cols(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
|
||||
yield self.datastore.db_pool.simple_insert(
|
||||
table="tablename",
|
||||
# Use OrderedDict() so we can assert on the SQL generated
|
||||
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
||||
yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_insert(
|
||||
table="tablename",
|
||||
# Use OrderedDict() so we can assert on the SQL generated
|
||||
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
||||
)
|
||||
)
|
||||
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
|
|
|
@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_find_first_stream_ordering_after_ts(self):
|
||||
def add_event(so, ts):
|
||||
return self.store.db_pool.simple_insert(
|
||||
"events",
|
||||
{
|
||||
"stream_ordering": so,
|
||||
"received_ts": ts,
|
||||
"event_id": "event%i" % so,
|
||||
"type": "",
|
||||
"room_id": "",
|
||||
"content": "",
|
||||
"processed": True,
|
||||
"outlier": False,
|
||||
"topological_ordering": 0,
|
||||
"depth": 0,
|
||||
},
|
||||
return defer.ensureDeferred(
|
||||
self.store.db_pool.simple_insert(
|
||||
"events",
|
||||
{
|
||||
"stream_ordering": so,
|
||||
"received_ts": ts,
|
||||
"event_id": "event%i" % so,
|
||||
"type": "",
|
||||
"room_id": "",
|
||||
"content": "",
|
||||
"processed": True,
|
||||
"outlier": False,
|
||||
"topological_ordering": 0,
|
||||
"depth": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# start with the base case where there are no events in the table
|
||||
|
|
|
@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_get_users_paginate(self):
|
||||
yield self.store.register_user(self.user.to_string(), "pass")
|
||||
yield self.store.create_profile(self.user.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
||||
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||
|
||||
users, total = yield self.store.get_users_paginate(
|
||||
|
|
|
@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_displayname(self):
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
|
||||
|
@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_avatar_url(self):
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
|
|
Loading…
Reference in New Issue