Request & follow redirects for /media/v3/download (#16701)

Implement MSC3860 to follow redirects for federated media downloads.

Note that the Client-Server API doesn't support this (yet) since the media
repository in Synapse doesn't have a way of supporting redirects.
This commit is contained in:
Patrick Cloke 2023-11-29 14:03:42 -05:00 committed by GitHub
parent a14678492e
commit d6c3b7584f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 211 additions and 37 deletions

View File

@ -0,0 +1 @@
Follow redirects when downloading media over federation (per [MSC3860](https://github.com/matrix-org/matrix-spec-proposals/pull/3860)).

View File

@ -21,6 +21,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,
Awaitable, Awaitable,
BinaryIO,
Callable, Callable,
Collection, Collection,
Container, Container,
@ -1862,6 +1863,43 @@ class FederationClient(FederationBase):
return filtered_statuses, filtered_failures return filtered_statuses, filtered_failures
async def download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't download media %s/%s with the v3 API, falling back to the r0 API",
destination,
media_id,
)
return await self.transport_layer.download_media_r0(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)
@attr.s(frozen=True, slots=True, auto_attribs=True) @attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse: class TimestampToEventResponse:

View File

@ -18,6 +18,7 @@ import urllib
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
BinaryIO,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -804,6 +805,58 @@ class TransportLayerClient:
destination=destination, path=path, data={"user_ids": user_ids} destination=destination, path=path, data={"user_ids": user_ids}
) )
async def download_media_r0(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
)
async def download_media_v3(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
# Matrix 1.7 allows for this to redirect to another URL, this should
# just be ignored for an old homeserver, so always provide it.
"allow_redirect": "true",
},
follow_redirects=True,
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str: def _create_path(federation_prefix: str, path: str, *args: str) -> str:
""" """

View File

@ -153,12 +153,18 @@ class MatrixFederationRequest:
"""Query arguments. """Query arguments.
""" """
txn_id: Optional[str] = None txn_id: str = attr.ib(init=False)
"""Unique ID for this request (for logging) """Unique ID for this request (for logging), this is autogenerated.
""" """
uri: bytes = attr.ib(init=False) uri: bytes = b""
"""The URI of this request """The URI of this request, usually generated from the above information.
"""
_generate_uri: bool = True
"""True to automatically generate the uri field based on the above information.
Set to False if manually configuring the URI.
""" """
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
@ -168,6 +174,7 @@ class MatrixFederationRequest:
object.__setattr__(self, "txn_id", txn_id) object.__setattr__(self, "txn_id", txn_id)
if self._generate_uri:
destination_bytes = self.destination.encode("ascii") destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii") path_bytes = self.path.encode("ascii")
query_bytes = encode_query_args(self.query) query_bytes = encode_query_args(self.query)
@ -513,6 +520,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
backoff_on_404: bool = False, backoff_on_404: bool = False,
backoff_on_all_error_codes: bool = False, backoff_on_all_error_codes: bool = False,
follow_redirects: bool = False,
) -> IResponse: ) -> IResponse:
""" """
Sends a request to the given server. Sends a request to the given server.
@ -555,6 +563,9 @@ class MatrixFederationHttpClient:
backoff_on_404: Back off if we get a 404 backoff_on_404: Back off if we get a 404
backoff_on_all_error_codes: Back off if we get any error response backoff_on_all_error_codes: Back off if we get any error response
follow_redirects: True to follow the Location header of 307/308 redirect
responses. This does not recurse.
Returns: Returns:
Resolves with the HTTP response object on success. Resolves with the HTTP response object on success.
@ -714,6 +725,26 @@ class MatrixFederationHttpClient:
response.code, response.code,
response_phrase, response_phrase,
) )
elif (
response.code in (307, 308)
and follow_redirects
and response.headers.hasHeader("Location")
):
# The Location header *might* be relative so resolve it.
location = response.headers.getRawHeaders(b"Location")[0]
new_uri = urllib.parse.urljoin(request.uri, location)
return await self._send_request(
attr.evolve(request, uri=new_uri, generate_uri=False),
retry_on_dns_fail,
timeout,
long_retries,
ignore_backoff,
backoff_on_404,
backoff_on_all_error_codes,
# Do not continue following redirects.
follow_redirects=False,
)
else: else:
logger.info( logger.info(
"{%s} [%s] Got response headers: %d %s", "{%s} [%s] Got response headers: %d %s",
@ -1383,6 +1414,7 @@ class MatrixFederationHttpClient:
retry_on_dns_fail: bool = True, retry_on_dns_fail: bool = True,
max_size: Optional[int] = None, max_size: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver """GETs a file from a given homeserver
Args: Args:
@ -1392,6 +1424,8 @@ class MatrixFederationHttpClient:
args: Optional dictionary used to create the query string. args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data ignore_backoff: true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
follow_redirects: True to follow the Location header of 307/308 redirect
responses. This does not recurse.
Returns: Returns:
Resolves with an (int,dict) tuple of Resolves with an (int,dict) tuple of
@ -1412,7 +1446,10 @@ class MatrixFederationHttpClient:
) )
response = await self._send_request( response = await self._send_request(
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
follow_redirects=follow_redirects,
) )
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())

View File

@ -77,7 +77,7 @@ class MediaRepository:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -644,22 +644,13 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish): with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join(
("/_matrix/media/r0/download", server_name, media_id)
)
try: try:
length, headers = await self.client.get_file( length, headers = await self.client.download_media(
server_name, server_name,
request_path, media_id,
output_stream=f, output_stream=f,
max_size=self.max_upload_size, max_size=self.max_upload_size,
args={ max_timeout_ms=max_timeout_ms,
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
) )
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warning( logger.warning(

View File

@ -27,10 +27,11 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import Codes from synapse.api.errors import Codes, HttpResponseException
from synapse.events import EventBase from synapse.events import EventBase
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
retry_on_dns_fail: bool = True, retry_on_dns_fail: bool = True,
max_size: Optional[int] = None, max_size: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file.""" """A mock for MatrixFederationHttpClient.get_file."""
@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
output_stream.write(data) output_stream.write(data)
return response return response
def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args)) self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d. # Note that this callback changes the value held by d.
d_after_callback = d.addCallback(write_to) d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback) return make_deferred_yieldable(d_after_callback)
# Mock out the homeserver's MatrixFederationHttpClient # Mock out the homeserver's MatrixFederationHttpClient
@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.fetches), 1) self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual( self.assertEqual(
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
) )
self.assertEqual( self.assertEqual(
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} self.fetches[0][3],
{"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"},
) )
headers = { headers = {
@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"cross-origin"], [b"cross-origin"],
) )
def test_unknown_v3_endpoint(self) -> None:
"""
If the v3 endpoint fails, try the r0 one.
"""
channel = self.make_request(
"GET",
f"/_matrix/media/v3/download/{self.media_id}",
shorthand=False,
await_result=False,
)
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)
# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)
self.pump()
# There should now be another request to the r0 URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}"
)
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}
self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)
class TestSpamCheckerLegacy: class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes """A spam checker module that rejects all media that includes the bytes

View File

@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
self.assertEqual( self.assertEqual(
request.path, request.path,
f"/_matrix/media/r0/download/{target}/{media_id}".encode(), f"/_matrix/media/v3/download/{target}/{media_id}".encode(),
) )
self.assertEqual( self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]