Convert some of the general database methods to async (#8100)
This commit is contained in:
parent
e04e465b4d
commit
050e20e7ca
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -332,8 +332,7 @@ class DatabasePool(object):
|
||||||
"""
|
"""
|
||||||
return self._db_pool.running
|
return self._db_pool.running
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _check_safe_to_upsert(self):
|
||||||
def _check_safe_to_upsert(self):
|
|
||||||
"""
|
"""
|
||||||
Is it safe to use native UPSERT?
|
Is it safe to use native UPSERT?
|
||||||
|
|
||||||
|
@ -342,7 +341,7 @@ class DatabasePool(object):
|
||||||
|
|
||||||
If the background updates have not completed, wait 15 sec and check again.
|
If the background updates have not completed, wait 15 sec and check again.
|
||||||
"""
|
"""
|
||||||
updates = yield self.simple_select_list(
|
updates = await self.simple_select_list(
|
||||||
"background_updates",
|
"background_updates",
|
||||||
keyvalues=None,
|
keyvalues=None,
|
||||||
retcols=["update_name"],
|
retcols=["update_name"],
|
||||||
|
@ -614,8 +613,7 @@ class DatabasePool(object):
|
||||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||||
# no complex WHERE clauses, just a dict of values for columns.
|
# no complex WHERE clauses, just a dict of values for columns.
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
||||||
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
|
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -631,7 +629,7 @@ class DatabasePool(object):
|
||||||
`or_ignore` is True
|
`or_ignore` is True
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
|
await self.runInteraction(desc, self.simple_insert_txn, table, values)
|
||||||
except self.engine.module.IntegrityError:
|
except self.engine.module.IntegrityError:
|
||||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||||
# a cursor after we receive an error from the db.
|
# a cursor after we receive an error from the db.
|
||||||
|
@ -684,8 +682,7 @@ class DatabasePool(object):
|
||||||
|
|
||||||
txn.executemany(sql, vals)
|
txn.executemany(sql, vals)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def simple_upsert(
|
||||||
def simple_upsert(
|
|
||||||
self,
|
self,
|
||||||
table,
|
table,
|
||||||
keyvalues,
|
keyvalues,
|
||||||
|
@ -714,14 +711,14 @@ class DatabasePool(object):
|
||||||
inserting
|
inserting
|
||||||
lock (bool): True to lock the table when doing the upsert.
|
lock (bool): True to lock the table when doing the upsert.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(None or bool): Native upserts always return None. Emulated
|
None or bool: Native upserts always return None. Emulated
|
||||||
upserts return True if a new entry was created, False if an existing
|
upserts return True if a new entry was created, False if an existing
|
||||||
one was updated.
|
one was updated.
|
||||||
"""
|
"""
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = yield self.runInteraction(
|
return await self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
self.simple_upsert_txn,
|
self.simple_upsert_txn,
|
||||||
table,
|
table,
|
||||||
|
@ -730,7 +727,6 @@ class DatabasePool(object):
|
||||||
insertion_values,
|
insertion_values,
|
||||||
lock=lock,
|
lock=lock,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
except self.engine.module.IntegrityError as e:
|
except self.engine.module.IntegrityError as e:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
if attempts >= 5:
|
if attempts >= 5:
|
||||||
|
@ -1121,8 +1117,7 @@ class DatabasePool(object):
|
||||||
|
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def simple_select_many_batch(
|
||||||
def simple_select_many_batch(
|
|
||||||
self,
|
self,
|
||||||
table,
|
table,
|
||||||
column,
|
column,
|
||||||
|
@ -1156,7 +1151,7 @@ class DatabasePool(object):
|
||||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||||
]
|
]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
rows = yield self.runInteraction(
|
rows = await self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
self.simple_select_many_txn,
|
self.simple_select_many_txn,
|
||||||
table,
|
table,
|
||||||
|
|
|
@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
service(ApplicationService): The service whose state to set.
|
service(ApplicationService): The service whose state to set.
|
||||||
state(ApplicationServiceState): The connectivity state to apply.
|
state(ApplicationServiceState): The connectivity state to apply.
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred which resolves when the state was set successfully.
|
An Awaitable which resolves when the state was set successfully.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_upsert(
|
return self.db_pool.simple_upsert(
|
||||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||||
|
|
|
@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
"""Given a list of event ids, check if we have already processed and
|
"""Given a list of event ids, check if we have already processed and
|
||||||
stored them as non outliers.
|
stored them as non outliers.
|
||||||
"""
|
"""
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = yield defer.ensureDeferred(
|
||||||
table="events",
|
self.db_pool.simple_select_many_batch(
|
||||||
retcols=("event_id",),
|
table="events",
|
||||||
column="event_id",
|
retcols=("event_id",),
|
||||||
iterable=list(event_ids),
|
column="event_id",
|
||||||
keyvalues={"outlier": False},
|
iterable=list(event_ids),
|
||||||
desc="have_events_in_timeline",
|
keyvalues={"outlier": False},
|
||||||
|
desc="have_events_in_timeline",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {r["event_id"] for r in rows}
|
return {r["event_id"] for r in rows}
|
||||||
|
|
|
@ -17,9 +17,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Optional
|
from typing import Awaitable, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
|
@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
id_server (str)
|
id_server (str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Awaitable
|
||||||
"""
|
"""
|
||||||
# We need to use an upsert, in case they user had already bound the
|
# We need to use an upsert, in case they user had already bound the
|
||||||
# threepid
|
# threepid
|
||||||
|
@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
def record_user_external_id(
|
def record_user_external_id(
|
||||||
self, auth_provider: str, external_id: str, user_id: str
|
self, auth_provider: str, external_id: str, user_id: str
|
||||||
) -> Deferred:
|
) -> Awaitable:
|
||||||
"""Record a mapping from an external user id to a mxid
|
"""Record a mapping from an external user id to a mxid
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return set(room_ids)
|
return set(room_ids)
|
||||||
|
|
||||||
def get_membership_from_event_ids(
|
async def get_membership_from_event_ids(
|
||||||
self, member_event_ids: Iterable[str]
|
self, member_event_ids: Iterable[str]
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""Get user_id and membership of a set of event IDs.
|
"""Get user_id and membership of a set of event IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.db_pool.simple_select_many_batch(
|
return await self.db_pool.simple_select_many_batch(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=member_event_ids,
|
iterable=member_event_ids,
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
self.bob = UserID.from_string("@4567:test")
|
self.bob = UserID.from_string("@4567:test")
|
||||||
self.alice = UserID.from_string("@alice:remote")
|
self.alice = UserID.from_string("@alice:remote")
|
||||||
|
|
||||||
yield self.store.create_profile(self.frank.localpart)
|
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
|
||||||
|
|
||||||
self.handler = hs.get_profile_handler()
|
self.handler = hs.get_profile_handler()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_incoming_fed_query(self):
|
def test_incoming_fed_query(self):
|
||||||
yield self.store.create_profile("caroline")
|
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||||
|
|
||||||
response = yield defer.ensureDeferred(
|
response = yield defer.ensureDeferred(
|
||||||
|
|
|
@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
([], 0)
|
([], 0)
|
||||||
)
|
)
|
||||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||||
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
|
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_appservices_state_down(self):
|
def test_set_appservices_state_down(self):
|
||||||
service = Mock(id=self.as_list[1]["id"])
|
service = Mock(id=self.as_list[1]["id"])
|
||||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
yield defer.ensureDeferred(
|
||||||
|
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||||
|
)
|
||||||
rows = yield self.db_pool.runQuery(
|
rows = yield self.db_pool.runQuery(
|
||||||
self.engine.convert_param_style(
|
self.engine.convert_param_style(
|
||||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||||
|
@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_appservices_state_multiple_up(self):
|
def test_set_appservices_state_multiple_up(self):
|
||||||
service = Mock(id=self.as_list[1]["id"])
|
service = Mock(id=self.as_list[1]["id"])
|
||||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
yield defer.ensureDeferred(
|
||||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
)
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||||
|
)
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||||
|
)
|
||||||
rows = yield self.db_pool.runQuery(
|
rows = yield self.db_pool.runQuery(
|
||||||
self.engine.convert_param_style(
|
self.engine.convert_param_style(
|
||||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||||
|
|
|
@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
def test_insert_1col(self):
|
def test_insert_1col(self):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield self.datastore.db_pool.simple_insert(
|
yield defer.ensureDeferred(
|
||||||
table="tablename", values={"columname": "Value"}
|
self.datastore.db_pool.simple_insert(
|
||||||
|
table="tablename", values={"columname": "Value"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
|
@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
def test_insert_3cols(self):
|
def test_insert_3cols(self):
|
||||||
self.mock_txn.rowcount = 1
|
self.mock_txn.rowcount = 1
|
||||||
|
|
||||||
yield self.datastore.db_pool.simple_insert(
|
yield defer.ensureDeferred(
|
||||||
table="tablename",
|
self.datastore.db_pool.simple_insert(
|
||||||
# Use OrderedDict() so we can assert on the SQL generated
|
table="tablename",
|
||||||
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
# Use OrderedDict() so we can assert on the SQL generated
|
||||||
|
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mock_txn.execute.assert_called_with(
|
self.mock_txn.execute.assert_called_with(
|
||||||
|
|
|
@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_find_first_stream_ordering_after_ts(self):
|
def test_find_first_stream_ordering_after_ts(self):
|
||||||
def add_event(so, ts):
|
def add_event(so, ts):
|
||||||
return self.store.db_pool.simple_insert(
|
return defer.ensureDeferred(
|
||||||
"events",
|
self.store.db_pool.simple_insert(
|
||||||
{
|
"events",
|
||||||
"stream_ordering": so,
|
{
|
||||||
"received_ts": ts,
|
"stream_ordering": so,
|
||||||
"event_id": "event%i" % so,
|
"received_ts": ts,
|
||||||
"type": "",
|
"event_id": "event%i" % so,
|
||||||
"room_id": "",
|
"type": "",
|
||||||
"content": "",
|
"room_id": "",
|
||||||
"processed": True,
|
"content": "",
|
||||||
"outlier": False,
|
"processed": True,
|
||||||
"topological_ordering": 0,
|
"outlier": False,
|
||||||
"depth": 0,
|
"topological_ordering": 0,
|
||||||
},
|
"depth": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# start with the base case where there are no events in the table
|
# start with the base case where there are no events in the table
|
||||||
|
|
|
@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_users_paginate(self):
|
def test_get_users_paginate(self):
|
||||||
yield self.store.register_user(self.user.to_string(), "pass")
|
yield self.store.register_user(self.user.to_string(), "pass")
|
||||||
yield self.store.create_profile(self.user.localpart)
|
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
||||||
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||||
|
|
||||||
users, total = yield self.store.get_users_paginate(
|
users, total = yield self.store.get_users_paginate(
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_displayname(self):
|
def test_displayname(self):
|
||||||
yield self.store.create_profile(self.u_frank.localpart)
|
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||||
|
|
||||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_avatar_url(self):
|
def test_avatar_url(self):
|
||||||
yield self.store.create_profile(self.u_frank.localpart)
|
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||||
|
|
||||||
yield self.store.set_profile_avatar_url(
|
yield self.store.set_profile_avatar_url(
|
||||||
self.u_frank.localpart, "http://my.site/here"
|
self.u_frank.localpart, "http://my.site/here"
|
||||||
|
|
Loading…
Reference in New Issue