Fix slipped logging context when media rejected (#17239)
When a module rejects a piece of media we end up trying to close the same logging context twice. Instead of fixing the existing code we refactor to use an async context manager, which is easier to write correctly.
This commit is contained in:
parent
ad179b0136
commit
bb5a692946
|
@ -0,0 +1 @@
|
||||||
|
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
|
|
@ -650,7 +650,7 @@ class MediaRepository:
|
||||||
|
|
||||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
try:
|
try:
|
||||||
length, headers = await self.client.download_media(
|
length, headers = await self.client.download_media(
|
||||||
server_name,
|
server_name,
|
||||||
|
@ -693,8 +693,6 @@ class MediaRepository:
|
||||||
)
|
)
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
await finish()
|
|
||||||
|
|
||||||
if b"Content-Type" in headers:
|
if b"Content-Type" in headers:
|
||||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||||
else:
|
else:
|
||||||
|
@ -1045,14 +1043,9 @@ class MediaRepository:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
f,
|
|
||||||
fname,
|
|
||||||
finish,
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await self.media_storage.write_to_file(t_byte_source, f)
|
await self.media_storage.write_to_file(t_byte_source, f)
|
||||||
await finish()
|
|
||||||
finally:
|
finally:
|
||||||
t_byte_source.close()
|
t_byte_source.close()
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,9 @@ from typing import (
|
||||||
IO,
|
IO,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
AsyncIterator,
|
||||||
BinaryIO,
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -97,11 +96,9 @@ class MediaStorage:
|
||||||
the file path written to in the primary media store
|
the file path written to in the primary media store
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
async with self.store_into_file(file_info) as (f, fname):
|
||||||
# Write to the main media repository
|
# Write to the main media repository
|
||||||
await self.write_to_file(source, f)
|
await self.write_to_file(source, f)
|
||||||
# Write to the other storage providers
|
|
||||||
await finish_cb()
|
|
||||||
|
|
||||||
return fname
|
return fname
|
||||||
|
|
||||||
|
@ -111,32 +108,27 @@ class MediaStorage:
|
||||||
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
|
||||||
|
|
||||||
@trace_with_opname("MediaStorage.store_into_file")
|
@trace_with_opname("MediaStorage.store_into_file")
|
||||||
@contextlib.contextmanager
|
@contextlib.asynccontextmanager
|
||||||
def store_into_file(
|
async def store_into_file(
|
||||||
self, file_info: FileInfo
|
self, file_info: FileInfo
|
||||||
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
|
) -> AsyncIterator[Tuple[BinaryIO, str]]:
|
||||||
"""Context manager used to get a file like object to write into, as
|
"""Async Context manager used to get a file like object to write into, as
|
||||||
described by file_info.
|
described by file_info.
|
||||||
|
|
||||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
Actually yields a 2-tuple (file, fname,), where file is a file
|
||||||
like object that can be written to, fname is the absolute path of file
|
like object that can be written to and fname is the absolute path of file
|
||||||
on disk, and finish_cb is a function that returns an awaitable.
|
on disk.
|
||||||
|
|
||||||
fname can be used to read the contents from after upload, e.g. to
|
fname can be used to read the contents from after upload, e.g. to
|
||||||
generate thumbnails.
|
generate thumbnails.
|
||||||
|
|
||||||
finish_cb must be called and waited on after the file has been successfully been
|
|
||||||
written to. Should not be called if there was an error. Checks for spam and
|
|
||||||
stores the file into the configured storage providers.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_info: Info about the file to store
|
file_info: Info about the file to store
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
async with media_storage.store_into_file(info) as (f, fname,):
|
||||||
# .. write into f ...
|
# .. write into f ...
|
||||||
await finish_cb()
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
path = self._file_info_to_path(file_info)
|
path = self._file_info_to_path(file_info)
|
||||||
|
@ -145,62 +137,42 @@ class MediaStorage:
|
||||||
dirname = os.path.dirname(fname)
|
dirname = os.path.dirname(fname)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
|
||||||
finished_called = [False]
|
|
||||||
|
|
||||||
main_media_repo_write_trace_scope = start_active_span(
|
main_media_repo_write_trace_scope = start_active_span(
|
||||||
"writing to main media repo"
|
"writing to main media repo"
|
||||||
)
|
)
|
||||||
main_media_repo_write_trace_scope.__enter__()
|
main_media_repo_write_trace_scope.__enter__()
|
||||||
|
|
||||||
try:
|
with main_media_repo_write_trace_scope:
|
||||||
with open(fname, "wb") as f:
|
|
||||||
|
|
||||||
async def finish() -> None:
|
|
||||||
# When someone calls finish, we assume they are done writing to the main media repo
|
|
||||||
main_media_repo_write_trace_scope.__exit__(None, None, None)
|
|
||||||
|
|
||||||
with start_active_span("writing to other storage providers"):
|
|
||||||
# Ensure that all writes have been flushed and close the
|
|
||||||
# file.
|
|
||||||
f.flush()
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
|
||||||
ReadableFileWrapper(self.clock, fname), file_info
|
|
||||||
)
|
|
||||||
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
|
||||||
logger.info("Blocking media due to spam checker")
|
|
||||||
# Note that we'll delete the stored media, due to the
|
|
||||||
# try/except below. The media also won't be stored in
|
|
||||||
# the DB.
|
|
||||||
# We currently ignore any additional field returned by
|
|
||||||
# the spam-check API.
|
|
||||||
raise SpamMediaException(errcode=spam_check[0])
|
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
|
||||||
with start_active_span(str(provider)):
|
|
||||||
await provider.store_file(path, file_info)
|
|
||||||
|
|
||||||
finished_called[0] = True
|
|
||||||
|
|
||||||
yield f, fname, finish
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
try:
|
||||||
main_media_repo_write_trace_scope.__exit__(
|
with open(fname, "wb") as f:
|
||||||
type(e), None, e.__traceback__
|
yield f, fname
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
os.remove(fname)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
with start_active_span("writing to other storage providers"):
|
||||||
|
spam_check = (
|
||||||
|
await self._spam_checker_module_callbacks.check_media_file_for_spam(
|
||||||
|
ReadableFileWrapper(self.clock, fname), file_info
|
||||||
)
|
)
|
||||||
os.remove(fname)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
raise e from None
|
|
||||||
|
|
||||||
if not finished_called:
|
|
||||||
exc = Exception("Finished callback not called")
|
|
||||||
main_media_repo_write_trace_scope.__exit__(
|
|
||||||
type(exc), None, exc.__traceback__
|
|
||||||
)
|
)
|
||||||
raise exc
|
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
|
||||||
|
logger.info("Blocking media due to spam checker")
|
||||||
|
# Note that we'll delete the stored media, due to the
|
||||||
|
# try/except below. The media also won't be stored in
|
||||||
|
# the DB.
|
||||||
|
# We currently ignore any additional field returned by
|
||||||
|
# the spam-check API.
|
||||||
|
raise SpamMediaException(errcode=spam_check[0])
|
||||||
|
|
||||||
|
for provider in self.storage_providers:
|
||||||
|
with start_active_span(str(provider)):
|
||||||
|
await provider.store_file(path, file_info)
|
||||||
|
|
||||||
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
||||||
"""Attempts to fetch media described by file_info from the local cache
|
"""Attempts to fetch media described by file_info from the local cache
|
||||||
|
|
|
@ -592,7 +592,7 @@ class UrlPreviewer:
|
||||||
|
|
||||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||||
|
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
if url.startswith("data:"):
|
if url.startswith("data:"):
|
||||||
if not allow_data_urls:
|
if not allow_data_urls:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -603,8 +603,6 @@ class UrlPreviewer:
|
||||||
else:
|
else:
|
||||||
download_result = await self._download_url(url, f)
|
download_result = await self._download_url(url, f)
|
||||||
|
|
||||||
await finish()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -93,13 +93,13 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
|
||||||
# from a regular 404.
|
# from a regular 404.
|
||||||
file_id = "abcdefg12345"
|
file_id = "abcdefg12345"
|
||||||
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
||||||
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
|
|
||||||
f,
|
media_storage = hs.get_media_repository().media_storage
|
||||||
fname,
|
|
||||||
finish,
|
ctx = media_storage.store_into_file(file_info)
|
||||||
):
|
(f, fname) = self.get_success(ctx.__aenter__())
|
||||||
f.write(SMALL_PNG)
|
f.write(SMALL_PNG)
|
||||||
self.get_success(finish())
|
self.get_success(ctx.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_cached_remote_media(
|
self.store.store_cached_remote_media(
|
||||||
|
|
|
@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
|
||||||
# from a regular 404.
|
# from a regular 404.
|
||||||
file_id = "abcdefg12345"
|
file_id = "abcdefg12345"
|
||||||
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
|
||||||
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
|
|
||||||
f,
|
media_storage = hs.get_media_repository().media_storage
|
||||||
fname,
|
|
||||||
finish,
|
ctx = media_storage.store_into_file(file_info)
|
||||||
):
|
(f, fname) = self.get_success(ctx.__aenter__())
|
||||||
f.write(SMALL_PNG)
|
f.write(SMALL_PNG)
|
||||||
self.get_success(finish())
|
self.get_success(ctx.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_cached_remote_media(
|
self.store.store_cached_remote_media(
|
||||||
|
|
Loading…
Reference in New Issue