Asynchronous Uploads (#15503)
Support asynchronous uploads as defined in MSC2246.
This commit is contained in:
parent
80922dc46e
commit
999bd77d3a
|
@ -0,0 +1 @@
|
||||||
|
Add support for asynchronous uploads as defined by [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246). Contributed by @sumnerevans at @beeper.
|
|
@ -1753,6 +1753,19 @@ rc_third_party_invite:
|
||||||
burst_count: 10
|
burst_count: 10
|
||||||
```
|
```
|
||||||
---
|
---
|
||||||
|
### `rc_media_create`
|
||||||
|
|
||||||
|
This option ratelimits creation of MXC URIs via the `/_matrix/media/v1/create`
|
||||||
|
endpoint based on the account that's creating the media. Defaults to
|
||||||
|
`per_second: 10`, `burst_count: 50`.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
rc_media_create:
|
||||||
|
per_second: 10
|
||||||
|
burst_count: 50
|
||||||
|
```
|
||||||
|
---
|
||||||
### `rc_federation`
|
### `rc_federation`
|
||||||
|
|
||||||
Defines limits on federation requests.
|
Defines limits on federation requests.
|
||||||
|
@ -1814,6 +1827,27 @@ Example configuration:
|
||||||
media_store_path: "DATADIR/media_store"
|
media_store_path: "DATADIR/media_store"
|
||||||
```
|
```
|
||||||
---
|
---
|
||||||
|
### `max_pending_media_uploads`
|
||||||
|
|
||||||
|
How many *pending media uploads* can a given user have? A pending media upload
|
||||||
|
is a created MXC URI that (a) is not expired (the `unused_expires_at` timestamp
|
||||||
|
has not passed) and (b) the media has not yet been uploaded for. Defaults to 5.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
max_pending_media_uploads: 5
|
||||||
|
```
|
||||||
|
---
|
||||||
|
### `unused_expiration_time`
|
||||||
|
|
||||||
|
How long to wait in milliseconds before expiring created media IDs. Defaults to
|
||||||
|
"24h"
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
unused_expiration_time: "1h"
|
||||||
|
```
|
||||||
|
---
|
||||||
### `media_storage_providers`
|
### `media_storage_providers`
|
||||||
|
|
||||||
Media storage providers allow media to be stored in different
|
Media storage providers allow media to be stored in different
|
||||||
|
|
|
@ -83,6 +83,8 @@ class Codes(str, Enum):
|
||||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||||
# USER_LOCKED = "M_USER_LOCKED"
|
# USER_LOCKED = "M_USER_LOCKED"
|
||||||
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
|
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
|
||||||
|
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
|
||||||
|
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
|
||||||
|
|
||||||
# Part of MSC3848
|
# Part of MSC3848
|
||||||
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
|
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
|
||||||
|
|
|
@ -204,3 +204,10 @@ class RatelimitConfig(Config):
|
||||||
"rc_third_party_invite",
|
"rc_third_party_invite",
|
||||||
defaults={"per_second": 0.0025, "burst_count": 5},
|
defaults={"per_second": 0.0025, "burst_count": 5},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ratelimit create media requests:
|
||||||
|
self.rc_media_create = RatelimitSettings.parse(
|
||||||
|
config,
|
||||||
|
"rc_media_create",
|
||||||
|
defaults={"per_second": 10, "burst_count": 50},
|
||||||
|
)
|
||||||
|
|
|
@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config):
|
||||||
"prevent_media_downloads_from", []
|
"prevent_media_downloads_from", []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.unused_expiration_time = self.parse_duration(
|
||||||
|
config.get("unused_expiration_time", "24h")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5)
|
||||||
|
|
||||||
self.media_store_path = self.ensure_directory(
|
self.media_store_path = self.ensure_directory(
|
||||||
config.get("media_store_path", "media_store")
|
config.get("media_store_path", "media_store")
|
||||||
)
|
)
|
||||||
|
|
|
@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
|
||||||
"audio/x-flac",
|
"audio/x-flac",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Default timeout_ms for download and thumbnail requests
|
||||||
|
DEFAULT_MAX_TIMEOUT_MS = 20_000
|
||||||
|
|
||||||
|
# Maximum allowed timeout_ms for download and thumbnail requests
|
||||||
|
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
|
||||||
|
|
||||||
|
|
||||||
def respond_404(request: SynapseRequest) -> None:
|
def respond_404(request: SynapseRequest) -> None:
|
||||||
assert request.path is not None
|
assert request.path is not None
|
||||||
|
|
|
@ -27,13 +27,16 @@ import twisted.web.http
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
|
Codes,
|
||||||
FederationDeniedError,
|
FederationDeniedError,
|
||||||
HttpResponseException,
|
HttpResponseException,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
|
cs_error,
|
||||||
)
|
)
|
||||||
from synapse.config.repository import ThumbnailRequirement
|
from synapse.config.repository import ThumbnailRequirement
|
||||||
|
from synapse.http.server import respond_with_json
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
|
@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
|
||||||
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
|
||||||
from synapse.media.url_previewer import UrlPreviewer
|
from synapse.media.url_previewer import UrlPreviewer
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.databases.main.media_repository import RemoteMedia
|
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
@ -80,6 +83,8 @@ class MediaRepository:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.max_upload_size = hs.config.media.max_upload_size
|
self.max_upload_size = hs.config.media.max_upload_size
|
||||||
self.max_image_pixels = hs.config.media.max_image_pixels
|
self.max_image_pixels = hs.config.media.max_image_pixels
|
||||||
|
self.unused_expiration_time = hs.config.media.unused_expiration_time
|
||||||
|
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
|
||||||
|
|
||||||
Thumbnailer.set_limits(self.max_image_pixels)
|
Thumbnailer.set_limits(self.max_image_pixels)
|
||||||
|
|
||||||
|
@ -185,6 +190,117 @@ class MediaRepository:
|
||||||
else:
|
else:
|
||||||
self.recently_accessed_locals.add(media_id)
|
self.recently_accessed_locals.add(media_id)
|
||||||
|
|
||||||
|
@trace
|
||||||
|
async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
|
||||||
|
"""Create and store a media ID for a local user and return the MXC URI and its
|
||||||
|
expiration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_user: The user_id of the uploader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the MXC URI of the stored content and the timestamp at
|
||||||
|
which the MXC URI expires.
|
||||||
|
"""
|
||||||
|
media_id = random_string(24)
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
await self.store.store_local_media_id(
|
||||||
|
media_id=media_id,
|
||||||
|
time_now_ms=now,
|
||||||
|
user_id=auth_user,
|
||||||
|
)
|
||||||
|
return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
|
||||||
|
|
||||||
|
@trace
|
||||||
|
async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
|
||||||
|
"""Check if the user is over the limit for pending media uploads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_user: The user_id of the uploader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with a boolean and an integer indicating whether the user has too
|
||||||
|
many pending media uploads and the timestamp at which the first pending
|
||||||
|
media will expire, respectively.
|
||||||
|
"""
|
||||||
|
pending, first_expiration_ts = await self.store.count_pending_media(
|
||||||
|
user_id=auth_user
|
||||||
|
)
|
||||||
|
return pending >= self.max_pending_media_uploads, first_expiration_ts
|
||||||
|
|
||||||
|
@trace
|
||||||
|
async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
|
||||||
|
"""Verify that the media ID can be uploaded to by the given user. This
|
||||||
|
function checks that:
|
||||||
|
|
||||||
|
* the media ID exists
|
||||||
|
* the media ID does not already have content
|
||||||
|
* the user uploading is the same as the one who created the media ID
|
||||||
|
* the media ID has not expired
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_id: The media ID to verify
|
||||||
|
auth_user: The user_id of the uploader
|
||||||
|
"""
|
||||||
|
media = await self.store.get_local_media(media_id)
|
||||||
|
if media is None:
|
||||||
|
raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
|
||||||
|
|
||||||
|
if media.user_id != auth_user.to_string():
|
||||||
|
raise SynapseError(
|
||||||
|
403,
|
||||||
|
"Only the creator of the media ID can upload to it",
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if media.media_length is not None:
|
||||||
|
raise SynapseError(
|
||||||
|
409,
|
||||||
|
"Media ID already has content",
|
||||||
|
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
|
||||||
|
)
|
||||||
|
|
||||||
|
expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
|
||||||
|
if media.created_ts < expired_time_ms:
|
||||||
|
raise NotFoundError("Media ID has expired")
|
||||||
|
|
||||||
|
@trace
|
||||||
|
async def update_content(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
media_type: str,
|
||||||
|
upload_name: Optional[str],
|
||||||
|
content: IO,
|
||||||
|
content_length: int,
|
||||||
|
auth_user: UserID,
|
||||||
|
) -> None:
|
||||||
|
"""Update the content of the given media ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
media_id: The media ID to replace.
|
||||||
|
media_type: The content type of the file.
|
||||||
|
upload_name: The name of the file, if provided.
|
||||||
|
content: A file like object that is the content to store
|
||||||
|
content_length: The length of the content
|
||||||
|
auth_user: The user_id of the uploader
|
||||||
|
"""
|
||||||
|
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||||
|
fname = await self.media_storage.store_file(content, file_info)
|
||||||
|
logger.info("Stored local media in file %r", fname)
|
||||||
|
|
||||||
|
await self.store.update_local_media(
|
||||||
|
media_id=media_id,
|
||||||
|
media_type=media_type,
|
||||||
|
upload_name=upload_name,
|
||||||
|
media_length=content_length,
|
||||||
|
user_id=auth_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("Failed to generate thumbnails: %s", e)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def create_content(
|
async def create_content(
|
||||||
self,
|
self,
|
||||||
|
@ -231,8 +347,74 @@ class MediaRepository:
|
||||||
|
|
||||||
return MXCUri(self.server_name, media_id)
|
return MXCUri(self.server_name, media_id)
|
||||||
|
|
||||||
|
def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
|
||||||
|
respond_with_json(
|
||||||
|
request,
|
||||||
|
504,
|
||||||
|
cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
|
||||||
|
send_cors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_local_media_info(
|
||||||
|
self, request: SynapseRequest, media_id: str, max_timeout_ms: int
|
||||||
|
) -> Optional[LocalMedia]:
|
||||||
|
"""Gets the info dictionary for given local media ID. If the media has
|
||||||
|
not been uploaded yet, this function will wait up to ``max_timeout_ms``
|
||||||
|
milliseconds for the media to be uploaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request.
|
||||||
|
media_id: The media ID of the content. (This is the same as
|
||||||
|
the file_id for local content.)
|
||||||
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the info dictionary for the given local media ID or
|
||||||
|
``None``. If ``None``, then no further processing is necessary as
|
||||||
|
this function will send the necessary JSON response.
|
||||||
|
"""
|
||||||
|
wait_until = self.clock.time_msec() + max_timeout_ms
|
||||||
|
while True:
|
||||||
|
# Get the info for the media
|
||||||
|
media_info = await self.store.get_local_media(media_id)
|
||||||
|
if not media_info:
|
||||||
|
logger.info("Media %s is unknown", media_id)
|
||||||
|
respond_404(request)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if media_info.quarantined_by:
|
||||||
|
logger.info("Media %s is quarantined", media_id)
|
||||||
|
respond_404(request)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# The file has been uploaded, so stop looping
|
||||||
|
if media_info.media_length is not None:
|
||||||
|
return media_info
|
||||||
|
|
||||||
|
# Check if the media ID has expired and still hasn't been uploaded to.
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
expired_time_ms = now - self.unused_expiration_time
|
||||||
|
if media_info.created_ts < expired_time_ms:
|
||||||
|
logger.info("Media %s has expired without being uploaded", media_id)
|
||||||
|
respond_404(request)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if now >= wait_until:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.clock.sleep(0.5)
|
||||||
|
|
||||||
|
logger.info("Media %s has not yet been uploaded", media_id)
|
||||||
|
self.respond_not_yet_uploaded(request)
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_local_media(
|
async def get_local_media(
|
||||||
self, request: SynapseRequest, media_id: str, name: Optional[str]
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
media_id: str,
|
||||||
|
name: Optional[str],
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Responds to requests for local media, if exists, or returns 404.
|
"""Responds to requests for local media, if exists, or returns 404.
|
||||||
|
|
||||||
|
@ -242,13 +424,14 @@ class MediaRepository:
|
||||||
the file_id for local content.)
|
the file_id for local content.)
|
||||||
name: Optional name that, if specified, will be used as
|
name: Optional name that, if specified, will be used as
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
"""
|
"""
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
|
||||||
if not media_info or media_info.quarantined_by:
|
if not media_info:
|
||||||
respond_404(request)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.mark_recently_accessed(None, media_id)
|
self.mark_recently_accessed(None, media_id)
|
||||||
|
@ -273,6 +456,7 @@ class MediaRepository:
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
name: Optional[str],
|
name: Optional[str],
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Respond to requests for remote media.
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
|
@ -282,6 +466,8 @@ class MediaRepository:
|
||||||
media_id: The media ID of the content (as defined by the remote server).
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
name: Optional name that, if specified, will be used as
|
name: Optional name that, if specified, will be used as
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
|
@ -307,11 +493,11 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id
|
server_name, media_id, max_timeout_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
# We deliberately stream the file outside the lock
|
# We deliberately stream the file outside the lock
|
||||||
if responder:
|
if responder and media_info:
|
||||||
upload_name = name if name else media_info.upload_name
|
upload_name = name if name else media_info.upload_name
|
||||||
await respond_with_responder(
|
await respond_with_responder(
|
||||||
request,
|
request,
|
||||||
|
@ -324,7 +510,7 @@ class MediaRepository:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
async def get_remote_media_info(
|
async def get_remote_media_info(
|
||||||
self, server_name: str, media_id: str
|
self, server_name: str, media_id: str, max_timeout_ms: int
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Gets the media info associated with the remote file, downloading
|
"""Gets the media info associated with the remote file, downloading
|
||||||
if necessary.
|
if necessary.
|
||||||
|
@ -332,6 +518,8 @@ class MediaRepository:
|
||||||
Args:
|
Args:
|
||||||
server_name: Remote server_name where the media originated.
|
server_name: Remote server_name where the media originated.
|
||||||
media_id: The media ID of the content (as defined by the remote server).
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file
|
The media info of the file
|
||||||
|
@ -347,7 +535,7 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id
|
server_name, media_id, max_timeout_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure we actually use the responder so that it releases resources
|
# Ensure we actually use the responder so that it releases resources
|
||||||
|
@ -358,7 +546,7 @@ class MediaRepository:
|
||||||
return media_info
|
return media_info
|
||||||
|
|
||||||
async def _get_remote_media_impl(
|
async def _get_remote_media_impl(
|
||||||
self, server_name: str, media_id: str
|
self, server_name: str, media_id: str, max_timeout_ms: int
|
||||||
) -> Tuple[Optional[Responder], RemoteMedia]:
|
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||||
"""Looks for media in local cache, if not there then attempt to
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
download from remote server.
|
download from remote server.
|
||||||
|
@ -367,6 +555,8 @@ class MediaRepository:
|
||||||
server_name: Remote server_name where the media originated.
|
server_name: Remote server_name where the media originated.
|
||||||
media_id: The media ID of the content (as defined by the
|
media_id: The media ID of the content (as defined by the
|
||||||
remote server).
|
remote server).
|
||||||
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of responder and the media info of the file.
|
A tuple of responder and the media info of the file.
|
||||||
|
@ -399,8 +589,7 @@ class MediaRepository:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
media_info = await self._download_remote_file(
|
media_info = await self._download_remote_file(
|
||||||
server_name,
|
server_name, media_id, max_timeout_ms
|
||||||
media_id,
|
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
raise
|
raise
|
||||||
|
@ -433,6 +622,7 @@ class MediaRepository:
|
||||||
self,
|
self,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
@ -442,7 +632,8 @@ class MediaRepository:
|
||||||
media_id: The media ID of the content (as defined by the
|
media_id: The media ID of the content (as defined by the
|
||||||
remote server). This is different than the file_id, which is
|
remote server). This is different than the file_id, which is
|
||||||
locally generated.
|
locally generated.
|
||||||
file_id: Local file ID
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
|
media to be uploaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file.
|
The media info of the file.
|
||||||
|
@ -466,7 +657,8 @@ class MediaRepository:
|
||||||
# tell the remote server to 404 if it doesn't
|
# tell the remote server to 404 if it doesn't
|
||||||
# recognise the server_name, to make sure we don't
|
# recognise the server_name, to make sure we don't
|
||||||
# end up with a routing loop.
|
# end up with a routing loop.
|
||||||
"allow_remote": "false"
|
"allow_remote": "false",
|
||||||
|
"timeout_ms": str(max_timeout_ms),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except RequestSendFailed as e:
|
except RequestSendFailed as e:
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2023 Beeper Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
from synapse.http.server import respond_with_json
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.media.media_repository import MediaRepository
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateResource(RestServlet):
|
||||||
|
PATTERNS = [re.compile("/_matrix/media/v1/create")]
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.media_repo = media_repo
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
|
||||||
|
|
||||||
|
# A rate limiter for creating new media IDs.
|
||||||
|
self._create_media_rate_limiter = Ratelimiter(
|
||||||
|
store=hs.get_datastores().main,
|
||||||
|
clock=self.clock,
|
||||||
|
cfg=hs.config.ratelimiting.rc_media_create,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_POST(self, request: SynapseRequest) -> None:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
# If the create media requests for the user are over the limit, drop them.
|
||||||
|
await self._create_media_rate_limiter.ratelimit(requester)
|
||||||
|
|
||||||
|
(
|
||||||
|
reached_pending_limit,
|
||||||
|
first_expiration_ts,
|
||||||
|
) = await self.media_repo.reached_pending_media_limit(requester.user)
|
||||||
|
if reached_pending_limit:
|
||||||
|
raise LimitExceededError(
|
||||||
|
limiter_name="max_pending_media_uploads",
|
||||||
|
retry_after_ms=first_expiration_ts - self.clock.time_msec(),
|
||||||
|
)
|
||||||
|
|
||||||
|
content_uri, unused_expires_at = await self.media_repo.create_media_id(
|
||||||
|
requester.user
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created Media URI %r that if unused will expire at %d",
|
||||||
|
content_uri,
|
||||||
|
unused_expires_at,
|
||||||
|
)
|
||||||
|
respond_with_json(
|
||||||
|
request,
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"content_uri": content_uri,
|
||||||
|
"unused_expires_at": unused_expires_at,
|
||||||
|
},
|
||||||
|
send_cors=True,
|
||||||
|
)
|
|
@ -17,9 +17,13 @@ import re
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.http.server import set_corp_headers, set_cors_headers
|
from synapse.http.server import set_corp_headers, set_cors_headers
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean
|
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.media._base import respond_404
|
from synapse.media._base import (
|
||||||
|
DEFAULT_MAX_TIMEOUT_MS,
|
||||||
|
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||||
|
respond_404,
|
||||||
|
)
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -65,12 +69,16 @@ class DownloadResource(RestServlet):
|
||||||
)
|
)
|
||||||
# Limited non-standard form of CSP for IE11
|
# Limited non-standard form of CSP for IE11
|
||||||
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
|
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
|
||||||
request.setHeader(
|
request.setHeader(b"Referrer-Policy", b"no-referrer")
|
||||||
b"Referrer-Policy",
|
max_timeout_ms = parse_integer(
|
||||||
b"no-referrer",
|
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||||
)
|
)
|
||||||
|
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||||
|
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
await self.media_repo.get_local_media(request, media_id, file_name)
|
await self.media_repo.get_local_media(
|
||||||
|
request, media_id, file_name, max_timeout_ms
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||||
if not allow_remote:
|
if not allow_remote:
|
||||||
|
@ -83,5 +91,5 @@ class DownloadResource(RestServlet):
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.media_repo.get_remote_media(
|
await self.media_repo.get_remote_media(
|
||||||
request, server_name, media_id, file_name
|
request, server_name, media_id, file_name, max_timeout_ms
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,10 +18,11 @@ from synapse.config._base import ConfigError
|
||||||
from synapse.http.server import HttpServer, JsonResource
|
from synapse.http.server import HttpServer, JsonResource
|
||||||
|
|
||||||
from .config_resource import MediaConfigResource
|
from .config_resource import MediaConfigResource
|
||||||
|
from .create_resource import CreateResource
|
||||||
from .download_resource import DownloadResource
|
from .download_resource import DownloadResource
|
||||||
from .preview_url_resource import PreviewUrlResource
|
from .preview_url_resource import PreviewUrlResource
|
||||||
from .thumbnail_resource import ThumbnailResource
|
from .thumbnail_resource import ThumbnailResource
|
||||||
from .upload_resource import UploadResource
|
from .upload_resource import AsyncUploadServlet, UploadServlet
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource):
|
||||||
|
|
||||||
# Note that many of these should not exist as v1 endpoints, but empirically
|
# Note that many of these should not exist as v1 endpoints, but empirically
|
||||||
# a lot of traffic still goes to them.
|
# a lot of traffic still goes to them.
|
||||||
|
CreateResource(hs, media_repo).register(http_server)
|
||||||
UploadResource(hs, media_repo).register(http_server)
|
UploadServlet(hs, media_repo).register(http_server)
|
||||||
|
AsyncUploadServlet(hs, media_repo).register(http_server)
|
||||||
DownloadResource(hs, media_repo).register(http_server)
|
DownloadResource(hs, media_repo).register(http_server)
|
||||||
ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
||||||
http_server
|
http_server
|
||||||
|
|
|
@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he
|
||||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.media._base import (
|
from synapse.media._base import (
|
||||||
|
DEFAULT_MAX_TIMEOUT_MS,
|
||||||
|
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
|
||||||
FileInfo,
|
FileInfo,
|
||||||
ThumbnailInfo,
|
ThumbnailInfo,
|
||||||
respond_404,
|
respond_404,
|
||||||
|
@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet):
|
||||||
method = parse_string(request, "method", "scale")
|
method = parse_string(request, "method", "scale")
|
||||||
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
||||||
m_type = "image/png"
|
m_type = "image/png"
|
||||||
|
max_timeout_ms = parse_integer(
|
||||||
|
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||||
|
)
|
||||||
|
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||||
|
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
await self._select_or_generate_local_thumbnail(
|
await self._select_or_generate_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._respond_local_thumbnail(
|
await self._respond_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type, max_timeout_ms
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(None, media_id)
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
else:
|
else:
|
||||||
|
@ -95,13 +101,20 @@ class ThumbnailResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.dynamic_thumbnails:
|
remote_resp_function = (
|
||||||
await self._select_or_generate_remote_thumbnail(
|
self._select_or_generate_remote_thumbnail
|
||||||
request, server_name, media_id, width, height, method, m_type
|
if self.dynamic_thumbnails
|
||||||
|
else self._respond_remote_thumbnail
|
||||||
)
|
)
|
||||||
else:
|
await remote_resp_function(
|
||||||
await self._respond_remote_thumbnail(
|
request,
|
||||||
request, server_name, media_id, width, height, method, m_type
|
server_name,
|
||||||
|
media_id,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
method,
|
||||||
|
m_type,
|
||||||
|
max_timeout_ms,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
|
@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet):
|
||||||
height: int,
|
height: int,
|
||||||
method: str,
|
method: str,
|
||||||
m_type: str,
|
m_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.media_repo.get_local_media_info(
|
||||||
|
request, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
if media_info.quarantined_by:
|
|
||||||
logger.info("Media is quarantined")
|
|
||||||
respond_404(request)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||||
|
@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet):
|
||||||
desired_height: int,
|
desired_height: int,
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
desired_type: str,
|
desired_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.media_repo.get_local_media_info(
|
||||||
|
request, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
if media_info.quarantined_by:
|
|
||||||
logger.info("Media is quarantined")
|
|
||||||
respond_404(request)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||||
|
@ -206,8 +214,14 @@ class ThumbnailResource(RestServlet):
|
||||||
desired_height: int,
|
desired_height: int,
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
desired_type: str,
|
desired_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
|
server_name, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
if not media_info:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
|
@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet):
|
||||||
height: int,
|
height: int,
|
||||||
method: str,
|
method: str,
|
||||||
m_type: str,
|
m_type: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Don't download the whole remote file
|
# TODO: Don't download the whole remote file
|
||||||
# We should proxy the thumbnail from the remote server instead of
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
# downloading the remote file and generating our own thumbnails.
|
# downloading the remote file and generating our own thumbnails.
|
||||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
|
server_name, media_id, max_timeout_ms
|
||||||
|
)
|
||||||
|
if not media_info:
|
||||||
|
return
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional
|
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import respond_with_json
|
from synapse.http.server import respond_with_json
|
||||||
|
@ -29,23 +29,24 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# The name of the lock to use when uploading media.
|
||||||
|
_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
|
||||||
|
|
||||||
class UploadResource(RestServlet):
|
|
||||||
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
|
|
||||||
|
|
||||||
|
class BaseUploadServlet(RestServlet):
|
||||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.filepaths = media_repo.filepaths
|
self.filepaths = media_repo.filepaths
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.clock = hs.get_clock()
|
self.server_name = hs.hostname
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.max_upload_size = hs.config.media.max_upload_size
|
self.max_upload_size = hs.config.media.max_upload_size
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> None:
|
def _get_file_metadata(
|
||||||
requester = await self.auth.get_user_by_req(request)
|
self, request: SynapseRequest
|
||||||
|
) -> Tuple[int, Optional[str], str]:
|
||||||
raw_content_length = request.getHeader("Content-Length")
|
raw_content_length = request.getHeader("Content-Length")
|
||||||
if raw_content_length is None:
|
if raw_content_length is None:
|
||||||
raise SynapseError(msg="Request must specify a Content-Length", code=400)
|
raise SynapseError(msg="Request must specify a Content-Length", code=400)
|
||||||
|
@ -88,6 +89,16 @@ class UploadResource(RestServlet):
|
||||||
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
||||||
# TODO(markjh): parse content-dispostion
|
# TODO(markjh): parse content-dispostion
|
||||||
|
|
||||||
|
return content_length, upload_name, media_type
|
||||||
|
|
||||||
|
|
||||||
|
class UploadServlet(BaseUploadServlet):
|
||||||
|
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")]
|
||||||
|
|
||||||
|
async def on_POST(self, request: SynapseRequest) -> None:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
content_length, upload_name, media_type = self._get_file_metadata(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content: IO = request.content # type: ignore
|
content: IO = request.content # type: ignore
|
||||||
content_uri = await self.media_repo.create_content(
|
content_uri = await self.media_repo.create_content(
|
||||||
|
@ -103,3 +114,53 @@ class UploadResource(RestServlet):
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request, 200, {"content_uri": str(content_uri)}, send_cors=True
|
request, 200, {"content_uri": str(content_uri)}, send_cors=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncUploadServlet(BaseUploadServlet):
|
||||||
|
PATTERNS = [
|
||||||
|
re.compile(
|
||||||
|
"/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
|
) -> None:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
if server_name != self.server_name:
|
||||||
|
raise SynapseError(
|
||||||
|
404,
|
||||||
|
"Non-local server name specified",
|
||||||
|
errcode=Codes.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
|
||||||
|
if not lock:
|
||||||
|
raise SynapseError(
|
||||||
|
409,
|
||||||
|
"Media ID cannot be overwritten",
|
||||||
|
errcode=Codes.CANNOT_OVERWRITE_MEDIA,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
await self.media_repo.verify_can_upload(media_id, requester.user)
|
||||||
|
content_length, upload_name, media_type = self._get_file_metadata(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content: IO = request.content # type: ignore
|
||||||
|
await self.media_repo.update_content(
|
||||||
|
media_id,
|
||||||
|
media_type,
|
||||||
|
upload_name,
|
||||||
|
content,
|
||||||
|
content_length,
|
||||||
|
requester.user,
|
||||||
|
)
|
||||||
|
except SpamMediaException:
|
||||||
|
# For uploading of media we want to respond with a 400, instead of
|
||||||
|
# the default 404, as that would just be confusing.
|
||||||
|
raise SynapseError(400, "Bad content")
|
||||||
|
|
||||||
|
logger.info("Uploaded content for media ID %r", media_id)
|
||||||
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
|
@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
|
||||||
class LocalMedia:
|
class LocalMedia:
|
||||||
media_id: str
|
media_id: str
|
||||||
media_type: str
|
media_type: str
|
||||||
media_length: int
|
media_length: Optional[int]
|
||||||
upload_name: str
|
upload_name: str
|
||||||
created_ts: int
|
created_ts: int
|
||||||
url_cache: Optional[str]
|
url_cache: Optional[str]
|
||||||
last_access_ts: int
|
last_access_ts: int
|
||||||
quarantined_by: Optional[str]
|
quarantined_by: Optional[str]
|
||||||
safe_from_quarantine: bool
|
safe_from_quarantine: bool
|
||||||
|
user_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
||||||
self._drop_media_index_without_method,
|
self._drop_media_index_without_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hs.config.media.can_load_media_repo:
|
||||||
|
self.unused_expiration_time: Optional[
|
||||||
|
int
|
||||||
|
] = hs.config.media.unused_expiration_time
|
||||||
|
else:
|
||||||
|
self.unused_expiration_time = None
|
||||||
|
|
||||||
async def _drop_media_index_without_method(
|
async def _drop_media_index_without_method(
|
||||||
self, progress: JsonDict, batch_size: int
|
self, progress: JsonDict, batch_size: int
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"url_cache",
|
"url_cache",
|
||||||
"last_access_ts",
|
"last_access_ts",
|
||||||
"safe_from_quarantine",
|
"safe_from_quarantine",
|
||||||
|
"user_id",
|
||||||
),
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_local_media",
|
desc="get_local_media",
|
||||||
|
@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
url_cache=row[5],
|
url_cache=row[5],
|
||||||
last_access_ts=row[6],
|
last_access_ts=row[6],
|
||||||
safe_from_quarantine=row[7],
|
safe_from_quarantine=row[7],
|
||||||
|
user_id=row[8],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_local_media_by_user_paginate(
|
async def get_local_media_by_user_paginate(
|
||||||
|
@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
url_cache,
|
url_cache,
|
||||||
last_access_ts,
|
last_access_ts,
|
||||||
quarantined_by,
|
quarantined_by,
|
||||||
safe_from_quarantine
|
safe_from_quarantine,
|
||||||
|
user_id
|
||||||
FROM local_media_repository
|
FROM local_media_repository
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
ORDER BY {order_by_column} {order}, media_id ASC
|
ORDER BY {order_by_column} {order}, media_id ASC
|
||||||
|
@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
last_access_ts=row[6],
|
last_access_ts=row[6],
|
||||||
quarantined_by=row[7],
|
quarantined_by=row[7],
|
||||||
safe_from_quarantine=bool(row[8]),
|
safe_from_quarantine=bool(row[8]),
|
||||||
|
user_id=row[9],
|
||||||
)
|
)
|
||||||
for row in txn
|
for row in txn
|
||||||
]
|
]
|
||||||
|
@ -391,6 +403,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"get_local_media_ids", _get_local_media_ids_txn
|
"get_local_media_ids", _get_local_media_ids_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@trace
|
||||||
|
async def store_local_media_id(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
time_now_ms: int,
|
||||||
|
user_id: UserID,
|
||||||
|
) -> None:
|
||||||
|
await self.db_pool.simple_insert(
|
||||||
|
"local_media_repository",
|
||||||
|
{
|
||||||
|
"media_id": media_id,
|
||||||
|
"created_ts": time_now_ms,
|
||||||
|
"user_id": user_id.to_string(),
|
||||||
|
},
|
||||||
|
desc="store_local_media_id",
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def store_local_media(
|
async def store_local_media(
|
||||||
self,
|
self,
|
||||||
|
@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="store_local_media",
|
desc="store_local_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_local_media(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
media_type: str,
|
||||||
|
upload_name: Optional[str],
|
||||||
|
media_length: int,
|
||||||
|
user_id: UserID,
|
||||||
|
url_cache: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
await self.db_pool.simple_update_one(
|
||||||
|
"local_media_repository",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id.to_string(),
|
||||||
|
"media_id": media_id,
|
||||||
|
},
|
||||||
|
updatevalues={
|
||||||
|
"media_type": media_type,
|
||||||
|
"upload_name": upload_name,
|
||||||
|
"media_length": media_length,
|
||||||
|
"url_cache": url_cache,
|
||||||
|
},
|
||||||
|
desc="update_local_media",
|
||||||
|
)
|
||||||
|
|
||||||
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
|
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
|
||||||
"""Mark a local media as safe or unsafe from quarantining."""
|
"""Mark a local media as safe or unsafe from quarantining."""
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
|
@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
desc="mark_local_media_as_safe",
|
desc="mark_local_media_as_safe",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
|
||||||
|
"""Count the number of pending media for a user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of two integers: the total pending media requests and the earliest
|
||||||
|
expiration timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
|
||||||
|
sql = """
|
||||||
|
SELECT COUNT(*), MIN(created_ts)
|
||||||
|
FROM local_media_repository
|
||||||
|
WHERE user_id = ?
|
||||||
|
AND created_ts > ?
|
||||||
|
AND media_length IS NULL
|
||||||
|
"""
|
||||||
|
assert self.unused_expiration_time is not None
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
user_id.to_string(),
|
||||||
|
self._clock.time_msec() - self.unused_expiration_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
if not row:
|
||||||
|
return 0, 0
|
||||||
|
return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_pending_media", get_pending_media_txn
|
||||||
|
)
|
||||||
|
|
||||||
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
|
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
|
||||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -318,7 +318,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
|
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
|
||||||
)
|
)
|
||||||
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
|
self.assertEqual(
|
||||||
|
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
|
||||||
|
)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||||
|
|
Loading…
Reference in New Issue