Add reactor to `SynapseRequest` and fix up types. (#10868)
This commit is contained in:
parent
fa74536384
commit
50022cff96
|
@ -0,0 +1 @@
|
|||
Speed up responding with large JSON objects to requests.
|
|
@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
|
|||
|
||||
def _send_response(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
code: int,
|
||||
response_object: Any,
|
||||
):
|
||||
|
@ -629,7 +629,7 @@ def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
|
|||
|
||||
|
||||
def respond_with_json(
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
code: int,
|
||||
json_object: Any,
|
||||
send_cors: bool = False,
|
||||
|
|
|
@ -14,13 +14,14 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Generator, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IAddress, IReactorTime
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.resource import IResource, Resource
|
||||
from twisted.web.server import Request, Site
|
||||
|
||||
|
@ -61,10 +62,18 @@ class SynapseRequest(Request):
|
|||
logcontext: the log context for this request
|
||||
"""
|
||||
|
||||
def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
|
||||
Request.__init__(self, channel, *args, **kw)
|
||||
def __init__(
|
||||
self,
|
||||
channel: HTTPChannel,
|
||||
site: "SynapseSite",
|
||||
*args,
|
||||
max_request_body_size: int = 1024,
|
||||
**kw,
|
||||
):
|
||||
super().__init__(channel, *args, **kw)
|
||||
self._max_request_body_size = max_request_body_size
|
||||
self.site: SynapseSite = channel.site
|
||||
self.synapse_site = site
|
||||
self.reactor = site.reactor
|
||||
self._channel = channel # this is used by the tests
|
||||
self.start_time = 0.0
|
||||
|
||||
|
@ -97,7 +106,7 @@ class SynapseRequest(Request):
|
|||
self.get_method(),
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto.decode("ascii", errors="replace"),
|
||||
self.site.site_tag,
|
||||
self.synapse_site.site_tag,
|
||||
)
|
||||
|
||||
def handleContentChunk(self, data: bytes) -> None:
|
||||
|
@ -216,7 +225,7 @@ class SynapseRequest(Request):
|
|||
request=ContextRequest(
|
||||
request_id=request_id,
|
||||
ip_address=self.getClientIP(),
|
||||
site_tag=self.site.site_tag,
|
||||
site_tag=self.synapse_site.site_tag,
|
||||
# The requester is going to be unknown at this point.
|
||||
requester=None,
|
||||
authenticated_entity=None,
|
||||
|
@ -228,7 +237,7 @@ class SynapseRequest(Request):
|
|||
)
|
||||
|
||||
# override the Server header which is set by twisted
|
||||
self.setHeader("Server", self.site.server_version_string)
|
||||
self.setHeader("Server", self.synapse_site.server_version_string)
|
||||
|
||||
with PreserveLoggingContext(self.logcontext):
|
||||
# we start the request metrics timer here with an initial stab
|
||||
|
@ -247,7 +256,7 @@ class SynapseRequest(Request):
|
|||
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def processing(self):
|
||||
def processing(self) -> Generator[None, None, None]:
|
||||
"""Record the fact that we are processing this request.
|
||||
|
||||
Returns a context manager; the correct way to use this is:
|
||||
|
@ -346,10 +355,10 @@ class SynapseRequest(Request):
|
|||
self.start_time, name=servlet_name, method=self.get_method()
|
||||
)
|
||||
|
||||
self.site.access_logger.debug(
|
||||
self.synapse_site.access_logger.debug(
|
||||
"%s - %s - Received request: %s %s",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.synapse_site.site_tag,
|
||||
self.get_method(),
|
||||
self.get_redacted_uri(),
|
||||
)
|
||||
|
@ -388,13 +397,13 @@ class SynapseRequest(Request):
|
|||
if authenticated_entity:
|
||||
requester = f"{authenticated_entity}|{requester}"
|
||||
|
||||
self.site.access_logger.log(
|
||||
self.synapse_site.access_logger.log(
|
||||
log_level,
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.synapse_site.site_tag,
|
||||
requester,
|
||||
processing_time,
|
||||
response_send_time,
|
||||
|
@ -522,7 +531,7 @@ class SynapseSite(Site):
|
|||
site_tag: str,
|
||||
config: ListenerConfig,
|
||||
resource: IResource,
|
||||
server_version_string,
|
||||
server_version_string: str,
|
||||
max_request_body_size: int,
|
||||
reactor: IReactorTime,
|
||||
):
|
||||
|
@ -542,6 +551,7 @@ class SynapseSite(Site):
|
|||
Site.__init__(self, resource, reactor=reactor)
|
||||
|
||||
self.site_tag = site_tag
|
||||
self.reactor = reactor
|
||||
|
||||
assert config.http_options is not None
|
||||
proxied = config.http_options.x_forwarded
|
||||
|
@ -550,6 +560,7 @@ class SynapseSite(Site):
|
|||
def request_factory(channel, queued: bool) -> Request:
|
||||
return request_class(
|
||||
channel,
|
||||
self,
|
||||
max_request_body_size=max_request_body_size,
|
||||
queued=queued,
|
||||
)
|
||||
|
|
|
@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
|
|||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
@ -102,7 +101,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||
)
|
||||
self.config = hs.config
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
assert request.postpath is not None
|
||||
if len(request.postpath) == 1:
|
||||
(server,) = request.postpath
|
||||
|
@ -119,7 +118,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
async def _async_render_POST(self, request: Request) -> None:
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
query = content["server_keys"]
|
||||
|
@ -128,7 +127,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||
|
||||
async def query_keys(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
query: JsonDict,
|
||||
query_remote_on_cache_miss: bool = False,
|
||||
) -> None:
|
||||
|
|
|
@ -27,6 +27,7 @@ from twisted.web.server import Request
|
|||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.http.server import finish_request, respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
|
@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
|||
)
|
||||
|
||||
|
||||
def respond_404(request: Request) -> None:
|
||||
def respond_404(request: SynapseRequest) -> None:
|
||||
respond_with_json(
|
||||
request,
|
||||
404,
|
||||
|
@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
|
|||
|
||||
|
||||
async def respond_with_file(
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
media_type: str,
|
||||
file_path: str,
|
||||
file_size: Optional[int] = None,
|
||||
|
@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
|
|||
|
||||
|
||||
async def respond_with_responder(
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
responder: "Optional[Responder]",
|
||||
media_type: str,
|
||||
file_size: Optional[int],
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
|
@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource):
|
|||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
|
|
@ -15,10 +15,9 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_boolean
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
|
||||
|
@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
|
|||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
b"Content-Security-Policy",
|
||||
|
|
|
@ -23,7 +23,6 @@ import twisted.internet.error
|
|||
import twisted.web.http
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import (
|
||||
FederationDeniedError,
|
||||
|
@ -34,6 +33,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.repository import ThumbnailRequirement
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import UserID
|
||||
|
@ -189,7 +189,7 @@ class MediaRepository:
|
|||
return "mxc://%s/%s" % (self.server_name, media_id)
|
||||
|
||||
async def get_local_media(
|
||||
self, request: Request, media_id: str, name: Optional[str]
|
||||
self, request: SynapseRequest, media_id: str, name: Optional[str]
|
||||
) -> None:
|
||||
"""Responds to requests for local media, if exists, or returns 404.
|
||||
|
||||
|
@ -223,7 +223,11 @@ class MediaRepository:
|
|||
)
|
||||
|
||||
async def get_remote_media(
|
||||
self, request: Request, server_name: str, media_id: str, name: Optional[str]
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
name: Optional[str],
|
||||
) -> None:
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ import attr
|
|||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
|
@ -168,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
|
||||
from ._base import (
|
||||
|
@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
width = parse_integer(request, "width", required=True)
|
||||
|
@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _respond_local_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
|
@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
|
@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
|
@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _respond_remote_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
width: int,
|
||||
|
@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
async def _select_and_respond_with_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_bytes_from_args
|
||||
|
@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
|
|||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||
|
|
|
@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
||||
resource = AdditionalResource(self.hs, handler)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
||||
|
@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
||||
resource = AdditionalResource(self.hs, handler)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
|
||||
|
|
|
@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
|||
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
|
||||
site.site_tag = "test-site"
|
||||
site.server_version_string = "Server v1"
|
||||
request = SynapseRequest(FakeChannel(site, None))
|
||||
site.reactor = Mock()
|
||||
request = SynapseRequest(FakeChannel(site, None), site)
|
||||
# Call requestReceived to finish instantiating the object.
|
||||
request.content = BytesIO()
|
||||
# Partially skip some of the internal processing of SynapseRequest.
|
||||
|
|
|
@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
resource = hs.get_media_repository_resource().children[b"download"]
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(resource),
|
||||
FakeSite(resource, self.reactor),
|
||||
"GET",
|
||||
f"/{target}/{media_id}",
|
||||
shorthand=False,
|
||||
|
|
|
@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
"""Ensure a piece of media is quarantined when trying to access it."""
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.download_resource),
|
||||
FakeSite(self.download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id,
|
||||
shorthand=False,
|
||||
|
@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
# Attempt to access the media
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.download_resource),
|
||||
FakeSite(self.download_resource, self.reactor),
|
||||
"GET",
|
||||
server_name_and_media_id,
|
||||
shorthand=False,
|
||||
|
@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
# Attempt to access each piece of media
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.download_resource),
|
||||
FakeSite(self.download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id_2,
|
||||
shorthand=False,
|
||||
|
|
|
@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
|
|||
# Attempt to access media
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(download_resource),
|
||||
FakeSite(download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id,
|
||||
shorthand=False,
|
||||
|
@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
|
|||
# Attempt to access media
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(download_resource),
|
||||
FakeSite(download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id,
|
||||
shorthand=False,
|
||||
|
@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(download_resource),
|
||||
FakeSite(download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id,
|
||||
shorthand=False,
|
||||
|
|
|
@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
|||
# Try to access a media and to create `last_access_ts`
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(download_resource),
|
||||
FakeSite(download_resource, self.reactor),
|
||||
"GET",
|
||||
server_and_media_id,
|
||||
shorthand=False,
|
||||
|
|
|
@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||
# Load the password reset confirmation page
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.submit_token_resource),
|
||||
FakeSite(self.submit_token_resource, self.reactor),
|
||||
"GET",
|
||||
path,
|
||||
shorthand=False,
|
||||
|
@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||
# Confirm the password reset
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.submit_token_resource),
|
||||
FakeSite(self.submit_token_resource, self.reactor),
|
||||
"POST",
|
||||
path,
|
||||
content=b"",
|
||||
|
|
|
@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||
"""You can observe the terms form without specifying a user"""
|
||||
resource = consent_resource.ConsentResource(self.hs)
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
|
||||
self.reactor,
|
||||
FakeSite(resource, self.reactor),
|
||||
"GET",
|
||||
"/consent?v=1",
|
||||
shorthand=False,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
|
@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(resource),
|
||||
FakeSite(resource, self.reactor),
|
||||
"GET",
|
||||
consent_uri,
|
||||
access_token=access_token,
|
||||
|
@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||
# POST to the consent page, saying we've agreed
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(resource),
|
||||
FakeSite(resource, self.reactor),
|
||||
"POST",
|
||||
consent_uri + "&v=" + version,
|
||||
access_token=access_token,
|
||||
|
@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||
# changed
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(resource),
|
||||
FakeSite(resource, self.reactor),
|
||||
"GET",
|
||||
consent_uri,
|
||||
access_token=access_token,
|
||||
|
|
|
@ -383,7 +383,7 @@ class RestHelper:
|
|||
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
||||
channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
FakeSite(resource),
|
||||
FakeSite(resource, self.hs.get_reactor()),
|
||||
"POST",
|
||||
path,
|
||||
content=image_data,
|
||||
|
|
|
@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
|
|||
Checks that the response is a 200 and returns the decoded json body.
|
||||
"""
|
||||
channel = FakeChannel(self.site, self.reactor)
|
||||
req = SynapseRequest(channel)
|
||||
req = SynapseRequest(channel, self.site)
|
||||
req.content = BytesIO(b"")
|
||||
req.requestReceived(
|
||||
b"GET",
|
||||
|
@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
|||
)
|
||||
|
||||
channel = FakeChannel(self.site, self.reactor)
|
||||
req = SynapseRequest(channel)
|
||||
req = SynapseRequest(channel, self.site)
|
||||
req.content = BytesIO(encode_canonical_json(data))
|
||||
|
||||
req.requestReceived(
|
||||
|
|
|
@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.download_resource),
|
||||
FakeSite(self.download_resource, self.reactor),
|
||||
"GET",
|
||||
self.media_id,
|
||||
shorthand=False,
|
||||
|
@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
params = "?width=32&height=32&method=scale"
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.thumbnail_resource),
|
||||
FakeSite(self.thumbnail_resource, self.reactor),
|
||||
"GET",
|
||||
self.media_id + params,
|
||||
shorthand=False,
|
||||
|
@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.thumbnail_resource),
|
||||
FakeSite(self.thumbnail_resource, self.reactor),
|
||||
"GET",
|
||||
self.media_id + params,
|
||||
shorthand=False,
|
||||
|
@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
params = "?width=32&height=32&method=" + method
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.thumbnail_resource),
|
||||
FakeSite(self.thumbnail_resource, self.reactor),
|
||||
"GET",
|
||||
self.media_id + params,
|
||||
shorthand=False,
|
||||
|
|
|
@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
|
|||
IPullProducer,
|
||||
IPushProducer,
|
||||
IReactorPluggableNameResolver,
|
||||
IReactorTime,
|
||||
IResolverSimple,
|
||||
ITransport,
|
||||
)
|
||||
|
@ -181,13 +182,14 @@ class FakeSite:
|
|||
site_tag = "test"
|
||||
access_logger = logging.getLogger("synapse.access.http.fake")
|
||||
|
||||
def __init__(self, resource: IResource):
|
||||
def __init__(self, resource: IResource, reactor: IReactorTime):
|
||||
"""
|
||||
|
||||
Args:
|
||||
resource: the resource to be used for rendering all requests
|
||||
"""
|
||||
self._resource = resource
|
||||
self.reactor = reactor
|
||||
|
||||
def getResourceFor(self, request):
|
||||
return self._resource
|
||||
|
@ -268,7 +270,7 @@ def make_request(
|
|||
|
||||
channel = FakeChannel(site, reactor, ip=client_ip)
|
||||
|
||||
req = request(channel)
|
||||
req = request(channel, site)
|
||||
req.content = BytesIO(content)
|
||||
# Twisted expects to be at the end of the content when parsing the request.
|
||||
req.content.seek(SEEK_END)
|
||||
|
|
|
@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
make_request(
|
||||
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
|
||||
self.reactor,
|
||||
FakeSite(res, self.reactor),
|
||||
b"GET",
|
||||
b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
|
||||
)
|
||||
|
||||
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
|
||||
|
@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||
)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"500")
|
||||
|
||||
|
@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||
def _callback(request, **kwargs):
|
||||
d = Deferred()
|
||||
d.addCallback(_throw)
|
||||
self.reactor.callLater(1, d.callback, True)
|
||||
self.reactor.callLater(0.5, d.callback, True)
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
|
@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||
)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"500")
|
||||
|
||||
|
@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||
)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
|
||||
|
@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||
)
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
self.assertEqual(channel.json_body["error"], "Unrecognized request")
|
||||
|
@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
# The path was registered as GET, but this is a HEAD request.
|
||||
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertNotIn("body", channel.result)
|
||||
|
@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
body = channel.result["body"]
|
||||
|
@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"301")
|
||||
headers = channel.result["headers"]
|
||||
|
@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"304")
|
||||
headers = channel.result["headers"]
|
||||
|
@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
|
||||
channel = make_request(
|
||||
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
|
||||
)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertNotIn("body", channel.result)
|
||||
|
|
Loading…
Reference in New Issue