Fixes to `federation_client` dev script (#14479)
* Attempt to fix federation-client devscript handling of .well-known The script was setting the wrong value in the Host header * Fix TLS verification Turns out that actually doing TLS verification isn't that hard. Let's enable it.
This commit is contained in:
parent
e1b15f25f3
commit
8d133a8464
|
@ -0,0 +1 @@
|
||||||
|
`scripts-dev/federation_client`: Fix routing on servers with `.well-known` files.
|
|
@ -46,11 +46,12 @@ import signedjson.key
|
||||||
import signedjson.types
|
import signedjson.types
|
||||||
import srvlookup
|
import srvlookup
|
||||||
import yaml
|
import yaml
|
||||||
|
from requests import PreparedRequest, Response
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from urllib3 import HTTPConnectionPool
|
from urllib3 import HTTPConnectionPool
|
||||||
|
|
||||||
# uncomment the following to enable debug logging of http requests
|
# uncomment the following to enable debug logging of http requests
|
||||||
# from httplib import HTTPConnection
|
# from http.client import HTTPConnection
|
||||||
# HTTPConnection.debuglevel = 1
|
# HTTPConnection.debuglevel = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,6 +104,7 @@ def request(
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
content: Optional[str],
|
content: Optional[str],
|
||||||
|
verify_tls: bool,
|
||||||
) -> requests.Response:
|
) -> requests.Response:
|
||||||
if method is None:
|
if method is None:
|
||||||
if content is None:
|
if content is None:
|
||||||
|
@ -141,7 +143,6 @@ def request(
|
||||||
s.mount("matrix://", MatrixConnectionAdapter())
|
s.mount("matrix://", MatrixConnectionAdapter())
|
||||||
|
|
||||||
headers: Dict[str, str] = {
|
headers: Dict[str, str] = {
|
||||||
"Host": destination,
|
|
||||||
"Authorization": authorization_headers[0],
|
"Authorization": authorization_headers[0],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +153,7 @@ def request(
|
||||||
method=method,
|
method=method,
|
||||||
url=dest,
|
url=dest,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
verify=False,
|
verify=verify_tls,
|
||||||
data=content,
|
data=content,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
@ -202,6 +203,12 @@ def main() -> None:
|
||||||
|
|
||||||
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
|
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--insecure",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable TLS certificate verification",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"path", help="request path, including the '/_matrix/federation/...' prefix."
|
"path", help="request path, including the '/_matrix/federation/...' prefix."
|
||||||
)
|
)
|
||||||
|
@ -227,6 +234,7 @@ def main() -> None:
|
||||||
args.destination,
|
args.destination,
|
||||||
args.path,
|
args.path,
|
||||||
content=args.body,
|
content=args.body,
|
||||||
|
verify_tls=not args.insecure,
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||||
|
@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:
|
||||||
|
|
||||||
|
|
||||||
class MatrixConnectionAdapter(HTTPAdapter):
|
class MatrixConnectionAdapter(HTTPAdapter):
|
||||||
@staticmethod
|
def send(
|
||||||
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
|
self,
|
||||||
if s[-1] == "]":
|
request: PreparedRequest,
|
||||||
# ipv6 literal (with no port)
|
*args: Any,
|
||||||
return s, 8448
|
**kwargs: Any,
|
||||||
|
) -> Response:
|
||||||
|
# overrides the send() method in the base class.
|
||||||
|
|
||||||
if ":" in s:
|
# We need to look for .well-known redirects before passing the request up to
|
||||||
out = s.rsplit(":", 1)
|
# HTTPAdapter.send().
|
||||||
|
assert isinstance(request.url, str)
|
||||||
|
parsed = urlparse.urlsplit(request.url)
|
||||||
|
server_name = parsed.netloc
|
||||||
|
well_known = self._get_well_known(parsed.netloc)
|
||||||
|
|
||||||
|
if well_known:
|
||||||
|
server_name = well_known
|
||||||
|
|
||||||
|
# replace the scheme in the uri with https, so that cert verification is done
|
||||||
|
# also replace the hostname if we got a .well-known result
|
||||||
|
request.url = urlparse.urlunsplit(
|
||||||
|
("https", server_name, parsed.path, parsed.query, parsed.fragment)
|
||||||
|
)
|
||||||
|
|
||||||
|
# at this point we also add the host header (otherwise urllib will add one
|
||||||
|
# based on the `host` from the connection returned by `get_connection`,
|
||||||
|
# which will be wrong if there is an SRV record).
|
||||||
|
request.headers["Host"] = server_name
|
||||||
|
|
||||||
|
return super().send(request, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_connection(
|
||||||
|
self, url: str, proxies: Optional[Dict[str, str]] = None
|
||||||
|
) -> HTTPConnectionPool:
|
||||||
|
# overrides the get_connection() method in the base class
|
||||||
|
parsed = urlparse.urlsplit(url)
|
||||||
|
(host, port, ssl_server_name) = self._lookup(parsed.netloc)
|
||||||
|
print(
|
||||||
|
f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
|
||||||
|
)
|
||||||
|
return self.poolmanager.connection_from_host(
|
||||||
|
host,
|
||||||
|
port=port,
|
||||||
|
scheme="https",
|
||||||
|
pool_kwargs={"server_hostname": ssl_server_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lookup(server_name: str) -> Tuple[str, int, str]:
|
||||||
|
"""
|
||||||
|
Do an SRV lookup on a server name and return the host:port to connect to
|
||||||
|
Given the server_name (after any .well-known lookup), return the host, port and
|
||||||
|
the ssl server name
|
||||||
|
"""
|
||||||
|
if server_name[-1] == "]":
|
||||||
|
# ipv6 literal (with no port)
|
||||||
|
return server_name, 8448, server_name
|
||||||
|
|
||||||
|
if ":" in server_name:
|
||||||
|
# explicit port
|
||||||
|
out = server_name.rsplit(":", 1)
|
||||||
try:
|
try:
|
||||||
port = int(out[1])
|
port = int(out[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("Invalid host:port '%s'" % s)
|
raise ValueError("Invalid host:port '%s'" % (server_name,))
|
||||||
return out[0], port
|
return out[0], port, out[0]
|
||||||
|
|
||||||
# try a .well-known lookup
|
|
||||||
if not skip_well_known:
|
|
||||||
well_known = MatrixConnectionAdapter.get_well_known(s)
|
|
||||||
if well_known:
|
|
||||||
return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
srv = srvlookup.lookup("matrix", "tcp", s)[0]
|
srv = srvlookup.lookup("matrix", "tcp", server_name)[0]
|
||||||
return srv.host, srv.port
|
print(
|
||||||
|
f"SRV lookup on _matrix._tcp.{server_name} gave {srv}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return srv.host, srv.port, server_name
|
||||||
except Exception:
|
except Exception:
|
||||||
return s, 8448
|
return server_name, 8448, server_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_well_known(server_name: str) -> Optional[str]:
|
def _get_well_known(server_name: str) -> Optional[str]:
|
||||||
uri = "https://%s/.well-known/matrix/server" % (server_name,)
|
if ":" in server_name:
|
||||||
print("fetching %s" % (uri,), file=sys.stderr)
|
# explicit port, or ipv6 literal. Either way, no .well-known
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: check for ipv4 literals
|
||||||
|
|
||||||
|
uri = f"https://{server_name}/.well-known/matrix/server"
|
||||||
|
print(f"fetching {uri}", file=sys.stderr)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = requests.get(uri)
|
resp = requests.get(uri)
|
||||||
|
@ -304,19 +369,6 @@ class MatrixConnectionAdapter(HTTPAdapter):
|
||||||
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
|
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_connection(
|
|
||||||
self, url: str, proxies: Optional[Dict[str, str]] = None
|
|
||||||
) -> HTTPConnectionPool:
|
|
||||||
parsed = urlparse.urlparse(url)
|
|
||||||
|
|
||||||
(host, port) = self.lookup(parsed.netloc)
|
|
||||||
netloc = "%s:%d" % (host, port)
|
|
||||||
print("Connecting to %s" % (netloc,), file=sys.stderr)
|
|
||||||
url = urlparse.urlunparse(
|
|
||||||
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
|
|
||||||
)
|
|
||||||
return super().get_connection(url, proxies)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue