diff --git a/changelog.d/17543.bugfix b/changelog.d/17543.bugfix new file mode 100644 index 0000000000..152b305e58 --- /dev/null +++ b/changelog.d/17543.bugfix @@ -0,0 +1 @@ +Fix authenticated media responses using a wrong limit when following redirects over federation. \ No newline at end of file diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 6fd75fd381..12c41c39e9 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -464,6 +464,8 @@ class MatrixFederationHttpClient: self.max_long_retries = hs.config.federation.max_long_retries self.max_short_retries = hs.config.federation.max_short_retries + self.max_download_size = hs.config.media.max_upload_size + self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) self._sleeper = AwakenableSleeper(self.reactor) @@ -1756,8 +1758,10 @@ class MatrixFederationHttpClient: request.destination, str_url, ) + # We don't know how large the response will be upfront, so limit it to + # the `max_upload_size` config value. length, headers, _, _ = await self._simple_http_client.get_file( - str_url, output_stream, expected_size + str_url, output_stream, self.max_download_size ) logger.info( diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index e2f033fdae..6827412373 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -17,6 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # +import io from typing import Any, Dict, Generator from unittest.mock import ANY, Mock, create_autospec @@ -32,7 +33,9 @@ from twisted.web.http import HTTPChannel from twisted.web.http_headers import Headers from synapse.api.errors import HttpResponseException, RequestSendFailed +from synapse.api.ratelimiting import Ratelimiter from synapse.config._base import ConfigError +from synapse.config.ratelimiting import RatelimitSettings from synapse.http.matrixfederationclient import ( ByteParser, MatrixFederationHttpClient, @@ -337,6 +340,81 @@ class FederationClientTests(HomeserverTestCase): r = self.successResultOf(d) self.assertEqual(r.code, 200) + def test_authed_media_redirect_response(self) -> None: + """ + Validate that, when following a `Location` redirect, the + maximum size is _not_ set to the initial response `Content-Length` and + the media file can be downloaded. + """ + limiter = Ratelimiter( + store=self.hs.get_datastores().main, + clock=self.clock, + cfg=RatelimitSettings(key="", per_second=0.17, burst_count=1048576), + ) + + output_stream = io.BytesIO() + + d = defer.ensureDeferred( + self.cl.federation_get_file( + "testserv:8008", "path", output_stream, limiter, "127.0.0.1", 10000 + ) + ) + + self.pump() + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # Deferred does not have a result + self.assertNoResult(d) + + redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: http://testserv:8008/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Length: %i\r\n" + b"Content-Type: multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a\r\n\r\n" + % (len(redirect_data)) + ) + protocol.dataReceived(redirect_data) + + # Still no result, not followed the redirect yet + self.assertNoResult(d) + + # Now send the response returned by the server at `Location` + clients = self.reactor.tcpClients + (host, port, factory, _timeout, _bindAddress) = clients[1] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # make sure the length is longer than the initial response + data = b"Hello world!" * 30 + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Length: %i\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"%s\r\n" + b"\r\n" % (len(data), data) + ) + + # We should get a successful response + length, _, _ = self.successResultOf(d) + self.assertEqual(length, len(data)) + self.assertEqual(output_stream.getvalue(), data) + @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) def test_timeout_reading_body(self, method_name: str) -> None: """