Remove usage of internal header encoding API (#17894)

```py
from twisted.web.http_headers import Headers

Headers()._canonicalNameCaps
Headers()._encodeName
```

Introduced in https://github.com/matrix-org/synapse/pull/15913 <-
https://github.com/matrix-org/synapse/pull/15773
This commit is contained in:
Eric Eastwood 2024-11-04 12:20:07 -06:00 committed by GitHub
parent 9c0a3963bc
commit 2c9ed5e510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 36 deletions

1
changelog.d/17894.misc Normal file
View File

@ -0,0 +1 @@
Remove usage of internal header encoding API.

View File

@ -51,25 +51,17 @@ logger = logging.getLogger(__name__)
# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 # "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616
# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be # section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be
# consumed by the immediate recipient and not be forwarded on. # consumed by the immediate recipient and not be forwarded on.
HOP_BY_HOP_HEADERS = { HOP_BY_HOP_HEADERS_LOWERCASE = {
"Connection", "connection",
"Keep-Alive", "keep-alive",
"Proxy-Authenticate", "proxy-authenticate",
"Proxy-Authorization", "proxy-authorization",
"TE", "te",
"Trailers", "trailers",
"Transfer-Encoding", "transfer-encoding",
"Upgrade", "upgrade",
} }
assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE)
if hasattr(Headers, "_canonicalNameCaps"):
# Twisted < 24.7.0rc1
_canonicalHeaderName = Headers()._canonicalNameCaps # type: ignore[attr-defined]
else:
# Twisted >= 24.7.0rc1
# But note that `_encodeName` still exists on prior versions,
# it just encodes differently
_canonicalHeaderName = Headers()._encodeName
def parse_connection_header_value( def parse_connection_header_value(
@ -92,12 +84,12 @@ def parse_connection_header_value(
Returns: Returns:
The set of header names that should not be copied over from the remote response. The set of header names that should not be copied over from the remote response.
The keys are capitalized in canonical capitalization. The keys are lowercased.
""" """
extra_headers_to_remove: Set[str] = set() extra_headers_to_remove: Set[str] = set()
if connection_header_value: if connection_header_value:
extra_headers_to_remove = { extra_headers_to_remove = {
_canonicalHeaderName(connection_option.strip()).decode("ascii") connection_option.decode("ascii").strip().lower()
for connection_option in connection_header_value.split(b",") for connection_option in connection_header_value.split(b",")
} }
@ -194,7 +186,7 @@ class ProxyResource(_AsyncResource):
# The `Connection` header also defines which headers should not be copied over. # The `Connection` header also defines which headers should not be copied over.
connection_header = response_headers.getRawHeaders(b"connection") connection_header = response_headers.getRawHeaders(b"connection")
extra_headers_to_remove = parse_connection_header_value( extra_headers_to_remove_lowercase = parse_connection_header_value(
connection_header[0] if connection_header else None connection_header[0] if connection_header else None
) )
@ -202,10 +194,10 @@ class ProxyResource(_AsyncResource):
for k, v in response_headers.getAllRawHeaders(): for k, v in response_headers.getAllRawHeaders():
# Do not copy over any hop-by-hop headers. These are meant to only be # Do not copy over any hop-by-hop headers. These are meant to only be
# consumed by the immediate recipient and not be forwarded on. # consumed by the immediate recipient and not be forwarded on.
header_key = k.decode("ascii") header_key_lowercase = k.decode("ascii").lower()
if ( if (
header_key in HOP_BY_HOP_HEADERS header_key_lowercase in HOP_BY_HOP_HEADERS_LOWERCASE
or header_key in extra_headers_to_remove or header_key_lowercase in extra_headers_to_remove_lowercase
): ):
continue continue

View File

@ -903,12 +903,19 @@ class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
headers=Headers( headers=Headers(
{ {
"Content-Type": ["application/json"], "Content-Type": ["application/json"],
"Connection": ["close, X-Foo, X-Bar"], "X-Test": ["test"],
# Define some hop-by-hop headers (try with varying casing to
# make sure we still match-up the headers)
"Connection": ["close, X-fOo, X-Bar, X-baz"],
# Should be removed because it's defined in the `Connection` header # Should be removed because it's defined in the `Connection` header
"X-Foo": ["foo"], "X-Foo": ["foo"],
"X-Bar": ["bar"], "X-Bar": ["bar"],
# (not in canonical case)
"x-baZ": ["baz"],
# Should be removed because it's a hop-by-hop header # Should be removed because it's a hop-by-hop header
"Proxy-Authorization": "abcdef", "Proxy-Authorization": "abcdef",
# Should be removed because it's a hop-by-hop header (not in canonical case)
"transfer-EnCoDiNg": "abcdef",
} }
), ),
) )
@ -938,9 +945,17 @@ class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
header_names = set(headers.keys()) header_names = set(headers.keys())
# Make sure the response does not include the hop-by-hop headers # Make sure the response does not include the hop-by-hop headers
self.assertNotIn(b"X-Foo", header_names) self.assertIncludes(
self.assertNotIn(b"X-Bar", header_names) header_names,
self.assertNotIn(b"Proxy-Authorization", header_names) {
b"Content-Type",
b"X-Test",
# Default headers from Twisted
b"Date",
b"Server",
},
exact=True,
)
# Make sure the response is as expected back on the main worker # Make sure the response is as expected back on the main worker
self.assertEqual(res, {"foo": "bar"}) self.assertEqual(res, {"foo": "bar"})

View File

@ -22,27 +22,42 @@ from typing import Set
from parameterized import parameterized from parameterized import parameterized
from synapse.http.proxy import parse_connection_header_value from synapse.http.proxy import (
HOP_BY_HOP_HEADERS_LOWERCASE,
parse_connection_header_value,
)
from tests.unittest import TestCase from tests.unittest import TestCase
def mix_case(s: str) -> str:
"""
Mix up the case of each character in the string (upper or lower case)
"""
return "".join(c.upper() if i % 2 == 0 else c.lower() for i, c in enumerate(s))
class ProxyTests(TestCase): class ProxyTests(TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# No whitespace # No whitespace
[b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}], [b"close,X-Foo,X-Bar", {"close", "x-foo", "x-bar"}],
# More whitespace # More whitespace
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# "close" directive in not the first position # "close" directive in not the first position
[b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}], [b"X-Foo, X-Bar, close", {"x-foo", "x-bar", "close"}],
# Normalizes header capitalization # Normalizes header capitalization
[b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}], [b"keep-alive, x-fOo, x-bAr", {"keep-alive", "x-foo", "x-bar"}],
# Handles header names with whitespace # Handles header names with whitespace
[ [
b"keep-alive, x foo, x bar", b"keep-alive, x foo, x bar",
{"Keep-Alive", "X foo", "X bar"}, {"keep-alive", "x foo", "x bar"},
],
# Make sure we handle all of the hop-by-hop headers
[
mix_case(", ".join(HOP_BY_HOP_HEADERS_LOWERCASE)).encode("ascii"),
HOP_BY_HOP_HEADERS_LOWERCASE,
], ],
] ]
) )
@ -54,7 +69,8 @@ class ProxyTests(TestCase):
""" """
Tests that the connection header value is parsed correctly Tests that the connection header value is parsed correctly
""" """
self.assertEqual( self.assertIncludes(
expected_extra_headers_to_remove, expected_extra_headers_to_remove,
parse_connection_header_value(connection_header_value), parse_connection_header_value(connection_header_value),
exact=True,
) )