diff --git a/changelog.d/8130.misc b/changelog.d/8130.misc new file mode 100644 index 0000000000..7944c09ade --- /dev/null +++ b/changelog.d/8130.misc @@ -0,0 +1 @@ +Update the test federation client to handle streaming responses. diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 531010185d..ad12523c4d 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -21,10 +21,12 @@ import argparse import base64 import json import sys +from typing import Any, Optional from urllib import parse as urlparse import nacl.signing import requests +import signedjson.types import srvlookup import yaml from requests.adapters import HTTPAdapter @@ -69,7 +71,9 @@ def encode_canonical_json(value): ).encode("UTF-8") -def sign_json(json_object, signing_key, signing_name): +def sign_json( + json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str +) -> Any: signatures = json_object.pop("signatures", {}) unsigned = json_object.pop("unsigned", None) @@ -122,7 +126,14 @@ def read_signing_keys(stream): return keys -def request_json(method, origin_name, origin_key, destination, path, content): +def request( + method: Optional[str], + origin_name: str, + origin_key: signedjson.types.SigningKey, + destination: str, + path: str, + content: Optional[str], +) -> requests.Response: if method is None: if content is None: method = "GET" @@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content): if method == "POST": headers["Content-Type"] = "application/json" - result = s.request( - method=method, url=dest, headers=headers, verify=False, data=content + return s.request( + method=method, + url=dest, + headers=headers, + verify=False, + data=content, + stream=True, ) - sys.stderr.write("Status Code: %d\n" % (result.status_code,)) - return result.json() def main(): @@ -222,7 +236,7 @@ def main(): with open(args.signing_key_path) as f: key = read_signing_keys(f)[0] - result = request_json( + result = request( args.method, args.server_name, key, @@ -231,7 +245,12 @@ def main(): content=args.body, ) - json.dump(result, sys.stdout) + sys.stderr.write("Status Code: %d\n" % (result.status_code,)) + + for chunk in result.iter_content(): + # we write raw utf8 to stdout. + sys.stdout.buffer.write(chunk) + print("")