Convert account data, device inbox, and censor events databases to async/await (#8063)
This commit is contained in:
parent
a3a59bab7b
commit
d68e10f308
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -16,15 +16,16 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
"get_account_data_for_user", get_account_data_for_user_txn
|
"get_account_data_for_user", get_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, max_entries=5000)
|
@cached(num_args=2, max_entries=5000)
|
||||||
def get_global_account_data_by_type_for_user(self, data_type, user_id):
|
async def get_global_account_data_by_type_for_user(
|
||||||
|
self, data_type: str, user_id: str
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: A dict
|
The account data.
|
||||||
"""
|
"""
|
||||||
result = yield self.db_pool.simple_select_one_onecol(
|
result = await self.db_pool.simple_select_one_onecol(
|
||||||
table="account_data",
|
table="account_data",
|
||||||
keyvalues={"user_id": user_id, "account_data_type": data_type},
|
keyvalues={"user_id": user_id, "account_data_type": data_type},
|
||||||
retcol="content",
|
retcol="content",
|
||||||
|
@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
@cached(num_args=2, cache_context=True, max_entries=5000)
|
||||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
async def is_ignored_by(
|
||||||
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
|
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
|
||||||
|
) -> bool:
|
||||||
|
ignored_account_data = await self.get_global_account_data_by_type_for_user(
|
||||||
"m.ignored_user_list",
|
"m.ignored_user_list",
|
||||||
ignorer_user_id,
|
ignorer_user_id,
|
||||||
on_invalidate=cache_context.invalidate,
|
on_invalidate=cache_context.invalidate,
|
||||||
|
@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
super(AccountDataStore, self).__init__(database, db_conn, hs)
|
super(AccountDataStore, self).__init__(database, db_conn, hs)
|
||||||
|
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self) -> int:
|
||||||
"""Get the current max stream id for the private user data stream
|
"""Get the current max stream id for the private user data stream
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred int.
|
The maximum stream ID.
|
||||||
"""
|
"""
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_account_data_to_room(
|
||||||
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
|
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
"""Add some account_data to a room for a user.
|
"""Add some account_data to a room for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to add a tag for.
|
user_id: The user to add a tag for.
|
||||||
room_id(str): The room to add a tag for.
|
room_id: The room to add a tag for.
|
||||||
account_data_type(str): The type of account_data to add.
|
account_data_type: The type of account_data to add.
|
||||||
content(dict): A json object to associate with the tag.
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that completes once the account_data has been added.
|
The maximum stream ID.
|
||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
|
@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
# no need to lock here as room_account_data has a unique constraint
|
# no need to lock here as room_account_data has a unique constraint
|
||||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||||
# retry if there is a conflict.
|
# retry if there is a conflict.
|
||||||
yield self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
desc="add_room_account_data",
|
desc="add_room_account_data",
|
||||||
table="room_account_data",
|
table="room_account_data",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
|
@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
# doesn't sound any worse than the whole update getting lost,
|
# doesn't sound any worse than the whole update getting lost,
|
||||||
# which is what would happen if we combined the two into one
|
# which is what would happen if we combined the two into one
|
||||||
# transaction.
|
# transaction.
|
||||||
yield self._update_max_stream_id(next_id)
|
await self._update_max_stream_id(next_id)
|
||||||
|
|
||||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
self.get_account_data_for_user.invalidate((user_id,))
|
||||||
|
@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
(user_id, room_id, account_data_type), content
|
(user_id, room_id, account_data_type), content
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
return result
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_account_data_for_user(
|
||||||
def add_account_data_for_user(self, user_id, account_data_type, content):
|
self, user_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> int:
|
||||||
"""Add some account_data to a room for a user.
|
"""Add some account_data to a room for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to add a tag for.
|
user_id: The user to add a tag for.
|
||||||
account_data_type(str): The type of account_data to add.
|
account_data_type: The type of account_data to add.
|
||||||
content(dict): A json object to associate with the tag.
|
content: A json object to associate with the tag.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that completes once the account_data has been added.
|
The maximum stream ID.
|
||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
|
@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
# no need to lock here as account_data has a unique constraint on
|
# no need to lock here as account_data has a unique constraint on
|
||||||
# (user_id, account_data_type) so simple_upsert will retry if
|
# (user_id, account_data_type) so simple_upsert will retry if
|
||||||
# there is a conflict.
|
# there is a conflict.
|
||||||
yield self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
desc="add_user_account_data",
|
desc="add_user_account_data",
|
||||||
table="account_data",
|
table="account_data",
|
||||||
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
|
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
|
||||||
|
@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
# Note: This is only here for backwards compat to allow admins to
|
# Note: This is only here for backwards compat to allow admins to
|
||||||
# roll back to a previous Synapse version. Next time we update the
|
# roll back to a previous Synapse version. Next time we update the
|
||||||
# database version we can remove this table.
|
# database version we can remove this table.
|
||||||
yield self._update_max_stream_id(next_id)
|
await self._update_max_stream_id(next_id)
|
||||||
|
|
||||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
self.get_account_data_for_user.invalidate((user_id,))
|
||||||
|
@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||||
(account_data_type, user_id)
|
(account_data_type, user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
return result
|
|
||||||
|
|
||||||
def _update_max_stream_id(self, next_id):
|
def _update_max_stream_id(self, next_id: int):
|
||||||
"""Update the max stream_id
|
"""Update the max stream_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
next_id(int): The the revision to advance to.
|
next_id: The the revision to advance to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Note: This is only here for backwards compat to allow admins to
|
# Note: This is only here for backwards compat to allow admins to
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.events.utils import prune_event_dict
|
from synapse.events.utils import prune_event_dict
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
updatevalues={"json": pruned_json},
|
updatevalues={"json": pruned_json},
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def expire_event(self, event_id: str) -> None:
|
||||||
def expire_event(self, event_id):
|
|
||||||
"""Retrieve and expire an event that has expired, and delete its associated
|
"""Retrieve and expire an event that has expired, and delete its associated
|
||||||
expiry timestamp. If the event can't be retrieved, delete its associated
|
expiry timestamp. If the event can't be retrieved, delete its associated
|
||||||
timestamp so we don't try to expire it again in the future.
|
timestamp so we don't try to expire it again in the future.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id (str): The ID of the event to delete.
|
event_id: The ID of the event to delete.
|
||||||
"""
|
"""
|
||||||
# Try to retrieve the event's content from the database or the event cache.
|
# Try to retrieve the event's content from the database or the event cache.
|
||||||
event = yield self.get_event(event_id)
|
event = await self.get_event(event_id)
|
||||||
|
|
||||||
def delete_expired_event_txn(txn):
|
def delete_expired_event_txn(txn):
|
||||||
# Delete the expiry timestamp associated with this event from the database.
|
# Delete the expiry timestamp associated with this event from the database.
|
||||||
|
@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
txn, "_get_event_cache", (event.event_id,)
|
txn, "_get_event_cache", (event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_expired_event", delete_expired_event_txn
|
"delete_expired_event", delete_expired_event_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
def get_to_device_stream_token(self):
|
def get_to_device_stream_token(self):
|
||||||
return self._device_inbox_id_gen.get_current_token()
|
return self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_new_messages_for_device(
|
async def get_new_messages_for_device(
|
||||||
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
last_stream_id: int,
|
||||||
|
current_stream_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> Tuple[List[dict], int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The recipient user_id.
|
user_id: The recipient user_id.
|
||||||
device_id(str): The recipient device_id.
|
device_id: The recipient device_id.
|
||||||
current_stream_id(int): The current position of the to device
|
last_stream_id: The last stream ID checked.
|
||||||
|
current_stream_id: The current position of the to device
|
||||||
message stream.
|
message stream.
|
||||||
|
limit: The maximum number of messages to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred ([dict], int): List of messages for the device and where
|
A list of messages for the device and where in the stream the messages got to.
|
||||||
in the stream the messages got to.
|
|
||||||
"""
|
"""
|
||||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||||
user_id, last_stream_id
|
user_id, last_stream_id
|
||||||
)
|
)
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return defer.succeed(([], current_stream_id))
|
return ([], current_stream_id)
|
||||||
|
|
||||||
def get_new_messages_for_device_txn(txn):
|
def get_new_messages_for_device_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
stream_pos = current_stream_id
|
stream_pos = current_stream_id
|
||||||
return messages, stream_pos
|
return messages, stream_pos
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_new_messages_for_device", get_new_messages_for_device_txn
|
"get_new_messages_for_device", get_new_messages_for_device_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def delete_messages_for_device(
|
||||||
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
|
self, user_id: str, device_id: str, up_to_stream_id: int
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The recipient user_id.
|
user_id: The recipient user_id.
|
||||||
device_id(str): The recipient device_id.
|
device_id: The recipient device_id.
|
||||||
up_to_stream_id(int): Where to delete messages up to.
|
up_to_stream_id: Where to delete messages up to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that resolves to the number of messages deleted.
|
The number of messages deleted.
|
||||||
"""
|
"""
|
||||||
# If we have cached the last stream id we've deleted up to, we can
|
# If we have cached the last stream id we've deleted up to, we can
|
||||||
# check if there is likely to be anything that needs deleting
|
# check if there is likely to be anything that needs deleting
|
||||||
|
@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||||
return txn.rowcount
|
return txn.rowcount
|
||||||
|
|
||||||
count = yield self.db_pool.runInteraction(
|
count = await self.db_pool.runInteraction(
|
||||||
"delete_messages_for_device", delete_messages_for_device_txn
|
"delete_messages_for_device", delete_messages_for_device_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
def get_new_device_msgs_for_remote(
|
async def get_new_device_msgs_for_remote(
|
||||||
self, destination, last_stream_id, current_stream_id, limit
|
self, destination, last_stream_id, current_stream_id, limit
|
||||||
):
|
) -> Tuple[List[dict], int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
destination(str): The name of the remote server.
|
destination(str): The name of the remote server.
|
||||||
|
@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
current_stream_id(int|long): The current position of the device
|
current_stream_id(int|long): The current position of the device
|
||||||
message stream.
|
message stream.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred ([dict], int|long): List of messages for the device and where
|
A list of messages for the device and where in the stream the messages got to.
|
||||||
in the stream the messages got to.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
set_tag("destination", destination)
|
set_tag("destination", destination)
|
||||||
|
@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
if not has_changed or last_stream_id == current_stream_id:
|
if not has_changed or last_stream_id == current_stream_id:
|
||||||
log_kv({"message": "No new messages in stream"})
|
log_kv({"message": "No new messages in stream"})
|
||||||
return defer.succeed(([], current_stream_id))
|
return ([], current_stream_id)
|
||||||
|
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
# This can happen if we run out of room for EDUs in the transaction.
|
# This can happen if we run out of room for EDUs in the transaction.
|
||||||
return defer.succeed(([], last_stream_id))
|
return ([], last_stream_id)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
def get_new_messages_for_remote_destination_txn(txn):
|
def get_new_messages_for_remote_destination_txn(txn):
|
||||||
|
@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
stream_pos = current_stream_id
|
stream_pos = current_stream_id
|
||||||
return messages, stream_pos
|
return messages, stream_pos
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_new_device_msgs_for_remote",
|
"get_new_device_msgs_for_remote",
|
||||||
get_new_messages_for_remote_destination_txn,
|
get_new_messages_for_remote_destination_txn,
|
||||||
)
|
)
|
||||||
|
@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
|
||||||
def reindex_txn(conn):
|
def reindex_txn(conn):
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||||
txn.close()
|
txn.close()
|
||||||
|
|
||||||
yield self.db_pool.runWithConnection(reindex_txn)
|
await self.db_pool.runWithConnection(reindex_txn)
|
||||||
|
|
||||||
yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def add_messages_to_device_inbox(
|
||||||
def add_messages_to_device_inbox(
|
self,
|
||||||
self, local_messages_by_user_then_device, remote_messages_by_destination
|
local_messages_by_user_then_device: dict,
|
||||||
):
|
remote_messages_by_destination: dict,
|
||||||
|
) -> int:
|
||||||
"""Used to send messages from this server.
|
"""Used to send messages from this server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sender_user_id(str): The ID of the user sending these messages.
|
local_messages_by_user_and_device:
|
||||||
local_messages_by_user_and_device(dict):
|
|
||||||
Dictionary of user_id to device_id to message.
|
Dictionary of user_id to device_id to message.
|
||||||
remote_messages_by_destination(dict):
|
remote_messages_by_destination:
|
||||||
Dictionary of destination server_name to the EDU JSON to send.
|
Dictionary of destination server_name to the EDU JSON to send.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred stream_id that resolves when the messages have been
|
The new stream_id.
|
||||||
inserted.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_messages_txn(txn, now_ms, stream_id):
|
def add_messages_txn(txn, now_ms, stream_id):
|
||||||
|
@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
|
|
||||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||||
)
|
)
|
||||||
for user_id in local_messages_by_user_then_device.keys():
|
for user_id in local_messages_by_user_then_device.keys():
|
||||||
|
@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
|
|
||||||
return self._device_inbox_id_gen.get_current_token()
|
return self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_messages_from_remote_to_device_inbox(
|
||||||
def add_messages_from_remote_to_device_inbox(
|
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
|
||||||
self, origin, message_id, local_messages_by_user_then_device
|
) -> int:
|
||||||
):
|
|
||||||
def add_messages_txn(txn, now_ms, stream_id):
|
def add_messages_txn(txn, now_ms, stream_id):
|
||||||
# Check if we've already inserted a matching message_id for that
|
# Check if we've already inserted a matching message_id for that
|
||||||
# origin. This can happen if the origin doesn't receive our
|
# origin. This can happen if the origin doesn't receive our
|
||||||
|
@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
|
|
||||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_from_remote_to_device_inbox",
|
"add_messages_from_remote_to_device_inbox",
|
||||||
add_messages_txn,
|
add_messages_txn,
|
||||||
now_ms,
|
now_ms,
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
from tests.utils import register_federation_servlets
|
from tests.utils import register_federation_servlets
|
||||||
|
|
||||||
|
@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.datastore.get_current_state_deltas.return_value = (0, None)
|
self.datastore.get_current_state_deltas.return_value = (0, None)
|
||||||
|
|
||||||
self.datastore.get_to_device_stream_token = lambda: 0
|
self.datastore.get_to_device_stream_token = lambda: 0
|
||||||
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
|
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
|
||||||
([], 0)
|
([], 0)
|
||||||
)
|
)
|
||||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||||
|
|
Loading…
Reference in New Issue