Add missing type hints to synapse.util (#9982)

This commit is contained in:
Patrick Cloke 2021-05-24 15:32:01 -04:00 committed by GitHub
parent 22a8838f62
commit 7adcb20fc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 39 additions and 25 deletions

1
changelog.d/9982.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to `synapse.util` module.

View File

@ -71,8 +71,13 @@ files =
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py, synapse/util/async_helpers.py,
synapse/util/caches, synapse/util/caches,
synapse/util/daemonize.py,
synapse/util/hash.py,
synapse/util/iterutils.py,
synapse/util/metrics.py, synapse/util/metrics.py,
synapse/util/macaroons.py, synapse/util/macaroons.py,
synapse/util/module_loader.py,
synapse/util/msisdn.py,
synapse/util/stringutils.py, synapse/util/stringutils.py,
synapse/visibility.py, synapse/visibility.py,
tests/replication, tests/replication,
@ -80,6 +85,7 @@ files =
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py, tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py, tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-pymacaroons.*] [mypy-pymacaroons.*]
@ -175,5 +181,8 @@ ignore_missing_imports = True
[mypy-pympler.*] [mypy-pympler.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-phonenumbers.*]
ignore_missing_imports = True
[mypy-ijson.*] [mypy-ijson.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -164,7 +164,13 @@ class SAML2Config(Config):
config_path = saml2_config.get("config_path", None) config_path = saml2_config.get("config_path", None)
if config_path is not None: if config_path is not None:
mod = load_python_module(config_path) mod = load_python_module(config_path)
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict) config = getattr(mod, "CONFIG", None)
if config is None:
raise ConfigError(
"Config path specified by saml2_config.config_path does not "
"have a CONFIG property."
)
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
import saml2.config import saml2.config

View File

@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
""" """
keys = {} keys = {}
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" """Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch) # batch_iter always returns tuples so it's safe to do len(batch)

View File

@ -17,15 +17,15 @@ import hashlib
import unpaddedbase64 import unpaddedbase64
def sha256_and_url_safe_base64(input_text): def sha256_and_url_safe_base64(input_text: str) -> str:
"""SHA256 hash an input string, encode the digest as url-safe base64, and """SHA256 hash an input string, encode the digest as url-safe base64, and
return return
:param input_text: string to hash Args:
:type input_text: str input_text: string to hash
:returns a sha256 hashed and url-safe base64 encoded digest returns:
:rtype: str A sha256 hashed and url-safe base64 encoded digest
""" """
digest = hashlib.sha256(input_text.encode()).digest() digest = hashlib.sha256(input_text.encode()).digest()
return unpaddedbase64.encode_base64(digest, urlsafe=True) return unpaddedbase64.encode_base64(digest, urlsafe=True)

View File

@ -30,12 +30,12 @@ from typing import (
T = TypeVar("T") T = TypeVar("T")
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
"""batch an iterable up into tuples with a maximum size """batch an iterable up into tuples with a maximum size
Args: Args:
iterable (iterable): the iterable to slice iterable: the iterable to slice
size (int): the maximum batch size size: the maximum batch size
Returns: Returns:
an iterator over the chunks an iterator over the chunks
@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
return iter(lambda: tuple(islice(sourceiter, size)), ()) return iter(lambda: tuple(islice(sourceiter, size)), ())
ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
"""Split the given sequence into chunks of the given size """Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size. The last chunk may be shorter than the given size.

View File

@ -15,6 +15,7 @@
import importlib import importlib
import importlib.util import importlib.util
import itertools import itertools
from types import ModuleType
from typing import Any, Iterable, Tuple, Type from typing import Any, Iterable, Tuple, Type
import jsonschema import jsonschema
@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
# We need to import the module, and then pick the class out of # We need to import the module, and then pick the class out of
# that, so we split based on the last dot. # that, so we split based on the last dot.
module, clz = modulename.rsplit(".", 1) module_name, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module) module = importlib.import_module(module_name)
provider_class = getattr(module, clz) provider_class = getattr(module, clz)
# Load the module config. If None, pass an empty dictionary instead # Load the module config. If None, pass an empty dictionary instead
@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
return provider_class, provider_config return provider_class, provider_config
def load_python_module(location: str): def load_python_module(location: str) -> ModuleType:
"""Load a python module, and return a reference to its global namespace """Load a python module, and return a reference to its global namespace
Args: Args:
location (str): path to the module location: path to the module
Returns: Returns:
python module object python module object

View File

@ -17,19 +17,19 @@ import phonenumbers
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
def phone_number_to_msisdn(country, number): def phone_number_to_msisdn(country: str, number: str) -> str:
""" """
Takes an ISO-3166-1 2 letter country code and phone number and Takes an ISO-3166-1 2 letter country code and phone number and
returns an msisdn representing the canonical version of that returns an msisdn representing the canonical version of that
phone number. phone number.
Args: Args:
country (str): ISO-3166-1 2 letter country code country: ISO-3166-1 2 letter country code
number (str): Phone number in a national or international format number: Phone number in a national or international format
Returns: Returns:
(str) The canonical form of the phone number, as an msisdn The canonical form of the phone number, as an msisdn
Raises: Raises:
SynapseError if the number could not be parsed. SynapseError if the number could not be parsed.
""" """
try: try:
phoneNumber = phonenumbers.parse(number, country) phoneNumber = phonenumbers.parse(number, country)

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List from typing import Dict, Iterable, List, Sequence
from synapse.util.iterutils import chunk_seq, sorted_topologically from synapse.util.iterutils import chunk_seq, sorted_topologically
@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
) )
def test_empty_input(self): def test_empty_input(self):
parts = chunk_seq([], 5) parts = chunk_seq([], 5) # type: Iterable[Sequence]
self.assertEqual( self.assertEqual(
list(parts), list(parts),