Fix additional type hints from Twisted upgrade. (#9518)
This commit is contained in:
parent
4db07f9aef
commit
33a02f0f52
|
@ -0,0 +1 @@
|
||||||
|
Fix incorrect type hints.
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import List, Optional
|
from typing import Any, Generator, List, Optional
|
||||||
|
|
||||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
@ -116,7 +116,7 @@ class MatrixFederationAgent:
|
||||||
uri: bytes,
|
uri: bytes,
|
||||||
headers: Optional[Headers] = None,
|
headers: Optional[Headers] = None,
|
||||||
bodyProducer: Optional[IBodyProducer] = None,
|
bodyProducer: Optional[IBodyProducer] = None,
|
||||||
) -> defer.Deferred:
|
) -> Generator[defer.Deferred, Any, defer.Deferred]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
method: HTTP method: GET/POST/etc
|
method: HTTP method: GET/POST/etc
|
||||||
|
@ -177,17 +177,17 @@ class MatrixFederationAgent:
|
||||||
# We need to make sure the host header is set to the netloc of the
|
# We need to make sure the host header is set to the netloc of the
|
||||||
# server and that a user-agent is provided.
|
# server and that a user-agent is provided.
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = Headers()
|
request_headers = Headers()
|
||||||
else:
|
else:
|
||||||
headers = headers.copy()
|
request_headers = headers.copy()
|
||||||
|
|
||||||
if not headers.hasHeader(b"host"):
|
if not request_headers.hasHeader(b"host"):
|
||||||
headers.addRawHeader(b"host", parsed_uri.netloc)
|
request_headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||||
if not headers.hasHeader(b"user-agent"):
|
if not request_headers.hasHeader(b"user-agent"):
|
||||||
headers.addRawHeader(b"user-agent", self.user_agent)
|
request_headers.addRawHeader(b"user-agent", self.user_agent)
|
||||||
|
|
||||||
res = yield make_deferred_yieldable(
|
res = yield make_deferred_yieldable(
|
||||||
self._agent.request(method, uri, headers, bodyProducer)
|
self._agent.request(method, uri, request_headers, bodyProducer)
|
||||||
)
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
|
||||||
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
||||||
|
|
||||||
"""
|
"""
|
||||||
c_type = headers.getRawHeaders(b"Content-Type")
|
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||||
if c_type is None:
|
if content_type_headers is None:
|
||||||
raise RequestSendFailed(
|
raise RequestSendFailed(
|
||||||
RuntimeError("No Content-Type header received from remote server"),
|
RuntimeError("No Content-Type header received from remote server"),
|
||||||
can_retry=False,
|
can_retry=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
c_type = c_type[0].decode("ascii") # only the first header
|
c_type = content_type_headers[0].decode("ascii") # only the first header
|
||||||
val, options = cgi.parse_header(c_type)
|
val, options = cgi.parse_header(c_type)
|
||||||
if val != "application/json":
|
if val != "application/json":
|
||||||
raise RequestSendFailed(
|
raise RequestSendFailed(
|
||||||
|
|
|
@ -21,6 +21,7 @@ import logging
|
||||||
import types
|
import types
|
||||||
import urllib
|
import urllib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from inspect import isawaitable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -30,6 +31,7 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
|
Optional,
|
||||||
Pattern,
|
Pattern,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||||
"""Sends a JSON error response to clients."""
|
"""Sends a JSON error response to clients."""
|
||||||
|
|
||||||
if f.check(SynapseError):
|
if f.check(SynapseError):
|
||||||
error_code = f.value.code
|
# mypy doesn't understand that f.check asserts the type.
|
||||||
error_dict = f.value.error_dict()
|
exc = f.value # type: SynapseError # type: ignore
|
||||||
|
error_code = exc.code
|
||||||
|
error_dict = exc.error_dict()
|
||||||
|
|
||||||
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||||
else:
|
else:
|
||||||
error_code = 500
|
error_code = 500
|
||||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||||
|
@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||||
"Failed handle request via %r: %r",
|
"Failed handle request via %r: %r",
|
||||||
request.request_metrics.name,
|
request.request_metrics.name,
|
||||||
request,
|
request,
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only respond with an error response if we haven't already started writing,
|
# Only respond with an error response if we haven't already started writing,
|
||||||
|
@ -128,7 +132,8 @@ def return_html_error(
|
||||||
`{msg}` placeholders), or a jinja2 template
|
`{msg}` placeholders), or a jinja2 template
|
||||||
"""
|
"""
|
||||||
if f.check(CodeMessageException):
|
if f.check(CodeMessageException):
|
||||||
cme = f.value
|
# mypy doesn't understand that f.check asserts the type.
|
||||||
|
cme = f.value # type: CodeMessageException # type: ignore
|
||||||
code = cme.code
|
code = cme.code
|
||||||
msg = cme.msg
|
msg = cme.msg
|
||||||
|
|
||||||
|
@ -142,7 +147,7 @@ def return_html_error(
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed handle request %r",
|
"Failed handle request %r",
|
||||||
request,
|
request,
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
@ -151,7 +156,7 @@ def return_html_error(
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed handle request %r",
|
"Failed handle request %r",
|
||||||
request,
|
request,
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(error_template, str):
|
if isinstance(error_template, str):
|
||||||
|
@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||||
raw_callback_return = method_handler(request)
|
raw_callback_return = method_handler(request)
|
||||||
|
|
||||||
# Is it synchronous? We'll allow this for now.
|
# Is it synchronous? We'll allow this for now.
|
||||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
if isawaitable(raw_callback_return):
|
||||||
callback_return = await raw_callback_return
|
callback_return = await raw_callback_return
|
||||||
else:
|
else:
|
||||||
callback_return = raw_callback_return # type: ignore
|
callback_return = raw_callback_return # type: ignore
|
||||||
|
@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
|
||||||
A tuple of the callback to use, the name of the servlet, and the
|
A tuple of the callback to use, the name of the servlet, and the
|
||||||
key word arguments to pass to the callback
|
key word arguments to pass to the callback
|
||||||
"""
|
"""
|
||||||
|
# At this point the path must be bytes.
|
||||||
|
request_path_bytes = request.path # type: bytes # type: ignore
|
||||||
|
request_path = request_path_bytes.decode("ascii")
|
||||||
# Treat HEAD requests as GET requests.
|
# Treat HEAD requests as GET requests.
|
||||||
request_path = request.path.decode("ascii")
|
|
||||||
request_method = request.method
|
request_method = request.method
|
||||||
if request_method == b"HEAD":
|
if request_method == b"HEAD":
|
||||||
request_method = b"GET"
|
request_method = b"GET"
|
||||||
|
@ -551,7 +558,7 @@ class _ByteProducer:
|
||||||
request: Request,
|
request: Request,
|
||||||
iterator: Iterator[bytes],
|
iterator: Iterator[bytes],
|
||||||
):
|
):
|
||||||
self._request = request
|
self._request = request # type: Optional[Request]
|
||||||
self._iterator = iterator
|
self._iterator = iterator
|
||||||
self._paused = False
|
self._paused = False
|
||||||
|
|
||||||
|
@ -563,7 +570,7 @@ class _ByteProducer:
|
||||||
"""
|
"""
|
||||||
Send a list of bytes as a chunk of a response.
|
Send a list of bytes as a chunk of a response.
|
||||||
"""
|
"""
|
||||||
if not data:
|
if not data or not self._request:
|
||||||
return
|
return
|
||||||
self._request.write(b"".join(data))
|
self._request.write(b"".join(data))
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
@ -57,7 +57,7 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
def __init__(self, channel, *args, **kw):
|
def __init__(self, channel, *args, **kw):
|
||||||
Request.__init__(self, channel, *args, **kw)
|
Request.__init__(self, channel, *args, **kw)
|
||||||
self.site = channel.site
|
self.site = channel.site # type: SynapseSite
|
||||||
self._channel = channel # this is used by the tests
|
self._channel = channel # this is used by the tests
|
||||||
self.start_time = 0.0
|
self.start_time = 0.0
|
||||||
|
|
||||||
|
@ -96,25 +96,34 @@ class SynapseRequest(Request):
|
||||||
def get_request_id(self):
|
def get_request_id(self):
|
||||||
return "%s-%i" % (self.get_method(), self.request_seq)
|
return "%s-%i" % (self.get_method(), self.request_seq)
|
||||||
|
|
||||||
def get_redacted_uri(self):
|
def get_redacted_uri(self) -> str:
|
||||||
uri = self.uri
|
"""Gets the redacted URI associated with the request (or placeholder if the URI
|
||||||
|
has not yet been received).
|
||||||
|
|
||||||
|
Note: This is necessary as the placeholder value in twisted is str
|
||||||
|
rather than bytes, so we need to sanitise `self.uri`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The redacted URI as a string.
|
||||||
|
"""
|
||||||
|
uri = self.uri # type: Union[bytes, str]
|
||||||
if isinstance(uri, bytes):
|
if isinstance(uri, bytes):
|
||||||
uri = self.uri.decode("ascii", errors="replace")
|
uri = uri.decode("ascii", errors="replace")
|
||||||
return redact_uri(uri)
|
return redact_uri(uri)
|
||||||
|
|
||||||
def get_method(self):
|
def get_method(self) -> str:
|
||||||
"""Gets the method associated with the request (or placeholder if not
|
"""Gets the method associated with the request (or placeholder if method
|
||||||
method has yet been received).
|
has not yet been received).
|
||||||
|
|
||||||
Note: This is necessary as the placeholder value in twisted is str
|
Note: This is necessary as the placeholder value in twisted is str
|
||||||
rather than bytes, so we need to sanitise `self.method`.
|
rather than bytes, so we need to sanitise `self.method`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str
|
The request method as a string.
|
||||||
"""
|
"""
|
||||||
method = self.method
|
method = self.method # type: Union[bytes, str]
|
||||||
if isinstance(method, bytes):
|
if isinstance(method, bytes):
|
||||||
method = self.method.decode("ascii")
|
return self.method.decode("ascii")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
def render(self, resrc):
|
def render(self, resrc):
|
||||||
|
@ -432,7 +441,9 @@ class SynapseSite(Site):
|
||||||
|
|
||||||
assert config.http_options is not None
|
assert config.http_options is not None
|
||||||
proxied = config.http_options.x_forwarded
|
proxied = config.http_options.x_forwarded
|
||||||
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
|
self.requestFactory = (
|
||||||
|
XForwardedForRequest if proxied else SynapseRequest
|
||||||
|
) # type: Type[Request]
|
||||||
self.access_logger = logging.getLogger(logger_name)
|
self.access_logger = logging.getLogger(logger_name)
|
||||||
self.server_version_string = server_version_string.encode("ascii")
|
self.server_version_string = server_version_string.encode("ascii")
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
|
||||||
TCP4ClientEndpoint,
|
TCP4ClientEndpoint,
|
||||||
TCP6ClientEndpoint,
|
TCP6ClientEndpoint,
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IPushProducer, ITransport
|
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
|
||||||
try:
|
try:
|
||||||
ip = ip_address(self.host)
|
ip = ip_address(self.host)
|
||||||
if isinstance(ip, IPv4Address):
|
if isinstance(ip, IPv4Address):
|
||||||
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
|
endpoint = TCP4ClientEndpoint(
|
||||||
|
_reactor, self.host, self.port
|
||||||
|
) # type: IStreamClientEndpoint
|
||||||
elif isinstance(ip, IPv6Address):
|
elif isinstance(ip, IPv6Address):
|
||||||
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
|
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
|
||||||
REGISTRY.register(ReactorLastSeenMetric())
|
REGISTRY.register(ReactorLastSeenMetric())
|
||||||
|
|
||||||
|
|
||||||
def runUntilCurrentTimer(func):
|
def runUntilCurrentTimer(reactor, func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
now = reactor.seconds()
|
now = reactor.seconds()
|
||||||
|
@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure the reactor has all the attributes we expect
|
# Ensure the reactor has all the attributes we expect
|
||||||
reactor.runUntilCurrent
|
reactor.seconds # type: ignore
|
||||||
reactor._newTimedCalls
|
reactor.runUntilCurrent # type: ignore
|
||||||
reactor.threadCallQueue
|
reactor._newTimedCalls # type: ignore
|
||||||
|
reactor.threadCallQueue # type: ignore
|
||||||
|
|
||||||
# runUntilCurrent is called when we have pending calls. It is called once
|
# runUntilCurrent is called when we have pending calls. It is called once
|
||||||
# per iteratation after fd polling.
|
# per iteratation after fd polling.
|
||||||
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
|
||||||
|
|
||||||
# We manually run the GC each reactor tick so that we can get some metrics
|
# We manually run the GC each reactor tick so that we can get some metrics
|
||||||
# about time spent doing GC,
|
# about time spent doing GC,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -307,7 +307,7 @@ class ModuleApi:
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_events_in_room(
|
def get_state_events_in_room(
|
||||||
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
|
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
|
||||||
) -> defer.Deferred:
|
) -> Generator[defer.Deferred, Any, defer.Deferred]:
|
||||||
"""Gets current state events for the given room.
|
"""Gets current state events for the given room.
|
||||||
|
|
||||||
(This is exposed for compatibility with the old SpamCheckerApi. We should
|
(This is exposed for compatibility with the old SpamCheckerApi. We should
|
||||||
|
|
|
@ -15,11 +15,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -71,7 +72,7 @@ class HttpPusher(Pusher):
|
||||||
self.data = pusher_config.data
|
self.data = pusher_config.data
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.failing_since = pusher_config.failing_since
|
self.failing_since = pusher_config.failing_since
|
||||||
self.timed_call = None
|
self.timed_call = None # type: Optional[IDelayedCall]
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||||
self._pusherpool = hs.get_pusherpool()
|
self._pusherpool = hs.get_pusherpool()
|
||||||
|
|
|
@ -108,9 +108,7 @@ class ReplicationDataHandler:
|
||||||
|
|
||||||
# Map from stream to list of deferreds waiting for the stream to
|
# Map from stream to list of deferreds waiting for the stream to
|
||||||
# arrive at a particular position. The lists are sorted by stream position.
|
# arrive at a particular position. The lists are sorted by stream position.
|
||||||
self._streams_to_waiters = (
|
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
|
||||||
{}
|
|
||||||
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
|
|
||||||
|
|
||||||
async def on_rdata(
|
async def on_rdata(
|
||||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
|
|
|
@ -38,6 +38,7 @@ from typing import (
|
||||||
|
|
||||||
import twisted.internet.base
|
import twisted.internet.base
|
||||||
import twisted.internet.tcp
|
import twisted.internet.tcp
|
||||||
|
from twisted.internet import defer
|
||||||
from twisted.mail.smtp import sendmail
|
from twisted.mail.smtp import sendmail
|
||||||
from twisted.web.iweb import IPolicyForHTTPS
|
from twisted.web.iweb import IPolicyForHTTPS
|
||||||
|
|
||||||
|
@ -403,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
return RoomShutdownHandler(self)
|
return RoomShutdownHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_sendmail(self) -> sendmail:
|
def get_sendmail(self) -> Callable[..., defer.Deferred]:
|
||||||
return sendmail
|
return sendmail
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
|
|
|
@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
cas_uri = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
cas_uri = location_headers[0]
|
||||||
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
|
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
|
||||||
|
|
||||||
# it should redirect us to the login page of the cas server
|
# it should redirect us to the login page of the cas server
|
||||||
|
@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ "&idp=saml",
|
+ "&idp=saml",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
saml_uri = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
saml_uri = location_headers[0]
|
||||||
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
|
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
|
||||||
|
|
||||||
# it should redirect us to the login page of the SAML server
|
# it should redirect us to the login page of the SAML server
|
||||||
|
@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ "&idp=oidc",
|
+ "&idp=oidc",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
oidc_uri = location_headers[0]
|
||||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
|
||||||
# it should redirect us to the auth page of the OIDC server
|
# it should redirect us to the auth page of the OIDC server
|
||||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
# ... and should have set a cookie including the redirect url
|
# ... and should have set a cookie including the redirect url
|
||||||
cookies = dict(
|
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
|
||||||
h.split(";")[0].split("=", maxsplit=1)
|
assert cookie_headers
|
||||||
for h in channel.headers.getRawHeaders("Set-Cookie")
|
cookies = {} # type: Dict[str, str]
|
||||||
)
|
for h in cookie_headers:
|
||||||
|
key, value = h.split(";")[0].split("=", maxsplit=1)
|
||||||
|
cookies[key] = value
|
||||||
|
|
||||||
oidc_session_cookie = cookies["oidc_session"]
|
oidc_session_cookie = cookies["oidc_session"]
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
|
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
|
||||||
|
@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# that should serve a confirmation page
|
# that should serve a confirmation page
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertTrue(
|
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
||||||
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
|
assert content_type_headers
|
||||||
)
|
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
||||||
p = TestHtmlParser()
|
p = TestHtmlParser()
|
||||||
p.feed(channel.text_body)
|
p.feed(channel.text_body)
|
||||||
p.close()
|
p.close()
|
||||||
|
@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(channel.code, 302)
|
self.assertEqual(channel.code, 302)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
||||||
|
|
||||||
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
||||||
|
@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# that should redirect to the username picker
|
# that should redirect to the username picker
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
picker_url = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
picker_url = location_headers[0]
|
||||||
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
|
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
|
||||||
|
|
||||||
# ... with a username_mapping_session cookie
|
# ... with a username_mapping_session cookie
|
||||||
|
@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 302, chan.result)
|
self.assertEqual(chan.code, 302, chan.result)
|
||||||
location_headers = chan.headers.getRawHeaders("Location")
|
location_headers = chan.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
|
||||||
# send a request to the completion page, which should 302 to the client redirectUrl
|
# send a request to the completion page, which should 302 to the client redirectUrl
|
||||||
chan = self.make_request(
|
chan = self.make_request(
|
||||||
|
@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 302, chan.result)
|
self.assertEqual(chan.code, 302, chan.result)
|
||||||
location_headers = chan.headers.getRawHeaders("Location")
|
location_headers = chan.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
|
||||||
# ensure that the returned location matches the requested redirect URL
|
# ensure that the returned location matches the requested redirect URL
|
||||||
path, query = location_headers[0].split("?", 1)
|
path, query = location_headers[0].split("?", 1)
|
||||||
|
|
Loading…
Reference in New Issue