Fix-up type hints in tests/server.py. (#15084)
This file was being ignored by mypy, we remove that and add the missing type hints & deal with any fallout.
This commit is contained in:
parent
61bfcd669a
commit
c9b9143655
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
2
mypy.ini
2
mypy.ini
|
@ -31,8 +31,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/server.py
|
|
||||||
)$
|
)$
|
||||||
|
|
||||||
[mypy-synapse.federation.transport.client]
|
[mypy-synapse.federation.transport.client]
|
||||||
|
|
|
@ -11,12 +11,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
|
from typing import List, Optional, Sequence, Tuple, cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.appservice import (
|
from synapse.appservice import (
|
||||||
ApplicationService,
|
ApplicationService,
|
||||||
|
@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
|
||||||
|
|
||||||
from ..utils import MockClock
|
from ..utils import MockClock
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from twisted.internet.testing import MemoryReactor
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
|
||||||
IOpenSSLClientConnectionCreator,
|
IOpenSSLClientConnectionCreator,
|
||||||
IProtocolFactory,
|
IProtocolFactory,
|
||||||
)
|
)
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
from twisted.web._newclient import ResponseNeverReceived
|
from twisted.web._newclient import ResponseNeverReceived
|
||||||
from twisted.web.client import Agent
|
from twisted.web.client import Agent
|
||||||
|
@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
assert isinstance(proxy_server_transport, FakeTransport)
|
assert isinstance(proxy_server_transport, FakeTransport)
|
||||||
client_protocol = proxy_server_transport.other
|
client_protocol = proxy_server_transport.other
|
||||||
c2s_transport = client_protocol.transport
|
assert isinstance(client_protocol, Protocol)
|
||||||
|
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
|
||||||
c2s_transport.other = server_ssl_protocol
|
c2s_transport.other = server_ssl_protocol
|
||||||
|
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
|
|
|
@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
|
||||||
_WrappingProtocol,
|
_WrappingProtocol,
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
else:
|
else:
|
||||||
assert isinstance(proxy_server_transport, FakeTransport)
|
assert isinstance(proxy_server_transport, FakeTransport)
|
||||||
client_protocol = proxy_server_transport.other
|
client_protocol = proxy_server_transport.other
|
||||||
c2s_transport = client_protocol.transport
|
assert isinstance(client_protocol, Protocol)
|
||||||
|
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
|
||||||
c2s_transport.other = server_ssl_protocol
|
c2s_transport.other = server_ssl_protocol
|
||||||
|
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
|
|
|
@ -34,7 +34,7 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.handlers.test_oidc import HAS_OIDC
|
from tests.handlers.test_oidc import HAS_OIDC
|
||||||
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
|
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
|
||||||
from tests.server import FakeChannel, make_request
|
from tests.server import FakeChannel
|
||||||
from tests.unittest import override_config, skip_unless
|
from tests.unittest import override_config, skip_unless
|
||||||
|
|
||||||
|
|
||||||
|
@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
|
||||||
channel = self.submit_logout_token(logout_token)
|
channel = self.submit_logout_token(logout_token)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
# Now try to exchange the login token
|
# Now try to exchange the login token, it should fail.
|
||||||
channel = make_request(
|
self.helper.login_via_token(login_token, 403)
|
||||||
self.hs.get_reactor(),
|
|
||||||
self.site,
|
|
||||||
"POST",
|
|
||||||
"/login",
|
|
||||||
content={"type": "m.login.token", "token": login_token},
|
|
||||||
)
|
|
||||||
# It should have failed
|
|
||||||
self.assertEqual(channel.code, 403)
|
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
|
|
|
@ -36,6 +36,7 @@ from urllib.parse import urlencode
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactorClock
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site
|
||||||
|
|
||||||
|
@ -67,6 +68,7 @@ class RestHelper:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hs: HomeServer
|
hs: HomeServer
|
||||||
|
reactor: MemoryReactorClock
|
||||||
site: Site
|
site: Site
|
||||||
auth_user_id: Optional[str]
|
auth_user_id: Optional[str]
|
||||||
|
|
||||||
|
@ -142,7 +144,7 @@ class RestHelper:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
|
@ -216,7 +218,7 @@ class RestHelper:
|
||||||
data["reason"] = reason
|
data["reason"] = reason
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
|
@ -313,7 +315,7 @@ class RestHelper:
|
||||||
data.update(extra_data or {})
|
data.update(extra_data or {})
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"PUT",
|
"PUT",
|
||||||
path,
|
path,
|
||||||
|
@ -394,7 +396,7 @@ class RestHelper:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"PUT",
|
"PUT",
|
||||||
path,
|
path,
|
||||||
|
@ -433,7 +435,7 @@ class RestHelper:
|
||||||
path = path + f"?access_token={tok}"
|
path = path + f"?access_token={tok}"
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
path,
|
path,
|
||||||
|
@ -488,7 +490,7 @@ class RestHelper:
|
||||||
if body is not None:
|
if body is not None:
|
||||||
content = json.dumps(body).encode("utf8")
|
content = json.dumps(body).encode("utf8")
|
||||||
|
|
||||||
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
|
channel = make_request(self.reactor, self.site, method, path, content)
|
||||||
|
|
||||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||||
expect_code,
|
expect_code,
|
||||||
|
@ -573,8 +575,8 @@ class RestHelper:
|
||||||
image_length = len(image_data)
|
image_length = len(image_data)
|
||||||
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
FakeSite(resource, self.hs.get_reactor()),
|
FakeSite(resource, self.reactor),
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
content=image_data,
|
content=image_data,
|
||||||
|
@ -603,7 +605,7 @@ class RestHelper:
|
||||||
expect_code: The return code to expect from attempting the whoami request
|
expect_code: The return code to expect from attempting the whoami request
|
||||||
"""
|
"""
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
"account/whoami",
|
"account/whoami",
|
||||||
|
@ -642,7 +644,7 @@ class RestHelper:
|
||||||
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
|
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
|
||||||
"""Log in (as a new user) via OIDC
|
"""Log in (as a new user) via OIDC
|
||||||
|
|
||||||
Returns the result of the final token login.
|
Returns the result of the final token login and the fake authorization grant.
|
||||||
|
|
||||||
Requires that "oidc_config" in the homeserver config be set appropriately
|
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||||
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||||
|
@ -672,10 +674,28 @@ class RestHelper:
|
||||||
assert m, channel.text_body
|
assert m, channel.text_body
|
||||||
login_token = m.group(1)
|
login_token = m.group(1)
|
||||||
|
|
||||||
# finally, submit the matrix login token to the login API, which gives us our
|
return self.login_via_token(login_token, expected_status), grant
|
||||||
# matrix access token and device id.
|
|
||||||
|
def login_via_token(
|
||||||
|
self,
|
||||||
|
login_token: str,
|
||||||
|
expected_status: int = 200,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""Submit the matrix login token to the login API, which gives us our
|
||||||
|
matrix access token and device id.Log in (as a new user) via OIDC
|
||||||
|
|
||||||
|
Returns the result of the token login.
|
||||||
|
|
||||||
|
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||||
|
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||||
|
"public_base_url".
|
||||||
|
|
||||||
|
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||||
|
the normal places.
|
||||||
|
"""
|
||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"POST",
|
"POST",
|
||||||
"/login",
|
"/login",
|
||||||
|
@ -684,7 +704,7 @@ class RestHelper:
|
||||||
assert (
|
assert (
|
||||||
channel.code == expected_status
|
channel.code == expected_status
|
||||||
), f"unexpected status in response: {channel.code}"
|
), f"unexpected status in response: {channel.code}"
|
||||||
return channel.json_body, grant
|
return channel.json_body
|
||||||
|
|
||||||
def auth_via_oidc(
|
def auth_via_oidc(
|
||||||
self,
|
self,
|
||||||
|
@ -805,7 +825,7 @@ class RestHelper:
|
||||||
with fake_serer.patch_homeserver(hs=self.hs):
|
with fake_serer.patch_homeserver(hs=self.hs):
|
||||||
# now hit the callback URI with the right params and a made-up code
|
# now hit the callback URI with the right params and a made-up code
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
callback_uri,
|
callback_uri,
|
||||||
|
@ -849,7 +869,7 @@ class RestHelper:
|
||||||
# is the easiest way of figuring out what the Host header ought to be set to
|
# is the easiest way of figuring out what the Host header ought to be set to
|
||||||
# to keep Synapse happy.
|
# to keep Synapse happy.
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
uri,
|
uri,
|
||||||
|
@ -867,7 +887,7 @@ class RestHelper:
|
||||||
location = get_location(channel)
|
location = get_location(channel)
|
||||||
parts = urllib.parse.urlsplit(location)
|
parts = urllib.parse.urlsplit(location)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.reactor,
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
urllib.parse.urlunsplit(("", "") + parts[2:]),
|
urllib.parse.urlunsplit(("", "") + parts[2:]),
|
||||||
|
@ -900,9 +920,7 @@ class RestHelper:
|
||||||
+ urllib.parse.urlencode({"session": ui_auth_session_id})
|
+ urllib.parse.urlencode({"session": ui_auth_session_id})
|
||||||
)
|
)
|
||||||
# hit the redirect url (which will issue a cookie and state)
|
# hit the redirect url (which will issue a cookie and state)
|
||||||
channel = make_request(
|
channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
|
||||||
self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
|
|
||||||
)
|
|
||||||
# that should serve a confirmation page
|
# that should serve a confirmation page
|
||||||
assert channel.code == HTTPStatus.OK, channel.text_body
|
assert channel.code == HTTPStatus.OK, channel.text_body
|
||||||
channel.extract_cookies(cookies)
|
channel.extract_cookies(cookies)
|
||||||
|
|
253
tests/server.py
253
tests/server.py
|
@ -22,20 +22,25 @@ import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from io import SEEK_END, BytesIO
|
from io import SEEK_END, BytesIO
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Deque
|
from typing_extensions import Deque, ParamSpec
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import address, threads, udp
|
from twisted.internet import address, threads, udp
|
||||||
|
@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IAddress,
|
IAddress,
|
||||||
|
IConnector,
|
||||||
IConsumer,
|
IConsumer,
|
||||||
IHostnameResolver,
|
IHostnameResolver,
|
||||||
|
IProducer,
|
||||||
IProtocol,
|
IProtocol,
|
||||||
IPullProducer,
|
IPullProducer,
|
||||||
IPushProducer,
|
IPushProducer,
|
||||||
|
@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
|
||||||
IResolverSimple,
|
IResolverSimple,
|
||||||
ITransport,
|
ITransport,
|
||||||
)
|
)
|
||||||
|
from twisted.internet.protocol import ClientFactory, DatagramProtocol
|
||||||
|
from twisted.python import threadpool
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
@ -61,6 +70,7 @@ from twisted.web.resource import IResource
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.events.presence_router import load_legacy_presence_router
|
from synapse.events.presence_router import load_legacy_presence_router
|
||||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||||
|
@ -88,6 +98,9 @@ from tests.utils import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
# the type of thing that can be passed into `make_request` in the headers list
|
# the type of thing that can be passed into `make_request` in the headers list
|
||||||
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
||||||
|
|
||||||
|
@ -98,12 +111,14 @@ class TimedOutException(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@implementer(IConsumer)
|
@implementer(ITransport, IPushProducer, IConsumer)
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class FakeChannel:
|
class FakeChannel:
|
||||||
"""
|
"""
|
||||||
A fake Twisted Web Channel (the part that interfaces with the
|
A fake Twisted Web Channel (the part that interfaces with the
|
||||||
wire).
|
wire).
|
||||||
|
|
||||||
|
See twisted.web.http.HTTPChannel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
site: Union[Site, "FakeSite"]
|
site: Union[Site, "FakeSite"]
|
||||||
|
@ -142,7 +157,7 @@ class FakeChannel:
|
||||||
|
|
||||||
Raises an exception if the request has not yet completed.
|
Raises an exception if the request has not yet completed.
|
||||||
"""
|
"""
|
||||||
if not self.is_finished:
|
if not self.is_finished():
|
||||||
raise Exception("Request not yet completed")
|
raise Exception("Request not yet completed")
|
||||||
return self.result["body"].decode("utf8")
|
return self.result["body"].decode("utf8")
|
||||||
|
|
||||||
|
@ -165,27 +180,36 @@ class FakeChannel:
|
||||||
h.addRawHeader(*i)
|
h.addRawHeader(*i)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def writeHeaders(self, version, code, reason, headers):
|
def writeHeaders(
|
||||||
|
self, version: bytes, code: bytes, reason: bytes, headers: Headers
|
||||||
|
) -> None:
|
||||||
self.result["version"] = version
|
self.result["version"] = version
|
||||||
self.result["code"] = code
|
self.result["code"] = code
|
||||||
self.result["reason"] = reason
|
self.result["reason"] = reason
|
||||||
self.result["headers"] = headers
|
self.result["headers"] = headers
|
||||||
|
|
||||||
def write(self, content: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
|
assert isinstance(data, bytes), "Should be bytes! " + repr(data)
|
||||||
|
|
||||||
if "body" not in self.result:
|
if "body" not in self.result:
|
||||||
self.result["body"] = b""
|
self.result["body"] = b""
|
||||||
|
|
||||||
self.result["body"] += content
|
self.result["body"] += data
|
||||||
|
|
||||||
|
def writeSequence(self, data: Iterable[bytes]) -> None:
|
||||||
|
for x in data:
|
||||||
|
self.write(x)
|
||||||
|
|
||||||
|
def loseConnection(self) -> None:
|
||||||
|
self.unregisterProducer()
|
||||||
|
self.transport.loseConnection()
|
||||||
|
|
||||||
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
|
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
|
||||||
def registerProducer( # type: ignore[override]
|
def registerProducer(self, producer: IProducer, streaming: bool) -> None:
|
||||||
self,
|
# TODO This should ensure that the IProducer is an IPushProducer or
|
||||||
producer: Union[IPullProducer, IPushProducer],
|
# IPullProducer, unfortunately twisted.protocols.basic.FileSender does
|
||||||
streaming: bool,
|
# implement those, but doesn't declare it.
|
||||||
) -> None:
|
self._producer = cast(Union[IPushProducer, IPullProducer], producer)
|
||||||
self._producer = producer
|
|
||||||
self.producerStreaming = streaming
|
self.producerStreaming = streaming
|
||||||
|
|
||||||
def _produce() -> None:
|
def _produce() -> None:
|
||||||
|
@ -202,6 +226,16 @@ class FakeChannel:
|
||||||
|
|
||||||
self._producer = None
|
self._producer = None
|
||||||
|
|
||||||
|
def stopProducing(self) -> None:
|
||||||
|
if self._producer is not None:
|
||||||
|
self._producer.stopProducing()
|
||||||
|
|
||||||
|
def pauseProducing(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def resumeProducing(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def requestDone(self, _self: Request) -> None:
|
def requestDone(self, _self: Request) -> None:
|
||||||
self.result["done"] = True
|
self.result["done"] = True
|
||||||
if isinstance(_self, SynapseRequest):
|
if isinstance(_self, SynapseRequest):
|
||||||
|
@ -281,12 +315,12 @@ class FakeSite:
|
||||||
self.reactor = reactor
|
self.reactor = reactor
|
||||||
self.experimental_cors_msc3886 = experimental_cors_msc3886
|
self.experimental_cors_msc3886 = experimental_cors_msc3886
|
||||||
|
|
||||||
def getResourceFor(self, request):
|
def getResourceFor(self, request: Request) -> IResource:
|
||||||
return self._resource
|
return self._resource
|
||||||
|
|
||||||
|
|
||||||
def make_request(
|
def make_request(
|
||||||
reactor,
|
reactor: MemoryReactorClock,
|
||||||
site: Union[Site, FakeSite],
|
site: Union[Site, FakeSite],
|
||||||
method: Union[bytes, str],
|
method: Union[bytes, str],
|
||||||
path: Union[bytes, str],
|
path: Union[bytes, str],
|
||||||
|
@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
A MemoryReactorClock that supports callFromThread.
|
A MemoryReactorClock that supports callFromThread.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.threadpool = ThreadPool(self)
|
self.threadpool = ThreadPool(self)
|
||||||
|
|
||||||
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
|
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
|
||||||
self._udp = []
|
self._udp: List[udp.Port] = []
|
||||||
self.lookups: Dict[str, str] = {}
|
self.lookups: Dict[str, str] = {}
|
||||||
self._thread_callbacks: Deque[Callable[[], None]] = deque()
|
self._thread_callbacks: Deque[Callable[..., R]] = deque()
|
||||||
|
|
||||||
lookups = self.lookups
|
lookups = self.lookups
|
||||||
|
|
||||||
@implementer(IResolverSimple)
|
@implementer(IResolverSimple)
|
||||||
class FakeResolver:
|
class FakeResolver:
|
||||||
def getHostByName(self, name, timeout=None):
|
def getHostByName(
|
||||||
|
self, name: str, timeout: Optional[Sequence[int]] = None
|
||||||
|
) -> "Deferred[str]":
|
||||||
if name not in lookups:
|
if name not in lookups:
|
||||||
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
||||||
return succeed(lookups[name])
|
return succeed(lookups[name])
|
||||||
|
@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
|
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
|
def listenUDP(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
protocol: DatagramProtocol,
|
||||||
|
interface: str = "",
|
||||||
|
maxPacketSize: int = 8196,
|
||||||
|
) -> udp.Port:
|
||||||
p = udp.Port(port, protocol, interface, maxPacketSize, self)
|
p = udp.Port(port, protocol, interface, maxPacketSize, self)
|
||||||
p.startListening()
|
p.startListening()
|
||||||
self._udp.append(p)
|
self._udp.append(p)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def callFromThread(self, callback, *args, **kwargs):
|
def callFromThread(
|
||||||
|
self, callable: Callable[..., Any], *args: object, **kwargs: object
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Make the callback fire in the next reactor iteration.
|
Make the callback fire in the next reactor iteration.
|
||||||
"""
|
"""
|
||||||
cb = lambda: callback(*args, **kwargs)
|
cb = lambda: callable(*args, **kwargs)
|
||||||
# it's not safe to call callLater() here, so we append the callback to a
|
# it's not safe to call callLater() here, so we append the callback to a
|
||||||
# separate queue.
|
# separate queue.
|
||||||
self._thread_callbacks.append(cb)
|
self._thread_callbacks.append(cb)
|
||||||
|
|
||||||
def getThreadPool(self):
|
def callInThread(
|
||||||
return self.threadpool
|
self, callable: Callable[..., Any], *args: object, **kwargs: object
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
|
def suggestThreadPoolSize(self, size: int) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def getThreadPool(self) -> "threadpool.ThreadPool":
|
||||||
|
# Cast to match super-class.
|
||||||
|
return cast(threadpool.ThreadPool, self.threadpool)
|
||||||
|
|
||||||
|
def add_tcp_client_callback(
|
||||||
|
self, host: str, port: int, callback: Callable[[], None]
|
||||||
|
) -> None:
|
||||||
"""Add a callback that will be invoked when we receive a connection
|
"""Add a callback that will be invoked when we receive a connection
|
||||||
attempt to the given IP/port using `connectTCP`.
|
attempt to the given IP/port using `connectTCP`.
|
||||||
|
|
||||||
|
@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
"""
|
"""
|
||||||
self._tcp_callbacks[(host, port)] = callback
|
self._tcp_callbacks[(host, port)] = callback
|
||||||
|
|
||||||
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
|
def connectTCP(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
factory: ClientFactory,
|
||||||
|
timeout: float = 30,
|
||||||
|
bindAddress: Optional[Tuple[str, int]] = None,
|
||||||
|
) -> IConnector:
|
||||||
"""Fake L{IReactorTCP.connectTCP}."""
|
"""Fake L{IReactorTCP.connectTCP}."""
|
||||||
|
|
||||||
conn = super().connectTCP(
|
conn = super().connectTCP(
|
||||||
|
@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
|
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def advance(self, amount):
|
def advance(self, amount: float) -> None:
|
||||||
# first advance our reactor's time, and run any "callLater" callbacks that
|
# first advance our reactor's time, and run any "callLater" callbacks that
|
||||||
# makes ready
|
# makes ready
|
||||||
super().advance(amount)
|
super().advance(amount)
|
||||||
|
@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
class ThreadPool:
|
class ThreadPool:
|
||||||
"""
|
"""
|
||||||
Threadless thread pool.
|
Threadless thread pool.
|
||||||
|
|
||||||
|
See twisted.python.threadpool.ThreadPool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reactor):
|
def __init__(self, reactor: IReactorTime):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
|
def callInThreadWithCallback(
|
||||||
def _(res):
|
self,
|
||||||
|
onResult: Callable[[bool, Union[Failure, R]], None],
|
||||||
|
function: Callable[P, R],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> "Deferred[None]":
|
||||||
|
def _(res: Any) -> None:
|
||||||
if isinstance(res, Failure):
|
if isinstance(res, Failure):
|
||||||
onResult(False, res)
|
onResult(False, res)
|
||||||
else:
|
else:
|
||||||
onResult(True, res)
|
onResult(True, res)
|
||||||
|
|
||||||
d = Deferred()
|
d: "Deferred[None]" = Deferred()
|
||||||
d.addCallback(lambda x: function(*args, **kwargs))
|
d.addCallback(lambda x: function(*args, **kwargs))
|
||||||
d.addBoth(_)
|
d.addBoth(_)
|
||||||
self._reactor.callLater(0, d.callback, True)
|
self._reactor.callLater(0, d.callback, True)
|
||||||
|
@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
|
||||||
for database in server.get_datastores().databases:
|
for database in server.get_datastores().databases:
|
||||||
pool = database._db_pool
|
pool = database._db_pool
|
||||||
|
|
||||||
def runWithConnection(func, *args, **kwargs):
|
def runWithConnection(
|
||||||
|
func: Callable[..., R], *args: Any, **kwargs: Any
|
||||||
|
) -> Awaitable[R]:
|
||||||
return threads.deferToThreadPool(
|
return threads.deferToThreadPool(
|
||||||
pool._reactor,
|
pool._reactor,
|
||||||
pool.threadpool,
|
pool.threadpool,
|
||||||
|
@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def runInteraction(interaction, *args, **kwargs):
|
def runInteraction(
|
||||||
|
desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||||
|
) -> Awaitable[R]:
|
||||||
return threads.deferToThreadPool(
|
return threads.deferToThreadPool(
|
||||||
pool._reactor,
|
pool._reactor,
|
||||||
pool.threadpool,
|
pool.threadpool,
|
||||||
pool._runInteraction,
|
pool._runInteraction,
|
||||||
interaction,
|
desc,
|
||||||
|
func,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pool.runWithConnection = runWithConnection
|
pool.runWithConnection = runWithConnection # type: ignore[assignment]
|
||||||
pool.runInteraction = runInteraction
|
pool.runInteraction = runInteraction # type: ignore[assignment]
|
||||||
# Replace the thread pool with a threadless 'thread' pool
|
# Replace the thread pool with a threadless 'thread' pool
|
||||||
pool.threadpool = ThreadPool(clock._reactor)
|
pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
|
||||||
pool.running = True
|
pool.running = True
|
||||||
|
|
||||||
# We've just changed the Databases to run DB transactions on the same
|
# We've just changed the Databases to run DB transactions on the same
|
||||||
|
@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
|
||||||
|
|
||||||
|
|
||||||
@implementer(ITransport)
|
@implementer(ITransport)
|
||||||
@attr.s(cmp=False)
|
@attr.s(cmp=False, auto_attribs=True)
|
||||||
class FakeTransport:
|
class FakeTransport:
|
||||||
"""
|
"""
|
||||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||||
|
@ -588,48 +663,50 @@ class FakeTransport:
|
||||||
If you want bidirectional communication, you'll need two instances.
|
If you want bidirectional communication, you'll need two instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
other = attr.ib()
|
other: IProtocol
|
||||||
"""The Protocol object which will receive any data written to this transport.
|
"""The Protocol object which will receive any data written to this transport.
|
||||||
|
|
||||||
:type: twisted.internet.interfaces.IProtocol
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_reactor = attr.ib()
|
_reactor: IReactorTime
|
||||||
"""Test reactor
|
"""Test reactor
|
||||||
|
|
||||||
:type: twisted.internet.interfaces.IReactorTime
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_protocol = attr.ib(default=None)
|
_protocol: Optional[IProtocol] = None
|
||||||
"""The Protocol which is producing data for this transport. Optional, but if set
|
"""The Protocol which is producing data for this transport. Optional, but if set
|
||||||
will get called back for connectionLost() notifications etc.
|
will get called back for connectionLost() notifications etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
_peer_address: IAddress = attr.Factory(
|
||||||
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
|
||||||
|
)
|
||||||
"""The value to be returned by getPeer"""
|
"""The value to be returned by getPeer"""
|
||||||
|
|
||||||
_host_address: Optional[IAddress] = attr.ib(default=None)
|
_host_address: IAddress = attr.Factory(
|
||||||
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
|
||||||
|
)
|
||||||
"""The value to be returned by getHost"""
|
"""The value to be returned by getHost"""
|
||||||
|
|
||||||
disconnecting = False
|
disconnecting = False
|
||||||
disconnected = False
|
disconnected = False
|
||||||
connected = True
|
connected = True
|
||||||
buffer = attr.ib(default=b"")
|
buffer: bytes = b""
|
||||||
producer = attr.ib(default=None)
|
producer: Optional[IPushProducer] = None
|
||||||
autoflush = attr.ib(default=True)
|
autoflush: bool = True
|
||||||
|
|
||||||
def getPeer(self) -> Optional[IAddress]:
|
def getPeer(self) -> IAddress:
|
||||||
return self._peer_address
|
return self._peer_address
|
||||||
|
|
||||||
def getHost(self) -> Optional[IAddress]:
|
def getHost(self) -> IAddress:
|
||||||
return self._host_address
|
return self._host_address
|
||||||
|
|
||||||
def loseConnection(self, reason=None):
|
def loseConnection(self) -> None:
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
logger.info("FakeTransport: loseConnection(%s)", reason)
|
logger.info("FakeTransport: loseConnection()")
|
||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
if self._protocol:
|
if self._protocol:
|
||||||
self._protocol.connectionLost(reason)
|
self._protocol.connectionLost(
|
||||||
|
Failure(RuntimeError("FakeTransport.loseConnection()"))
|
||||||
|
)
|
||||||
|
|
||||||
# if we still have data to write, delay until that is done
|
# if we still have data to write, delay until that is done
|
||||||
if self.buffer:
|
if self.buffer:
|
||||||
|
@ -640,38 +717,38 @@ class FakeTransport:
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def abortConnection(self):
|
def abortConnection(self) -> None:
|
||||||
logger.info("FakeTransport: abortConnection()")
|
logger.info("FakeTransport: abortConnection()")
|
||||||
|
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
if self._protocol:
|
if self._protocol:
|
||||||
self._protocol.connectionLost(None)
|
self._protocol.connectionLost(None) # type: ignore[arg-type]
|
||||||
|
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.producer.pauseProducing()
|
self.producer.pauseProducing()
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
self.producer.resumeProducing()
|
self.producer.resumeProducing()
|
||||||
|
|
||||||
def unregisterProducer(self):
|
def unregisterProducer(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.producer = None
|
self.producer = None
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
|
||||||
self.producer = producer
|
self.producer = producer
|
||||||
self.producerStreaming = streaming
|
self.producerStreaming = streaming
|
||||||
|
|
||||||
def _produce():
|
def _produce() -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
# we've been unregistered
|
# we've been unregistered
|
||||||
return
|
return
|
||||||
|
@ -683,7 +760,7 @@ class FakeTransport:
|
||||||
if not streaming:
|
if not streaming:
|
||||||
self._reactor.callLater(0.0, _produce)
|
self._reactor.callLater(0.0, _produce)
|
||||||
|
|
||||||
def write(self, byt):
|
def write(self, byt: bytes) -> None:
|
||||||
if self.disconnecting:
|
if self.disconnecting:
|
||||||
raise Exception("Writing to disconnecting FakeTransport")
|
raise Exception("Writing to disconnecting FakeTransport")
|
||||||
|
|
||||||
|
@ -695,11 +772,11 @@ class FakeTransport:
|
||||||
if self.autoflush:
|
if self.autoflush:
|
||||||
self._reactor.callLater(0.0, self.flush)
|
self._reactor.callLater(0.0, self.flush)
|
||||||
|
|
||||||
def writeSequence(self, seq):
|
def writeSequence(self, seq: Iterable[bytes]) -> None:
|
||||||
for x in seq:
|
for x in seq:
|
||||||
self.write(x)
|
self.write(x)
|
||||||
|
|
||||||
def flush(self, maxbytes=None):
|
def flush(self, maxbytes: Optional[int] = None) -> None:
|
||||||
if not self.buffer:
|
if not self.buffer:
|
||||||
# nothing to do. Don't write empty buffers: it upsets the
|
# nothing to do. Don't write empty buffers: it upsets the
|
||||||
# TLSMemoryBIOProtocol
|
# TLSMemoryBIOProtocol
|
||||||
|
@ -750,17 +827,17 @@ def connect_client(
|
||||||
|
|
||||||
|
|
||||||
class TestHomeServer(HomeServer):
|
class TestHomeServer(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore
|
DATASTORE_CLASS = DataStore # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def setup_test_homeserver(
|
def setup_test_homeserver(
|
||||||
cleanup_func,
|
cleanup_func: Callable[[Callable[[], None]], None],
|
||||||
name="test",
|
name: str = "test",
|
||||||
config=None,
|
config: Optional[HomeServerConfig] = None,
|
||||||
reactor=None,
|
reactor: Optional[ISynapseReactor] = None,
|
||||||
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
) -> HomeServer:
|
||||||
"""
|
"""
|
||||||
Setup a homeserver suitable for running tests against. Keyword arguments
|
Setup a homeserver suitable for running tests against. Keyword arguments
|
||||||
are passed to the Homeserver constructor.
|
are passed to the Homeserver constructor.
|
||||||
|
@ -775,13 +852,14 @@ def setup_test_homeserver(
|
||||||
HomeserverTestCase.
|
HomeserverTestCase.
|
||||||
"""
|
"""
|
||||||
if reactor is None:
|
if reactor is None:
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor as _reactor
|
||||||
|
|
||||||
|
reactor = cast(ISynapseReactor, _reactor)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = default_config(name, parse=True)
|
config = default_config(name, parse=True)
|
||||||
|
|
||||||
config.caches.resize_all_caches()
|
config.caches.resize_all_caches()
|
||||||
config.ldap_enabled = False
|
|
||||||
|
|
||||||
if "clock" not in kwargs:
|
if "clock" not in kwargs:
|
||||||
kwargs["clock"] = MockClock()
|
kwargs["clock"] = MockClock()
|
||||||
|
@ -832,6 +910,8 @@ def setup_test_homeserver(
|
||||||
# Create the database before we actually try and connect to it, based off
|
# Create the database before we actually try and connect to it, based off
|
||||||
# the template database we generate in setupdb()
|
# the template database we generate in setupdb()
|
||||||
if isinstance(db_engine, PostgresEngine):
|
if isinstance(db_engine, PostgresEngine):
|
||||||
|
import psycopg2.extensions
|
||||||
|
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = db_engine.module.connect(
|
||||||
database=POSTGRES_BASE_DB,
|
database=POSTGRES_BASE_DB,
|
||||||
user=POSTGRES_USER,
|
user=POSTGRES_USER,
|
||||||
|
@ -839,6 +919,7 @@ def setup_test_homeserver(
|
||||||
port=POSTGRES_PORT,
|
port=POSTGRES_PORT,
|
||||||
password=POSTGRES_PASSWORD,
|
password=POSTGRES_PASSWORD,
|
||||||
)
|
)
|
||||||
|
assert isinstance(db_conn, psycopg2.extensions.connection)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
||||||
|
@ -867,14 +948,15 @@ def setup_test_homeserver(
|
||||||
hs.setup_background_tasks()
|
hs.setup_background_tasks()
|
||||||
|
|
||||||
if isinstance(db_engine, PostgresEngine):
|
if isinstance(db_engine, PostgresEngine):
|
||||||
database = hs.get_datastores().databases[0]
|
database_pool = hs.get_datastores().databases[0]
|
||||||
|
|
||||||
# We need to do cleanup on PostgreSQL
|
# We need to do cleanup on PostgreSQL
|
||||||
def cleanup():
|
def cleanup() -> None:
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
import psycopg2.extensions
|
||||||
|
|
||||||
# Close all the db pools
|
# Close all the db pools
|
||||||
database._db_pool.close()
|
database_pool._db_pool.close()
|
||||||
|
|
||||||
dropped = False
|
dropped = False
|
||||||
|
|
||||||
|
@ -886,6 +968,7 @@ def setup_test_homeserver(
|
||||||
port=POSTGRES_PORT,
|
port=POSTGRES_PORT,
|
||||||
password=POSTGRES_PASSWORD,
|
password=POSTGRES_PASSWORD,
|
||||||
)
|
)
|
||||||
|
assert isinstance(db_conn, psycopg2.extensions.connection)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
|
|
||||||
|
@ -918,23 +1001,23 @@ def setup_test_homeserver(
|
||||||
# Need to let the HS build an auth handler and then mess with it
|
# Need to let the HS build an auth handler and then mess with it
|
||||||
# because AuthHandler's constructor requires the HS, so we can't make one
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||||
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||||
async def hash(p):
|
async def hash(p: str) -> str:
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest()
|
return hashlib.md5(p.encode("utf8")).hexdigest()
|
||||||
|
|
||||||
hs.get_auth_handler().hash = hash
|
hs.get_auth_handler().hash = hash # type: ignore[assignment]
|
||||||
|
|
||||||
async def validate_hash(p, h):
|
async def validate_hash(p: str, h: str) -> bool:
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
||||||
|
|
||||||
hs.get_auth_handler().validate_hash = validate_hash
|
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
|
||||||
|
|
||||||
# Make the threadpool and database transactions synchronous for testing.
|
# Make the threadpool and database transactions synchronous for testing.
|
||||||
_make_test_homeserver_synchronous(hs)
|
_make_test_homeserver_synchronous(hs)
|
||||||
|
|
||||||
# Load any configured modules into the homeserver
|
# Load any configured modules into the homeserver
|
||||||
module_api = hs.get_module_api()
|
module_api = hs.get_module_api()
|
||||||
for module, config in hs.config.modules.loaded_modules:
|
for module, module_config in hs.config.modules.loaded_modules:
|
||||||
module(config=config, api=module_api)
|
module(config=module_config, api=module_api)
|
||||||
|
|
||||||
load_legacy_spam_checkers(hs)
|
load_legacy_spam_checkers(hs)
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
|
|
|
@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.python.threadpool import ThreadPool
|
from twisted.python.threadpool import ThreadPool
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -82,7 +82,7 @@ from tests.server import (
|
||||||
)
|
)
|
||||||
from tests.test_utils import event_injection, setup_awaitable_errors
|
from tests.test_utils import event_injection, setup_awaitable_errors
|
||||||
from tests.test_utils.logging_setup import setup_logging
|
from tests.test_utils.logging_setup import setup_logging
|
||||||
from tests.utils import default_config, setupdb
|
from tests.utils import checked_cast, default_config, setupdb
|
||||||
|
|
||||||
setupdb()
|
setupdb()
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
from tests.rest.client.utils import RestHelper
|
from tests.rest.client.utils import RestHelper
|
||||||
|
|
||||||
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
|
self.helper = RestHelper(
|
||||||
|
self.hs,
|
||||||
|
checked_cast(MemoryReactorClock, self.hs.get_reactor()),
|
||||||
|
self.site,
|
||||||
|
getattr(self, "user_id", None),
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(self, "user_id"):
|
if hasattr(self, "user_id"):
|
||||||
if self.hijack_auth:
|
if self.hijack_auth:
|
||||||
|
|
Loading…
Reference in New Issue