wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method (#8231)

We have three things which all call `_get_e2e_device_keys_and_signatures_txn`
with their own `runInteraction`. Factor out the common code.
This commit is contained in:
Richard van der Hoff 2020-09-03 11:50:49 +01:00 committed by GitHub
parent c8758cb72f
commit 6f6f371a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 17 deletions

1
changelog.d/8231.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

View File

@ -255,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU List of objects representing an device update EDU
""" """
devices = ( devices = (
await self.db_pool.runInteraction( await self.get_e2e_device_keys_and_signatures(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
query_map.keys(), query_map.keys(),
include_all_devices=True, include_all_devices=True,
include_deleted_devices=True, include_deleted_devices=True,

View File

@ -36,7 +36,7 @@ if TYPE_CHECKING:
@attr.s @attr.s
class DeviceKeyLookupResult: class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn""" """The type returned by get_e2e_device_keys_and_signatures"""
display_name = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str])
@ -60,11 +60,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
""" """
now_stream_id = self.get_device_stream_token() now_stream_id = self.get_device_stream_token()
devices = await self.db_pool.runInteraction( devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
[(user_id, None)],
)
if devices: if devices:
user_devices = devices[user_id] user_devices = devices[user_id]
@ -108,11 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list: if not query_list:
return {} return {}
results = await self.db_pool.runInteraction( results = await self.get_e2e_device_keys_and_signatures(query_list)
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
)
# Build the result structure, un-jsonify the results, and add the # Build the result structure, un-jsonify the results, and add the
# "unsigned" section # "unsigned" section
@ -135,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv return rv
@trace @trace
def _get_e2e_device_keys_and_signatures_txn( async def get_e2e_device_keys_and_signatures(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False self,
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
to indicate "all devices for this user"
include_all_devices: whether to return devices without device keys
include_deleted_devices: whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
set_tag("include_all_devices", include_all_devices) set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices) set_tag("include_deleted_devices", include_deleted_devices)
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
log_kv(result)
return result
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = [] query_clauses = []
query_params = [] query_params = []
signature_query_clauses = [] signature_query_clauses = []
@ -230,7 +255,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
signing_user_signatures[signing_key_id] = signature signing_user_signatures[signing_key_id] = signature
log_kv(result)
return result return result
async def get_e2e_one_time_keys( async def get_e2e_one_time_keys(