Convert ReadWriteLock to async/await. (#8202)

This commit is contained in:
Patrick Cloke 2020-08-28 16:47:11 -04:00 committed by GitHub
parent b4826d6eb1
commit d2ac767de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 33 deletions

1
changelog.d/8202.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Optional
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken from synapse.streams.config import PaginationConfig
from synapse.types import Requester, RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -247,15 +250,16 @@ class PaginationHandler(object):
) )
return purge_id return purge_id
async def _purge_history(self, purge_id, room_id, token, delete_local_events): async def _purge_history(
self, purge_id: str, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Carry out a history purge on a room. """Carry out a history purge on a room.
Args: Args:
purge_id (str): The id for this purge purge_id: The id for this purge
room_id (str): The room to purge from room_id: The room to purge from
token (str): topological token to delete events before token: topological token to delete events before
delete_local_events (bool): True to delete local events as well as delete_local_events: True to delete local events as well as remote ones
remote ones
""" """
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
@ -291,9 +295,9 @@ class PaginationHandler(object):
""" """
return self._purges_by_id.get(purge_id) return self._purges_by_id.get(purge_id)
async def purge_room(self, room_id): async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database""" """Purge the given room from the database"""
with (await self.pagination_lock.write(room_id)): with await self.pagination_lock.write(room_id):
# check we know about the room # check we know about the room
await self.store.get_room_version_id(room_id) await self.store.get_room_version_id(room_id)
@ -307,23 +311,22 @@ class PaginationHandler(object):
async def get_messages( async def get_messages(
self, self,
requester, requester: Requester,
room_id=None, room_id: Optional[str] = None,
pagin_config=None, pagin_config: Optional[PaginationConfig] = None,
as_client_event=True, as_client_event: bool = True,
event_filter=None, event_filter: Optional[Filter] = None,
): ) -> Dict[str, Any]:
"""Get messages in a room. """Get messages in a room.
Args: Args:
requester (Requester): The user requesting messages. requester: The user requesting messages.
room_id (str): The room they want messages from. room_id: The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config: The pagination config rules to apply, if any.
config rules to apply, if any. as_client_event: True to get events in client-server format.
as_client_event (bool): True to get events in client-server format. event_filter: Filter to apply to results or None
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -343,7 +346,7 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (await self.pagination_lock.read(room_id)): with await self.pagination_lock.read(room_id):
( (
membership, membership,
member_event_id, member_event_id,

View File

@ -20,6 +20,7 @@ from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union from typing import Dict, Sequence, Set, Union
import attr import attr
from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
@ -338,11 +339,11 @@ class Linearizer(object):
class ReadWriteLock(object): class ReadWriteLock(object):
"""A deferred style read write lock. """An async read write lock.
Example: Example:
with (yield read_write_lock.read("test_key")): with await read_write_lock.read("test_key"):
# do some work # do some work
""" """
@ -365,8 +366,7 @@ class ReadWriteLock(object):
# Latest writer queued # Latest writer queued
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks async def read(self, key: str) -> ContextManager:
def read(self, key):
new_defer = defer.Deferred() new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set()) curr_readers = self.key_to_current_readers.setdefault(key, set())
@ -376,7 +376,8 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore # We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers. # any existing readers... as they're readers.
yield make_deferred_yieldable(curr_writer) if curr_writer:
await make_deferred_yieldable(curr_writer)
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -388,8 +389,7 @@ class ReadWriteLock(object):
return _ctx_manager() return _ctx_manager()
@defer.inlineCallbacks async def write(self, key: str) -> ContextManager:
def write(self, key):
new_defer = defer.Deferred() new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set()) curr_readers = self.key_to_current_readers.get(key, set())
@ -405,7 +405,7 @@ class ReadWriteLock(object):
curr_readers.clear() curr_readers.clear()
self.key_to_current_writer[key] = new_defer self.key_to_current_writer[key] = new_defer
yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.util.async_helpers import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
rwlock.read(key), # 5 rwlock.read(key), # 5
rwlock.write(key), # 6 rwlock.write(key), # 6
] ]
ds = [defer.ensureDeferred(d) for d in ds]
self._assert_called_before_not_after(ds, 2) self._assert_called_before_not_after(ds, 2)
@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
with ds[6].result: with ds[6].result:
pass pass
d = rwlock.write(key) d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called) self.assertTrue(d.called)
with d.result: with d.result:
pass pass
d = rwlock.read(key) d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called) self.assertTrue(d.called)
with d.result: with d.result:
pass pass