Return attrs for more media repo APIs. (#16611)
This commit is contained in:
parent
91587d4cf9
commit
ff716b483b
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
@ -23,6 +23,7 @@ from synapse.api.errors import (
|
||||||
StoreError,
|
StoreError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
|
||||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
@ -306,7 +307,9 @@ class ProfileHandler:
|
||||||
server_name = host
|
server_name = host
|
||||||
|
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info: Optional[
|
||||||
|
Union[LocalMedia, RemoteMedia]
|
||||||
|
] = await self.store.get_local_media(media_id)
|
||||||
else:
|
else:
|
||||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||||
|
|
||||||
|
@ -322,12 +325,12 @@ class ProfileHandler:
|
||||||
|
|
||||||
if self.max_avatar_size:
|
if self.max_avatar_size:
|
||||||
# Ensure avatar does not exceed max allowed avatar size
|
# Ensure avatar does not exceed max allowed avatar size
|
||||||
if media_info["media_length"] > self.max_avatar_size:
|
if media_info.media_length > self.max_avatar_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Forbidding avatar change to %s: %d bytes is above the allowed size "
|
"Forbidding avatar change to %s: %d bytes is above the allowed size "
|
||||||
"limit",
|
"limit",
|
||||||
mxc,
|
mxc,
|
||||||
media_info["media_length"],
|
media_info.media_length,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -335,12 +338,12 @@ class ProfileHandler:
|
||||||
# Ensure the avatar's file type is allowed
|
# Ensure the avatar's file type is allowed
|
||||||
if (
|
if (
|
||||||
self.allowed_avatar_mimetypes
|
self.allowed_avatar_mimetypes
|
||||||
and media_info["media_type"] not in self.allowed_avatar_mimetypes
|
and media_info.media_type not in self.allowed_avatar_mimetypes
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Forbidding avatar change to %s: mimetype %s not allowed",
|
"Forbidding avatar change to %s: mimetype %s not allowed",
|
||||||
mxc,
|
mxc,
|
||||||
media_info["media_type"],
|
media_info.media_type,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -806,7 +806,7 @@ class SsoHandler:
|
||||||
media_id = profile["avatar_url"].split("/")[-1]
|
media_id = profile["avatar_url"].split("/")[-1]
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
media = await self._media_repo.store.get_local_media(media_id)
|
media = await self._media_repo.store.get_local_media(media_id)
|
||||||
if media is not None and upload_name == media["upload_name"]:
|
if media is not None and upload_name == media.upload_name:
|
||||||
logger.info("skipping saving the user avatar")
|
logger.info("skipping saving the user avatar")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import shutil
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
from matrix_common.types.mxc_uri import MXCUri
|
from matrix_common.types.mxc_uri import MXCUri
|
||||||
|
|
||||||
import twisted.internet.error
|
import twisted.internet.error
|
||||||
|
@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
|
||||||
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
||||||
from synapse.media.url_previewer import UrlPreviewer
|
from synapse.media.url_previewer import UrlPreviewer
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.databases.main.media_repository import RemoteMedia
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -245,18 +247,18 @@ class MediaRepository:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
"""
|
"""
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.store.get_local_media(media_id)
|
||||||
if not media_info or media_info["quarantined_by"]:
|
if not media_info or media_info.quarantined_by:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.mark_recently_accessed(None, media_id)
|
self.mark_recently_accessed(None, media_id)
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info.media_type
|
||||||
if not media_type:
|
if not media_type:
|
||||||
media_type = "application/octet-stream"
|
media_type = "application/octet-stream"
|
||||||
media_length = media_info["media_length"]
|
media_length = media_info.media_length
|
||||||
upload_name = name if name else media_info["upload_name"]
|
upload_name = name if name else media_info.upload_name
|
||||||
url_cache = media_info["url_cache"]
|
url_cache = media_info.url_cache
|
||||||
|
|
||||||
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
|
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
|
||||||
|
|
||||||
|
@ -310,16 +312,20 @@ class MediaRepository:
|
||||||
|
|
||||||
# We deliberately stream the file outside the lock
|
# We deliberately stream the file outside the lock
|
||||||
if responder:
|
if responder:
|
||||||
media_type = media_info["media_type"]
|
upload_name = name if name else media_info.upload_name
|
||||||
media_length = media_info["media_length"]
|
|
||||||
upload_name = name if name else media_info["upload_name"]
|
|
||||||
await respond_with_responder(
|
await respond_with_responder(
|
||||||
request, responder, media_type, media_length, upload_name
|
request,
|
||||||
|
responder,
|
||||||
|
media_info.media_type,
|
||||||
|
media_info.media_length,
|
||||||
|
upload_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
|
async def get_remote_media_info(
|
||||||
|
self, server_name: str, media_id: str
|
||||||
|
) -> RemoteMedia:
|
||||||
"""Gets the media info associated with the remote file, downloading
|
"""Gets the media info associated with the remote file, downloading
|
||||||
if necessary.
|
if necessary.
|
||||||
|
|
||||||
|
@ -353,7 +359,7 @@ class MediaRepository:
|
||||||
|
|
||||||
async def _get_remote_media_impl(
|
async def _get_remote_media_impl(
|
||||||
self, server_name: str, media_id: str
|
self, server_name: str, media_id: str
|
||||||
) -> Tuple[Optional[Responder], dict]:
|
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||||
"""Looks for media in local cache, if not there then attempt to
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
download from remote server.
|
download from remote server.
|
||||||
|
|
||||||
|
@ -373,15 +379,17 @@ class MediaRepository:
|
||||||
|
|
||||||
# If we have an entry in the DB, try and look for it
|
# If we have an entry in the DB, try and look for it
|
||||||
if media_info:
|
if media_info:
|
||||||
file_id = media_info["filesystem_id"]
|
file_id = media_info.filesystem_id
|
||||||
file_info = FileInfo(server_name, file_id)
|
file_info = FileInfo(server_name, file_id)
|
||||||
|
|
||||||
if media_info["quarantined_by"]:
|
if media_info.quarantined_by:
|
||||||
logger.info("Media is quarantined")
|
logger.info("Media is quarantined")
|
||||||
raise NotFoundError()
|
raise NotFoundError()
|
||||||
|
|
||||||
if not media_info["media_type"]:
|
if not media_info.media_type:
|
||||||
media_info["media_type"] = "application/octet-stream"
|
media_info = attr.evolve(
|
||||||
|
media_info, media_type="application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
if responder:
|
if responder:
|
||||||
|
@ -403,9 +411,9 @@ class MediaRepository:
|
||||||
if not media_info:
|
if not media_info:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
file_id = media_info["filesystem_id"]
|
file_id = media_info.filesystem_id
|
||||||
if not media_info["media_type"]:
|
if not media_info.media_type:
|
||||||
media_info["media_type"] = "application/octet-stream"
|
media_info = attr.evolve(media_info, media_type="application/octet-stream")
|
||||||
file_info = FileInfo(server_name, file_id)
|
file_info = FileInfo(server_name, file_id)
|
||||||
|
|
||||||
# We generate thumbnails even if another process downloaded the media
|
# We generate thumbnails even if another process downloaded the media
|
||||||
|
@ -415,7 +423,7 @@ class MediaRepository:
|
||||||
# otherwise they'll request thumbnails and get a 404 if they're not
|
# otherwise they'll request thumbnails and get a 404 if they're not
|
||||||
# ready yet.
|
# ready yet.
|
||||||
await self._generate_thumbnails(
|
await self._generate_thumbnails(
|
||||||
server_name, media_id, file_id, media_info["media_type"]
|
server_name, media_id, file_id, media_info.media_type
|
||||||
)
|
)
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
@ -425,7 +433,7 @@ class MediaRepository:
|
||||||
self,
|
self,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
) -> dict:
|
) -> RemoteMedia:
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
|
||||||
|
@ -518,7 +526,7 @@ class MediaRepository:
|
||||||
origin=server_name,
|
origin=server_name,
|
||||||
media_id=media_id,
|
media_id=media_id,
|
||||||
media_type=media_type,
|
media_type=media_type,
|
||||||
time_now_ms=self.clock.time_msec(),
|
time_now_ms=time_now_ms,
|
||||||
upload_name=upload_name,
|
upload_name=upload_name,
|
||||||
media_length=length,
|
media_length=length,
|
||||||
filesystem_id=file_id,
|
filesystem_id=file_id,
|
||||||
|
@ -526,15 +534,17 @@ class MediaRepository:
|
||||||
|
|
||||||
logger.info("Stored remote media in file %r", fname)
|
logger.info("Stored remote media in file %r", fname)
|
||||||
|
|
||||||
media_info = {
|
return RemoteMedia(
|
||||||
"media_type": media_type,
|
media_origin=server_name,
|
||||||
"media_length": length,
|
media_id=media_id,
|
||||||
"upload_name": upload_name,
|
media_type=media_type,
|
||||||
"created_ts": time_now_ms,
|
media_length=length,
|
||||||
"filesystem_id": file_id,
|
upload_name=upload_name,
|
||||||
}
|
created_ts=time_now_ms,
|
||||||
|
filesystem_id=file_id,
|
||||||
return media_info
|
last_access_ts=time_now_ms,
|
||||||
|
quarantined_by=None,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_thumbnail_requirements(
|
def _get_thumbnail_requirements(
|
||||||
self, media_type: str
|
self, media_type: str
|
||||||
|
|
|
@ -240,15 +240,14 @@ class UrlPreviewer:
|
||||||
cache_result = await self.store.get_url_cache(url, ts)
|
cache_result = await self.store.get_url_cache(url, ts)
|
||||||
if (
|
if (
|
||||||
cache_result
|
cache_result
|
||||||
and cache_result["expires_ts"] > ts
|
and cache_result.expires_ts > ts
|
||||||
and cache_result["response_code"] / 100 == 2
|
and cache_result.response_code // 100 == 2
|
||||||
):
|
):
|
||||||
# It may be stored as text in the database, not as bytes (such as
|
# It may be stored as text in the database, not as bytes (such as
|
||||||
# PostgreSQL). If so, encode it back before handing it on.
|
# PostgreSQL). If so, encode it back before handing it on.
|
||||||
og = cache_result["og"]
|
if isinstance(cache_result.og, str):
|
||||||
if isinstance(og, str):
|
return cache_result.og.encode("utf8")
|
||||||
og = og.encode("utf8")
|
return cache_result.og
|
||||||
return og
|
|
||||||
|
|
||||||
# If this URL can be accessed via an allowed oEmbed, use that instead.
|
# If this URL can be accessed via an allowed oEmbed, use that instead.
|
||||||
url_to_download = url
|
url_to_download = url
|
||||||
|
|
|
@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
if media_info["quarantined_by"]:
|
if media_info.quarantined_by:
|
||||||
logger.info("Media is quarantined")
|
logger.info("Media is quarantined")
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
|
||||||
thumbnail_infos,
|
thumbnail_infos,
|
||||||
media_id,
|
media_id,
|
||||||
media_id,
|
media_id,
|
||||||
url_cache=bool(media_info["url_cache"]),
|
url_cache=bool(media_info.url_cache),
|
||||||
server_name=None,
|
server_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
if media_info["quarantined_by"]:
|
if media_info.quarantined_by:
|
||||||
logger.info("Media is quarantined")
|
logger.info("Media is quarantined")
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
|
||||||
file_info = FileInfo(
|
file_info = FileInfo(
|
||||||
server_name=None,
|
server_name=None,
|
||||||
file_id=media_id,
|
file_id=media_id,
|
||||||
url_cache=media_info["url_cache"],
|
url_cache=bool(media_info.url_cache),
|
||||||
thumbnail=info,
|
thumbnail=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
|
||||||
desired_height,
|
desired_height,
|
||||||
desired_method,
|
desired_method,
|
||||||
desired_type,
|
desired_type,
|
||||||
url_cache=bool(media_info["url_cache"]),
|
url_cache=bool(media_info.url_cache),
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_path:
|
if file_path:
|
||||||
|
@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
|
|
||||||
file_id = media_info["filesystem_id"]
|
file_id = media_info.filesystem_id
|
||||||
|
|
||||||
for info in thumbnail_infos:
|
for info in thumbnail_infos:
|
||||||
t_w = info.width == desired_width
|
t_w = info.width == desired_width
|
||||||
|
@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
|
||||||
if t_w and t_h and t_method and t_type:
|
if t_w and t_h and t_method and t_type:
|
||||||
file_info = FileInfo(
|
file_info = FileInfo(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
file_id=media_info["filesystem_id"],
|
file_id=file_id,
|
||||||
thumbnail=info,
|
thumbnail=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
|
||||||
m_type,
|
m_type,
|
||||||
thumbnail_infos,
|
thumbnail_infos,
|
||||||
media_id,
|
media_id,
|
||||||
media_info["filesystem_id"],
|
media_info.filesystem_id,
|
||||||
url_cache=False,
|
url_cache=False,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -54,11 +52,32 @@ class LocalMedia:
|
||||||
media_length: int
|
media_length: int
|
||||||
upload_name: str
|
upload_name: str
|
||||||
created_ts: int
|
created_ts: int
|
||||||
|
url_cache: Optional[str]
|
||||||
last_access_ts: int
|
last_access_ts: int
|
||||||
quarantined_by: Optional[str]
|
quarantined_by: Optional[str]
|
||||||
safe_from_quarantine: bool
|
safe_from_quarantine: bool
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class RemoteMedia:
|
||||||
|
media_origin: str
|
||||||
|
media_id: str
|
||||||
|
media_type: str
|
||||||
|
media_length: int
|
||||||
|
upload_name: Optional[str]
|
||||||
|
filesystem_id: str
|
||||||
|
created_ts: int
|
||||||
|
last_access_ts: int
|
||||||
|
quarantined_by: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class UrlCache:
|
||||||
|
response_code: int
|
||||||
|
expires_ts: int
|
||||||
|
og: Union[str, bytes]
|
||||||
|
|
||||||
|
|
||||||
class MediaSortOrder(Enum):
|
class MediaSortOrder(Enum):
|
||||||
"""
|
"""
|
||||||
Enum to define the sorting method used when returning media with
|
Enum to define the sorting method used when returning media with
|
||||||
|
@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
self.server_name: str = hs.hostname
|
self.server_name: str = hs.hostname
|
||||||
|
|
||||||
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
|
||||||
"""Get the metadata for a local piece of media
|
"""Get the metadata for a local piece of media
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None if the media_id doesn't exist.
|
None if the media_id doesn't exist.
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
"local_media_repository",
|
"local_media_repository",
|
||||||
{"media_id": media_id},
|
{"media_id": media_id},
|
||||||
(
|
(
|
||||||
|
@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"created_ts",
|
"created_ts",
|
||||||
"quarantined_by",
|
"quarantined_by",
|
||||||
"url_cache",
|
"url_cache",
|
||||||
|
"last_access_ts",
|
||||||
"safe_from_quarantine",
|
"safe_from_quarantine",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_local_media",
|
desc="get_local_media",
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return LocalMedia(media_id=media_id, **row)
|
||||||
|
|
||||||
async def get_local_media_by_user_paginate(
|
async def get_local_media_by_user_paginate(
|
||||||
self,
|
self,
|
||||||
|
@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
media_length,
|
media_length,
|
||||||
upload_name,
|
upload_name,
|
||||||
created_ts,
|
created_ts,
|
||||||
|
url_cache,
|
||||||
last_access_ts,
|
last_access_ts,
|
||||||
quarantined_by,
|
quarantined_by,
|
||||||
safe_from_quarantine
|
safe_from_quarantine
|
||||||
|
@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
media_length=row[2],
|
media_length=row[2],
|
||||||
upload_name=row[3],
|
upload_name=row[3],
|
||||||
created_ts=row[4],
|
created_ts=row[4],
|
||||||
last_access_ts=row[5],
|
url_cache=row[5],
|
||||||
quarantined_by=row[6],
|
last_access_ts=row[6],
|
||||||
safe_from_quarantine=bool(row[7]),
|
quarantined_by=row[7],
|
||||||
|
safe_from_quarantine=bool(row[8]),
|
||||||
)
|
)
|
||||||
for row in txn
|
for row in txn
|
||||||
]
|
]
|
||||||
|
@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="mark_local_media_as_safe",
|
desc="mark_local_media_as_safe",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
|
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
|
||||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||||
Returns:
|
Returns:
|
||||||
None if the URL isn't cached.
|
None if the URL isn't cached.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
|
||||||
# get the most recently cached result (relative to the given ts)
|
# get the most recently cached result (relative to the given ts)
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
SELECT response_code, expires_ts, og
|
||||||
" FROM local_media_repository_url_cache"
|
FROM local_media_repository_url_cache
|
||||||
" WHERE url = ? AND download_ts <= ?"
|
WHERE url = ? AND download_ts <= ?
|
||||||
" ORDER BY download_ts DESC LIMIT 1"
|
ORDER BY download_ts DESC LIMIT 1
|
||||||
)
|
"""
|
||||||
txn.execute(sql, (url, ts))
|
txn.execute(sql, (url, ts))
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
# ...or if we've requested a timestamp older than the oldest
|
# ...or if we've requested a timestamp older than the oldest
|
||||||
# copy in the cache, return the oldest copy (if any)
|
# copy in the cache, return the oldest copy (if any)
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
SELECT response_code, expires_ts, og
|
||||||
" FROM local_media_repository_url_cache"
|
FROM local_media_repository_url_cache
|
||||||
" WHERE url = ? AND download_ts > ?"
|
WHERE url = ? AND download_ts > ?
|
||||||
" ORDER BY download_ts ASC LIMIT 1"
|
ORDER BY download_ts ASC LIMIT 1
|
||||||
)
|
"""
|
||||||
txn.execute(sql, (url, ts))
|
txn.execute(sql, (url, ts))
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dict(
|
return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
|
||||||
zip(
|
|
||||||
(
|
|
||||||
"response_code",
|
|
||||||
"etag",
|
|
||||||
"expires_ts",
|
|
||||||
"og",
|
|
||||||
"media_id",
|
|
||||||
"download_ts",
|
|
||||||
),
|
|
||||||
row,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||||
|
|
||||||
|
@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
response_code: int,
|
response_code: int,
|
||||||
etag: Optional[str],
|
etag: Optional[str],
|
||||||
expires_ts: int,
|
expires_ts: int,
|
||||||
og: Optional[str],
|
og: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
download_ts: int,
|
download_ts: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
|
|
||||||
async def get_cached_remote_media(
|
async def get_cached_remote_media(
|
||||||
self, origin: str, media_id: str
|
self, origin: str, media_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[RemoteMedia]:
|
||||||
return await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
"remote_media_cache",
|
"remote_media_cache",
|
||||||
{"media_origin": origin, "media_id": media_id},
|
{"media_origin": origin, "media_id": media_id},
|
||||||
(
|
(
|
||||||
|
@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"upload_name",
|
"upload_name",
|
||||||
"created_ts",
|
"created_ts",
|
||||||
"filesystem_id",
|
"filesystem_id",
|
||||||
|
"last_access_ts",
|
||||||
"quarantined_by",
|
"quarantined_by",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_cached_remote_media",
|
desc="get_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return row
|
||||||
|
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
|
||||||
|
|
||||||
async def store_cached_remote_media(
|
async def store_cached_remote_media(
|
||||||
self,
|
self,
|
||||||
|
@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
t_width: int,
|
t_width: int,
|
||||||
t_height: int,
|
t_height: int,
|
||||||
t_type: str,
|
t_type: str,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[ThumbnailInfo]:
|
||||||
"""Fetch the thumbnail info of given width, height and type."""
|
"""Fetch the thumbnail info of given width, height and type."""
|
||||||
|
|
||||||
return await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
table="remote_media_cache_thumbnails",
|
table="remote_media_cache_thumbnails",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"media_origin": origin,
|
"media_origin": origin,
|
||||||
|
@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"thumbnail_method",
|
"thumbnail_method",
|
||||||
"thumbnail_type",
|
"thumbnail_type",
|
||||||
"thumbnail_length",
|
"thumbnail_length",
|
||||||
"filesystem_id",
|
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_remote_media_thumbnail",
|
desc="get_remote_media_thumbnail",
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return ThumbnailInfo(
|
||||||
|
width=row["thumbnail_width"],
|
||||||
|
height=row["thumbnail_height"],
|
||||||
|
method=row["thumbnail_method"],
|
||||||
|
type=row["thumbnail_type"],
|
||||||
|
length=row["thumbnail_length"],
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def store_remote_media_thumbnail(
|
async def store_remote_media_thumbnail(
|
||||||
|
|
|
@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
origin, media_id = self.media_id.split("/")
|
origin, media_id = self.media_id.split("/")
|
||||||
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
|
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
|
||||||
assert info is not None
|
assert info is not None
|
||||||
file_id = info["filesystem_id"]
|
file_id = info.filesystem_id
|
||||||
|
|
||||||
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
|
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
|
||||||
origin, file_id
|
origin, file_id
|
||||||
|
|
|
@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info.quarantined_by)
|
||||||
|
|
||||||
# quarantining
|
# quarantining
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["quarantined_by"])
|
self.assertTrue(media_info.quarantined_by)
|
||||||
|
|
||||||
# remove from quarantine
|
# remove from quarantine
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info.quarantined_by)
|
||||||
|
|
||||||
def test_quarantine_protected_media(self) -> None:
|
def test_quarantine_protected_media(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||||
# verify protection
|
# verify protection
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["safe_from_quarantine"])
|
self.assertTrue(media_info.safe_from_quarantine)
|
||||||
|
|
||||||
# quarantining
|
# quarantining
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
|
||||||
# verify that is not in quarantine
|
# verify that is not in quarantine
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info.quarantined_by)
|
||||||
|
|
||||||
|
|
||||||
class ProtectMediaByIDTestCase(_AdminMediaTests):
|
class ProtectMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["safe_from_quarantine"])
|
self.assertFalse(media_info.safe_from_quarantine)
|
||||||
|
|
||||||
# protect
|
# protect
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["safe_from_quarantine"])
|
self.assertTrue(media_info.safe_from_quarantine)
|
||||||
|
|
||||||
# unprotect
|
# unprotect
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
assert media_info is not None
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["safe_from_quarantine"])
|
self.assertFalse(media_info.safe_from_quarantine)
|
||||||
|
|
||||||
|
|
||||||
class PurgeMediaCacheTestCase(_AdminMediaTests):
|
class PurgeMediaCacheTestCase(_AdminMediaTests):
|
||||||
|
|
|
@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
||||||
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
|
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
|
||||||
"""Given an MXC URI, assert whether it has been purged or not."""
|
"""Given an MXC URI, assert whether it has been purged or not."""
|
||||||
if mxc_uri.server_name == self.hs.config.server.server_name:
|
if mxc_uri.server_name == self.hs.config.server.server_name:
|
||||||
found_media_dict = self.get_success(
|
found_media = bool(
|
||||||
self.store.get_local_media(mxc_uri.media_id)
|
self.get_success(self.store.get_local_media(mxc_uri.media_id))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
found_media_dict = self.get_success(
|
found_media = bool(
|
||||||
self.store.get_cached_remote_media(
|
self.get_success(
|
||||||
mxc_uri.server_name, mxc_uri.media_id
|
self.store.get_cached_remote_media(
|
||||||
|
mxc_uri.server_name, mxc_uri.media_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if expect_purged:
|
if expect_purged:
|
||||||
self.assertIsNone(
|
self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
|
||||||
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.assertIsNotNone(
|
self.assertTrue(
|
||||||
found_media_dict,
|
found_media,
|
||||||
msg=f"{mxc_uri} unexpectedly purged",
|
msg=f"{mxc_uri} unexpectedly purged",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue