Convert ReadWriteLock to async/await. (#8202)
This commit is contained in:
parent
b4826d6eb1
commit
d2ac767de2
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -14,15 +14,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import Requester, RoomStreamToken
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -247,15 +250,16 @@ class PaginationHandler(object):
|
|||
)
|
||||
return purge_id
|
||||
|
||||
async def _purge_history(self, purge_id, room_id, token, delete_local_events):
|
||||
async def _purge_history(
|
||||
self, purge_id: str, room_id: str, token: str, delete_local_events: bool
|
||||
) -> None:
|
||||
"""Carry out a history purge on a room.
|
||||
|
||||
Args:
|
||||
purge_id (str): The id for this purge
|
||||
room_id (str): The room to purge from
|
||||
token (str): topological token to delete events before
|
||||
delete_local_events (bool): True to delete local events as well as
|
||||
remote ones
|
||||
purge_id: The id for this purge
|
||||
room_id: The room to purge from
|
||||
token: topological token to delete events before
|
||||
delete_local_events: True to delete local events as well as remote ones
|
||||
"""
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
|
@ -291,9 +295,9 @@ class PaginationHandler(object):
|
|||
"""
|
||||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
async def purge_room(self, room_id):
|
||||
async def purge_room(self, room_id: str) -> None:
|
||||
"""Purge the given room from the database"""
|
||||
with (await self.pagination_lock.write(room_id)):
|
||||
with await self.pagination_lock.write(room_id):
|
||||
# check we know about the room
|
||||
await self.store.get_room_version_id(room_id)
|
||||
|
||||
|
@ -307,23 +311,22 @@ class PaginationHandler(object):
|
|||
|
||||
async def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
event_filter=None,
|
||||
):
|
||||
requester: Requester,
|
||||
room_id: Optional[str] = None,
|
||||
pagin_config: Optional[PaginationConfig] = None,
|
||||
as_client_event: bool = True,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
requester (Requester): The user requesting messages.
|
||||
room_id (str): The room they want messages from.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
event_filter (Filter): Filter to apply to results or None
|
||||
requester: The user requesting messages.
|
||||
room_id: The room they want messages from.
|
||||
pagin_config: The pagination config rules to apply, if any.
|
||||
as_client_event: True to get events in client-server format.
|
||||
event_filter: Filter to apply to results or None
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
Pagination API results
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -343,7 +346,7 @@ class PaginationHandler(object):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (await self.pagination_lock.read(room_id)):
|
||||
with await self.pagination_lock.read(room_id):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
|
|
|
@ -20,6 +20,7 @@ from contextlib import contextmanager
|
|||
from typing import Dict, Sequence, Set, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
@ -338,11 +339,11 @@ class Linearizer(object):
|
|||
|
||||
|
||||
class ReadWriteLock(object):
|
||||
"""A deferred style read write lock.
|
||||
"""An async read write lock.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield read_write_lock.read("test_key")):
|
||||
with await read_write_lock.read("test_key"):
|
||||
# do some work
|
||||
"""
|
||||
|
||||
|
@ -365,8 +366,7 @@ class ReadWriteLock(object):
|
|||
# Latest writer queued
|
||||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def read(self, key):
|
||||
async def read(self, key: str) -> ContextManager:
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
|
@ -376,7 +376,8 @@ class ReadWriteLock(object):
|
|||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
yield make_deferred_yieldable(curr_writer)
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
@ -388,8 +389,7 @@ class ReadWriteLock(object):
|
|||
|
||||
return _ctx_manager()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write(self, key):
|
||||
async def write(self, key: str) -> ContextManager:
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
|
@ -405,7 +405,7 @@ class ReadWriteLock(object):
|
|||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
|
||||
|
@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
|||
rwlock.read(key), # 5
|
||||
rwlock.write(key), # 6
|
||||
]
|
||||
ds = [defer.ensureDeferred(d) for d in ds]
|
||||
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
|
||||
|
@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
|||
with ds[6].result:
|
||||
pass
|
||||
|
||||
d = rwlock.write(key)
|
||||
d = defer.ensureDeferred(rwlock.write(key))
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
|
||||
d = rwlock.read(key)
|
||||
d = defer.ensureDeferred(rwlock.read(key))
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue