Add reactor to `SynapseRequest` and fix up types. (#10868)

This commit is contained in:
Erik Johnston 2021-09-24 11:01:25 +01:00 committed by GitHub
parent fa74536384
commit 50022cff96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 123 additions and 82 deletions

View File

@ -0,0 +1 @@
Speed up responding with large JSON objects to requests.

View File

@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response( def _send_response(
self, self,
request: Request, request: SynapseRequest,
code: int, code: int,
response_object: Any, response_object: Any,
): ):
@ -629,7 +629,7 @@ def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
def respond_with_json( def respond_with_json(
request: Request, request: SynapseRequest,
code: int, code: int,
json_object: Any, json_object: Any,
send_cors: bool = False, send_cors: bool = False,

View File

@ -14,13 +14,14 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Tuple, Union from typing import Generator, Optional, Tuple, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request logcontext: the log context for this request
""" """
def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw): def __init__(
Request.__init__(self, channel, *args, **kw) self,
channel: HTTPChannel,
site: "SynapseSite",
*args,
max_request_body_size: int = 1024,
**kw,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
@ -97,7 +106,7 @@ class SynapseRequest(Request):
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"), self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag, self.synapse_site.site_tag,
) )
def handleContentChunk(self, data: bytes) -> None: def handleContentChunk(self, data: bytes) -> None:
@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest( request=ContextRequest(
request_id=request_id, request_id=request_id,
ip_address=self.getClientIP(), ip_address=self.getClientIP(),
site_tag=self.site.site_tag, site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point. # The requester is going to be unknown at this point.
requester=None, requester=None,
authenticated_entity=None, authenticated_entity=None,
@ -228,7 +237,7 @@ class SynapseRequest(Request):
) )
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab # we start the request metrics timer here with an initial stab
@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc() requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager @contextlib.contextmanager
def processing(self): def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request. """Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is: Returns a context manager; the correct way to use this is:
@ -346,10 +355,10 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method() self.start_time, name=servlet_name, method=self.get_method()
) )
self.site.access_logger.debug( self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.synapse_site.site_tag,
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
) )
@ -388,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity: if authenticated_entity:
requester = f"{authenticated_entity}|{requester}" requester = f"{authenticated_entity}|{requester}"
self.site.access_logger.log( self.synapse_site.access_logger.log(
log_level, log_level,
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.synapse_site.site_tag,
requester, requester,
processing_time, processing_time,
response_send_time, response_send_time,
@ -522,7 +531,7 @@ class SynapseSite(Site):
site_tag: str, site_tag: str,
config: ListenerConfig, config: ListenerConfig,
resource: IResource, resource: IResource,
server_version_string, server_version_string: str,
max_request_body_size: int, max_request_body_size: int,
reactor: IReactorTime, reactor: IReactorTime,
): ):
@ -542,6 +551,7 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor) Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag self.site_tag = site_tag
self.reactor = reactor
assert config.http_options is not None assert config.http_options is not None
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
@ -550,6 +560,7 @@ class SynapseSite(Site):
def request_factory(channel, queued: bool) -> Request: def request_factory(channel, queued: bool) -> Request:
return request_class( return request_class(
channel, channel,
self,
max_request_body_size=max_request_body_size, max_request_body_size=max_request_body_size,
queued=queued, queued=queued,
) )

View File

@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
@ -102,7 +101,7 @@ class RemoteKey(DirectServeJsonResource):
) )
self.config = hs.config self.config = hs.config
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None assert request.postpath is not None
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
@ -119,7 +118,7 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request: Request) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
@ -128,7 +127,7 @@ class RemoteKey(DirectServeJsonResource):
async def query_keys( async def query_keys(
self, self,
request: Request, request: SynapseRequest,
query: JsonDict, query: JsonDict,
query_remote_on_cache_miss: bool = False, query_remote_on_cache_miss: bool = False,
) -> None: ) -> None:

View File

@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
) )
def respond_404(request: Request) -> None: def respond_404(request: SynapseRequest) -> None:
respond_with_json( respond_with_json(
request, request,
404, 404,
@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
async def respond_with_file( async def respond_with_file(
request: Request, request: SynapseRequest,
media_type: str, media_type: str,
file_path: str, file_path: str,
file_size: Optional[int] = None, file_size: Optional[int] = None,
@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
async def respond_with_responder( async def respond_with_responder(
request: Request, request: SynapseRequest,
responder: "Optional[Responder]", responder: "Optional[Responder]",
media_type: str, media_type: str,
file_size: Optional[int], file_size: Optional[int],

View File

@ -16,8 +16,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource):
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View File

@ -15,10 +15,9 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
b"Content-Security-Policy", b"Content-Security-Policy",

View File

@ -23,7 +23,6 @@ import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import ( from synapse.api.errors import (
FederationDeniedError, FederationDeniedError,
@ -34,6 +33,7 @@ from synapse.api.errors import (
) )
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement from synapse.config.repository import ThumbnailRequirement
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID from synapse.types import UserID
@ -189,7 +189,7 @@ class MediaRepository:
return "mxc://%s/%s" % (self.server_name, media_id) return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media( async def get_local_media(
self, request: Request, media_id: str, name: Optional[str] self, request: SynapseRequest, media_id: str, name: Optional[str]
) -> None: ) -> None:
"""Responds to requests for local media, if exists, or returns 404. """Responds to requests for local media, if exists, or returns 404.
@ -223,7 +223,11 @@ class MediaRepository:
) )
async def get_remote_media( async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str] self,
request: SynapseRequest,
server_name: str,
media_id: str,
name: Optional[str],
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.

View File

@ -29,7 +29,6 @@ import attr
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -168,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000 self._start_expire_url_cache_data, 10 * 1000
) )
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View File

@ -17,11 +17,10 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import ( from ._base import (
@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_local_thumbnail( async def _respond_local_thumbnail(
self, self,
request: Request, request: SynapseRequest,
media_id: str, media_id: str,
width: int, width: int,
height: int, height: int,
@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
request: Request, request: SynapseRequest,
media_id: str, media_id: str,
desired_width: int, desired_width: int,
desired_height: int, desired_height: int,
@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail( async def _select_or_generate_remote_thumbnail(
self, self,
request: Request, request: SynapseRequest,
server_name: str, server_name: str,
media_id: str, media_id: str,
desired_width: int, desired_width: int,
@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_remote_thumbnail( async def _respond_remote_thumbnail(
self, self,
request: Request, request: SynapseRequest,
server_name: str, server_name: str,
media_id: str, media_id: str,
width: int, width: int,
@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_and_respond_with_thumbnail( async def _select_and_respond_with_thumbnail(
self, self,
request: Request, request: SynapseRequest,
desired_width: int, desired_width: int,
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,

View File

@ -16,8 +16,6 @@
import logging import logging
from typing import IO, TYPE_CHECKING, Dict, List, Optional from typing import IO, TYPE_CHECKING, Dict, List, Optional
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args from synapse.http.servlet import parse_bytes_from_args
@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:

View File

@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/") channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _SyncTestCustomEndpoint({}, None).handle_request handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/") channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site" site.site_tag = "test-site"
site.server_version_string = "Server v1" site.server_version_string = "Server v1"
request = SynapseRequest(FakeChannel(site, None)) site.reactor = Mock()
request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object. # Call requestReceived to finish instantiating the object.
request.content = BytesIO() request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest. # Partially skip some of the internal processing of SynapseRequest.

View File

@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
resource = hs.get_media_repository_resource().children[b"download"] resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
f"/{target}/{media_id}", f"/{target}/{media_id}",
shorthand=False, shorthand=False,

View File

@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media # Attempt to access the media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_name_and_media_id, server_name_and_media_id,
shorthand=False, shorthand=False,
@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media # Attempt to access each piece of media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id_2, server_and_media_id_2,
shorthand=False, shorthand=False,

View File

@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media # Attempt to access media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media # Attempt to access media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts` # Try to access a media and to create `last_access_ts`
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page # Load the password reset confirmation page
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.submit_token_resource), FakeSite(self.submit_token_resource, self.reactor),
"GET", "GET",
path, path,
shorthand=False, shorthand=False,
@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset # Confirm the password reset
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.submit_token_resource), FakeSite(self.submit_token_resource, self.reactor),
"POST", "POST",
path, path,
content=b"", content=b"",

View File

@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
channel = make_request( channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False self.reactor,
FakeSite(resource, self.reactor),
"GET",
"/consent?v=1",
shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
) )
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
consent_uri, consent_uri,
access_token=access_token, access_token=access_token,
@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed # POST to the consent page, saying we've agreed
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"POST", "POST",
consent_uri + "&v=" + version, consent_uri + "&v=" + version,
access_token=access_token, access_token=access_token,
@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed # changed
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
consent_uri, consent_uri,
access_token=access_token, access_token=access_token,

View File

@ -383,7 +383,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
FakeSite(resource), FakeSite(resource, self.hs.get_reactor()),
"POST", "POST",
path, path,
content=image_data, content=image_data,

View File

@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body. Checks that the response is a 200 and returns the decoded json body.
""" """
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel) req = SynapseRequest(channel, self.site)
req.content = BytesIO(b"") req.content = BytesIO(b"")
req.requestReceived( req.requestReceived(
b"GET", b"GET",
@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
) )
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel) req = SynapseRequest(channel, self.site)
req.content = BytesIO(encode_canonical_json(data)) req.content = BytesIO(encode_canonical_json(data))
req.requestReceived( req.requestReceived(

View File

@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
self.media_id, self.media_id,
shorthand=False, shorthand=False,
@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale" params = "?width=32&height=32&method=scale"
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,
@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,
@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,

View File

@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
IPullProducer, IPullProducer,
IPushProducer, IPushProducer,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple, IResolverSimple,
ITransport, ITransport,
) )
@ -181,13 +182,14 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource): def __init__(self, resource: IResource, reactor: IReactorTime):
""" """
Args: Args:
resource: the resource to be used for rendering all requests resource: the resource to be used for rendering all requests
""" """
self._resource = resource self._resource = resource
self.reactor = reactor
def getResourceFor(self, request): def getResourceFor(self, request):
return self._resource return self._resource
@ -268,7 +270,7 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip) channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel) req = request(channel, site)
req.content = BytesIO(content) req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request. # Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END) req.content.seek(SEEK_END)

View File

@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
) )
make_request( make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor,
FakeSite(res, self.reactor),
b"GET",
b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
) )
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
def _callback(request, **kwargs): def _callback(request, **kwargs):
d = Deferred() d = Deferred()
d.addCallback(_throw) d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True) self.reactor.callLater(0.5, d.callback, True)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["error"], "Unrecognized request")
@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)
@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"] body = channel.result["body"]
@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"] headers = channel.result["headers"]
@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"] headers = channel.result["headers"]
@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)