Misc typing fixes for `tests`, part 1 of N (#11323)
* Annotate HomeserverTestCase.servlets * Correct annotation of federation_auth_origin * Use AnyStr custom_headers instead of a Union This allows (str, str) and (bytes, bytes). This disallows (str, bytes) and (bytes, str) * DomainSpecificString.SIGIL is a ClassVar
This commit is contained in:
parent
95547e5300
commit
4c96ce396e
|
@ -0,0 +1 @@
|
|||
Improve type annotations in Synapse's test suite.
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from synapse.http.server import HttpServer, JsonResource
|
||||
from synapse.rest import admin
|
||||
|
@ -62,6 +62,8 @@ from synapse.rest.client import (
|
|||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
|
||||
|
||||
|
||||
class ClientRestResource(JsonResource):
|
||||
"""Matrix Client API REST resource.
|
||||
|
|
|
@ -19,6 +19,7 @@ from collections import namedtuple
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
|
@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
|
|||
'domain' : The domain part of the name
|
||||
"""
|
||||
|
||||
SIGIL: str = abc.abstractproperty() # type: ignore
|
||||
SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
|
||||
|
||||
localpart = attr.ib(type=str)
|
||||
domain = attr.ib(type=str)
|
||||
|
|
|
@ -12,13 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.replication.http import ReplicationRestResource
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
|
@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
unlike `BaseStreamTestCase`.
|
||||
"""
|
||||
|
||||
servlets: List[Callable[[HomeServer, JsonResource], None]] = []
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
|
|
|
@ -19,7 +19,17 @@ import json
|
|||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import attr
|
||||
|
@ -53,9 +63,7 @@ class RestHelper:
|
|||
tok: Optional[str] = None,
|
||||
expect_code: int = 200,
|
||||
extra_content: Optional[Dict] = None,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a room.
|
||||
|
@ -227,9 +235,7 @@ class RestHelper:
|
|||
txn_id=None,
|
||||
tok=None,
|
||||
expect_code=200,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
):
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
@ -418,7 +424,7 @@ class RestHelper:
|
|||
path,
|
||||
content=image_data,
|
||||
access_token=tok,
|
||||
custom_headers=[(b"Content-Length", str(image_length))],
|
||||
custom_headers=[("Content-Length", str(image_length))],
|
||||
)
|
||||
|
||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||
|
|
|
@ -16,7 +16,16 @@ import json
|
|||
import logging
|
||||
from collections import deque
|
||||
from io import SEEK_END, BytesIO
|
||||
from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
|
||||
from typing import (
|
||||
AnyStr,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
|
@ -222,9 +231,7 @@ def make_request(
|
|||
federation_auth_origin: Optional[bytes] = None,
|
||||
content_is_form: bool = False,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""
|
||||
|
|
|
@ -20,7 +20,20 @@ import inspect
|
|||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from canonicaljson import json
|
||||
|
@ -45,6 +58,7 @@ from synapse.logging.context import (
|
|||
current_context,
|
||||
set_current_context,
|
||||
)
|
||||
from synapse.rest import RegisterServletsFunc
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
|
|||
config dict.
|
||||
|
||||
Attributes:
|
||||
servlets (list[function]): List of servlet registration function.
|
||||
servlets: List of servlet registration function.
|
||||
user_id (str): The user ID to assume if auth is hijacked.
|
||||
hijack_auth (bool): Whether to hijack auth to return the user specified
|
||||
in user_id.
|
||||
"""
|
||||
|
||||
servlets = []
|
||||
hijack_auth = True
|
||||
needs_threadpool = False
|
||||
servlets: ClassVar[List[RegisterServletsFunc]] = []
|
||||
|
||||
def __init__(self, methodName, *args, **kwargs):
|
||||
super().__init__(methodName, *args, **kwargs)
|
||||
|
@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
|
|||
access_token: Optional[str] = None,
|
||||
request: Type[T] = SynapseRequest,
|
||||
shorthand: bool = True,
|
||||
federation_auth_origin: str = None,
|
||||
federation_auth_origin: Optional[bytes] = None,
|
||||
content_is_form: bool = False,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""
|
||||
|
@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
|
|||
a dict.
|
||||
shorthand: Whether to try and be helpful and prefix the given URL
|
||||
with the usual REST API path, if it doesn't contain it.
|
||||
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
|
||||
federation_auth_origin: if set to not-None, we will add a fake
|
||||
Authorization header pretenting to be the given server name.
|
||||
content_is_form: Whether the content is URL encoded form data. Adds the
|
||||
'Content-Type': 'application/x-www-form-urlencoded' header.
|
||||
|
@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
|
|||
username,
|
||||
password,
|
||||
device_id=None,
|
||||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
):
|
||||
"""
|
||||
Log in a user, and get an access token. Requires the Login API be
|
||||
|
|
Loading…
Reference in New Issue