Convert simple_select_one and simple_select_one_onecol to async (#8162)

This commit is contained in:
Patrick Cloke 2020-08-26 07:19:32 -04:00 committed by GitHub
parent 56efa9ec71
commit 4c6c56dc58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 220 additions and 113 deletions

1
changelog.d/8162.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -29,9 +29,11 @@ from typing import (
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
overload,
) )
from prometheus_client import Histogram from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer from twisted.internet import defer
@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args) return txn.execute_batch(sql, args)
def simple_select_one( @overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
...
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
...
async def simple_select_one(
self, self,
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcols: Iterable[str], retcols: Iterable[str],
allow_none: bool = False, allow_none: bool = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> defer.Deferred: ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it. return a single row, returning multiple columns from it.
@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT allow_none: If true, return None instead of failing if the SELECT
statement returns no rows statement returns no rows
""" """
return self.runInteraction( return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
) )
def simple_select_one_onecol( async def simple_select_one_onecol(
self, self,
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcol: Iterable[str], retcol: Iterable[str],
allow_none: bool = False, allow_none: bool = False,
desc: str = "simple_select_one_onecol", desc: str = "simple_select_one_onecol",
) -> defer.Deferred: ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning a single column from it.
@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows statement returns no rows
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
""" """
return self.runInteraction( return await self.runInteraction(
desc, desc,
self.simple_select_one_onecol_txn, self.simple_select_one_onecol_txn,
table, table,

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore): class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id: str, device_id: str): async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as """Retrieve a device. Only returns devices that are not marked as
hidden. hidden.
@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve device_id: The ID of the device to retrieve
Returns: Returns:
defer.Deferred for a dict containing the device information A dict containing the device information
Raises: Raises:
StoreError: if the device is not found StoreError: if the device is not found
""" """
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
) )
@cached(max_entries=10000) @cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id: str): async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="stream_id", retcol="stream_id",

View File

@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers) return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias): async def get_room_alias_creator(self, room_alias: str) -> str:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="room_aliases", table="room_aliases",
keyvalues={"room_alias": room_alias}, keyvalues={"room_alias": room_alias},
retcol="creator", retcol="creator",

View File

@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret return ret
def count_e2e_room_keys(self, user_id, version): async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version. """Get the number of keys in a backup version.
Args: Args:
user_id (str): the user whose backup we're querying user_id: the user whose backup we're querying
version (str): the version ID of the backup we're querying about version: the version ID of the backup we're querying about
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)", retcol="COUNT(*)",

View File

@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id): async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event. """Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events. Raises an exception for unknown events.
Args: Args:
event_id (str) event_id: The event ID to query.
Returns: Returns:
Deferred[int|None]: Timestamp in milliseconds, or None for events Timestamp in milliseconds, or None for events that were persisted
that were persisted before received_ts was implemented. before received_ts was implemented.
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="events", table="events",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcol="received_ts", retcol="received_ts",

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore): class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id): async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=( retcols=(
@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
) )
return bool(result) return bool(result)
def is_user_admin_in_group(self, group_id, user_id): async def is_user_admin_in_group(
return self.db_pool.simple_select_one_onecol( self, group_id: str, user_id: str
) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin", retcol="is_admin",
@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group", desc="is_user_admin_in_group",
) )
def is_user_invited_to_local_group(self, group_id, user_id): async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user? """Has the group server invited a user?
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id", retcol="user_id",

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs) super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id): async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:
None if the media_id doesn't exist. None if the media_id doesn't exist.
""" """
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail", desc="store_local_thumbnail",
) )
def get_cached_remote_media(self, origin, media_id): async def get_cached_remote_media(
return self.db_pool.simple_select_one( self, origin, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache", "remote_media_cache",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (

View File

@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users return users
@cached(num_args=1) @cached(num_args=1)
def user_last_seen_monthly_active(self, user_id): async def user_last_seen_monthly_active(self, user_id: str) -> int:
""" """
Checks if a given user is part of the monthly active user group Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
Arguments:
user_id: user to add/update
Return:
Timestamp since last seen, None if never seen
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users", table="monthly_active_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="timestamp", retcol="timestamp",

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore): class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart): async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try: try:
profile = await self.db_pool.simple_select_one( profile = await self.db_pool.simple_select_one(
table="profiles", table="profiles",
@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"] avatar_url=profile["avatar_url"], display_name=profile["displayname"]
) )
def get_profile_displayname(self, user_localpart): async def get_profile_displayname(self, user_localpart: str) -> str:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="displayname", retcol="displayname",
desc="get_profile_displayname", desc="get_profile_displayname",
) )
def get_profile_avatar_url(self, user_localpart): async def get_profile_avatar_url(self, user_localpart: str) -> str:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="avatar_url", retcol="avatar_url",
desc="get_profile_avatar_url", desc="get_profile_avatar_url",
) )
def get_from_remote_profile_cache(self, user_id): async def get_from_remote_profile_cache(
return self.db_pool.simple_select_one( self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"), retcols=("displayname", "avatar_url"),

View File

@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
@cached(num_args=3) @cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): async def get_last_receipt_event_id_for_user(
return self.db_pool.simple_select_one_onecol( self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized", table="receipts_linearized",
keyvalues={ keyvalues={
"room_id": room_id, "room_id": room_id,

View File

@ -17,7 +17,7 @@
import logging import logging
import re import re
from typing import Awaitable, Dict, List, Optional from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
@cached() @cached()
def get_user_by_id(self, user_id): async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcols=[ retcols=[
@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation", desc="del_user_pending_deactivation",
) )
def get_user_pending_deactivation(self): async def get_user_pending_deactivation(self) -> Optional[str]:
""" """
Gets one user from the table of users waiting to be parted from all the rooms Gets one user from the table of users waiting to be parted from all the rooms
they're in. they're in.
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation", "users_pending_deactivation",
keyvalues={}, keyvalues={},
retcol="user_id", retcol="user_id",

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore): class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id): async def get_rejection_reason(self, event_id: str) -> Optional[str]:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="rejections", table="rejections",
retcol="reason", retcol="reason",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},

View File

@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config self.config = hs.config
def get_room(self, room_id): async def get_room(self, room_id: str) -> dict:
"""Retrieve a room. """Retrieve a room.
Args: Args:
room_id (str): The ID of the room to retrieve. room_id: The ID of the room to retrieve.
Returns: Returns:
A dict containing the room information, or None if the room is unknown. A dict containing the room information, or None if the room is unknown.
""" """
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"), retcols=("room_id", "is_public", "creator"),
@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val return ret_val
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): async def is_room_blocked(self, room_id: str) -> Optional[bool]:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms", table="blocked_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="1", retcol="1",

View File

@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias") return event.content.get("canonical_alias")
@cached(max_entries=50000) @cached(max_entries=50000)
def _get_state_group_for_event(self, event_id): async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcol="state_group", retcol="state_group",

View File

@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on) return len(rooms_to_work_on)
def get_stats_positions(self): async def get_stats_positions(self) -> int:
""" """
Returns the stats processor positions. Returns the stats processor positions.
""" """
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position", table="stats_incremental_position",
keyvalues={}, keyvalues={},
retcol="stream_id", retcol="stream_id",
@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list return slice_list
@cached() @cached()
def get_earliest_token_for_stats(self, stats_type, id): async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
""" """
Fetch the "earliest token". This is used by the room stats delta Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the processor to ignore deltas that have been processed between the
@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated. being calculated.
Returns: Returns:
Deferred[int] The earliest token.
""" """
table, id_col = TYPE_TO_TABLE[stats_type] table, id_col = TYPE_TO_TABLE[stats_type]
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,), "%s_current" % (table,),
keyvalues={id_col: id}, keyvalues={id_col: id},
retcol="completed_delta_stream_id", retcol="completed_delta_stream_id",

View File

@ -15,6 +15,7 @@
import logging import logging
import re import re
from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
@cached() @cached()
def get_user_in_directory(self, user_id): async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
return self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="user_directory", table="user_directory",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"), retcols=("display_name", "avatar_url"),
@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows) users.update(rows)
return list(users) return list(users)
def get_user_directory_stream_pos(self): async def get_user_directory_stream_pos(self) -> int:
return self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos", table="user_directory_stream_pos",
keyvalues={}, keyvalues={},
retcol="stream_id", retcol="stream_id",

View File

@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank") yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
displayname = yield defer.ensureDeferred( displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank) self.handler.get_displayname(self.frank)
@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", (
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
yield self.store.set_profile_displayname(self.frank.localpart, "Frank") yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", (
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
) )
# Setting displayname a second time is forbidden # Setting displayname a second time is forbidden
@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline")) yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline") yield defer.ensureDeferred(
self.store.set_profile_displayname("caroline", "Caroline")
)
response = yield defer.ensureDeferred( response = yield defer.ensureDeferred(
self.query_handlers["profile"]( self.query_handlers["profile"](
@ -170,9 +186,11 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url( yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
) )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url) self.assertEquals("http://my.server/me.png", avatar_url)
@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/pic.gif", "http://my.server/pic.gif",
) )
@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png", "http://my.server/me.png",
) )
@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
yield self.store.set_profile_avatar_url( yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
) )
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png", "http://my.server/me.png",
) )

View File

@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = ( self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam # we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1) lambda: make_awaitable(1)
) )
self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_current_state_deltas.return_value = (0, None)

View File

@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes # Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test") self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token) self.assertTrue(access_token)
self.assertTrue(self.store.get_user_by_id(user_id)) self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned # Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id)) emails = self.get_success(self.store.user_get_threepids(user_id))

View File

@ -97,9 +97,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore.db_pool.simple_select_one_onecol( value = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
) )
)
self.assertEquals("Value", value) self.assertEquals("Value", value)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
@ -111,11 +113,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3) self.mock_txn.fetchone.return_value = (1, 2, 3)
ret = yield self.datastore.db_pool.simple_select_one( ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename", table="tablename",
keyvalues={"keycol": "TheKey"}, keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"], retcols=["colA", "colB", "colC"],
) )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
@ -127,12 +131,14 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0 self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None self.mock_txn.fetchone.return_value = None
ret = yield self.datastore.db_pool.simple_select_one( ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename", table="tablename",
keyvalues={"keycol": "Not here"}, keyvalues={"keycol": "Not here"},
retcols=["colA"], retcols=["colA"],
allow_none=True, allow_none=True,
) )
)
self.assertFalse(ret) self.assertFalse(ret)

View File

@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name") self.store.store_device("user_id", "device_id", "display_name")
) )
res = yield self.store.get_device("user_id", "device_id") res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": "user_id", "user_id": "user_id",
@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1") self.store.store_device("user_id", "device_id", "display_name 1")
) )
res = yield self.store.get_device("user_id", "device_id") res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id") res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
) )
# check it worked # check it worked
res = yield self.store.get_device("user_id", "device_id") res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"]) self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self): def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") yield defer.ensureDeferred(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals( self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_avatar_url(self): def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url( yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here" self.u_frank.localpart, "http://my.site/here"
) )
)
self.assertEquals( self.assertEquals(
"http://my.site/here", "http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
) )

View File

@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None, "user_type": None,
"deactivated": 0, "deactivated": 0,
}, },
(yield self.store.get_user_by_id(self.user_id)), (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True, "is_public": True,
}, },
(yield self.store.get_room(self.room.to_string())), (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_room_unknown_room(self): def test_get_room_unknown_room(self):
self.assertIsNone((yield self.store.get_room("!uknown:test")),) self.assertIsNone(
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_room_with_stats(self): def test_get_room_with_stats(self):
@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"public": True, "public": True,
}, },
(yield self.store.get_room_with_stats(self.room.to_string())), (
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self): def test_get_room_with_stats_unknown_room(self):
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
)
class RoomEventsStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase):