Use StrCollection in additional places. (#16301)

This commit is contained in:
Patrick Cloke 2023-09-13 07:57:19 -04:00 committed by GitHub
parent e9addf6a01
commit d38d0dffc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 59 additions and 67 deletions

1
changelog.d/16301.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -27,9 +27,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Iterable,
List, List,
NoReturn, NoReturn,
Optional, Optional,
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules, load_legacy_third_party_event_rules,
) )
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process from synapse.util.daemonize import daemonize_process
@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
""" """
Start Prometheus metrics server. Start Prometheus metrics server.
""" """
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
def listen_manhole( def listen_manhole(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
manhole_settings: ManholeConfig, manhole_settings: ManholeConfig,
manhole_globals: dict, manhole_globals: dict,
@ -339,7 +337,7 @@ def listen_manhole(
def listen_tcp( def listen_tcp(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
factory: ServerFactory, factory: ServerFactory,
reactor: IReactorTCP = reactor, reactor: IReactorTCP = reactor,
@ -448,7 +446,7 @@ def listen_http(
def listen_ssl( def listen_ssl(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
factory: ServerFactory, factory: ServerFactory,
context_factory: IOpenSSLContextFactory, context_factory: IOpenSSLContextFactory,

View File

@ -26,7 +26,6 @@ from textwrap import dedent
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
Collection,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
@ -384,7 +383,7 @@ class RootConfig:
config_classes: List[Type[Config]] = [] config_classes: List[Type[Config]] = []
def __init__(self, config_files: Collection[str] = ()): def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize. # Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files] self.config_files = [os.path.abspath(path) for path in config_files]

View File

@ -25,7 +25,6 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]: def keys(self) -> Iterable[str]:
return self._dict.keys() return self._dict.keys()
def prev_event_ids(self) -> Sequence[str]: def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id return self._event_id
def prev_event_ids(self) -> Sequence[str]: def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr import attr
from signedjson.types import SigningKey from signedjson.types import SigningKey
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -103,7 +103,7 @@ class EventBuilder:
async def build( async def build(
self, self,
prev_event_ids: Collection[str], prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]], auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
) -> EventBase: ) -> EventBase:
@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2: if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids) auth_events = await self._store.add_event_hashes(auth_event_ids)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
from typing import Iterable, List, Type, Union, cast from typing import List, Type, Union, cast
import jsonschema import jsonschema
from pydantic import Field, StrictBool, StrictStr from pydantic import Field, StrictBool, StrictStr
@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
class EventValidator: class EventValidator:
@ -225,7 +225,7 @@ class EventValidator:
self._ensure_state_event(event) self._ensure_state_event(event)
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys: for s in keys:
if s not in d: if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,)) raise SynapseError(400, "'%s' not in content" % (s,))

View File

@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that # the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes. # the entries can either be Lists or bytes.
RawHeaderValue = Union[ RawHeaderValue = Union[
List[str], StrSequence,
List[bytes], List[bytes],
List[Union[str, bytes]], List[Union[str, bytes]],
Tuple[str, ...],
Tuple[bytes, ...], Tuple[bytes, ...],
Tuple[Union[str, bytes], ...], Tuple[Union[str, bytes], ...],
] ]

View File

@ -18,7 +18,6 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Iterable,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder from synapse.util import json_decoder
if TYPE_CHECKING: if TYPE_CHECKING:
@ -340,7 +339,7 @@ def parse_string(
name: str, name: str,
default: str, default: str,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -352,7 +351,7 @@ def parse_string(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -365,7 +364,7 @@ def parse_string(
*, *,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -376,7 +375,7 @@ def parse_string(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
""" """
@ -485,7 +484,7 @@ def parse_enum(
def _parse_string_value( def _parse_string_value(
value: bytes, value: bytes,
allowed_values: Optional[Iterable[str]], allowed_values: Optional[StrCollection],
name: str, name: str,
encoding: str, encoding: str,
) -> str: ) -> str:
@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]], args: Mapping[bytes, Sequence[bytes]],
name: str, name: str,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
... ...
@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str, name: str,
default: List[str], default: List[str],
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]:
... ...
@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]:
... ...
@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None, default: Optional[List[str]] = None,
*, *,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
... ...
@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str, name: str,
default: Optional[List[str]] = None, default: Optional[List[str]] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
""" """
@ -610,7 +609,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None, default: Optional[str] = None,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -635,7 +634,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -646,7 +645,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
""" """
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type) return validate_json_object(content, model_type)
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = [] absent = []
for k in required: for k in required:
if k not in body: if k not in body:

View File

@ -25,7 +25,6 @@ from typing import (
Iterable, Iterable,
Mapping, Mapping,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector from synapse.metrics._types import Collector
from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,7 +81,7 @@ class LaterGauge(Collector):
name: str name: str
desc: str desc: str
labels: Optional[Sequence[str]] = attr.ib(hash=False) labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric), # callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value # or dict mapping from a label tuple to a value
caller: Callable[ caller: Callable[
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self, self,
name: str, name: str,
desc: str, desc: str,
labels: Sequence[str], labels: StrSequence,
sub_metrics: Sequence[str], sub_metrics: StrSequence,
): ):
self.name = name self.name = name
self.desc = desc self.desc = desc

View File

@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__( def __init__(
self, self,
user_id: str, user_id: str,
rooms: Collection[str], rooms: StrCollection,
current_token: StreamToken, current_token: StreamToken,
time_now_ms: int, time_now_ms: int,
): ):
@ -457,7 +457,7 @@ class Notifier:
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[Collection[str]] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None:
"""Used to inform listeners that something has happened event wise. """Used to inform listeners that something has happened event wise.
@ -529,7 +529,7 @@ class Notifier:
user_id: str, user_id: str,
timeout: int, timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]], callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids: Optional[Collection[str]] = None, room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START, from_token: StreamToken = StreamToken.START,
) -> T: ) -> T:
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the

View File

@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.types import JsonDict from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_patterns( def client_patterns(
path_regex: str, path_regex: str,
releases: Iterable[str] = ("r0", "v3"), releases: StrCollection = ("r0", "v3"),
unstable: bool = True, unstable: bool = True,
v1: bool = False, v1: bool = False,
) -> Iterable[Pattern]: ) -> Iterable[Pattern]:

View File

@ -20,7 +20,6 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet, FrozenSet,
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events( async def compute_state_after_events(
self, self,
room_id: str, room_id: str,
event_ids: Collection[str], event_ids: StrCollection,
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
await_full_state: bool = True, await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter) return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room( async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: Collection[str] self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]: ) -> Set[str]:
""" """
Get the users IDs who are currently in a room. Get the users IDs who are currently in a room.
@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state) return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]: ) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids
@ -470,7 +469,7 @@ class StateHandler:
@trace @trace
@measure_func() @measure_func()
async def resolve_state_groups_for_events( async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry: ) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each """Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore" store: "DataStore"
def get_events( def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]: ) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database """Get events from the database

View File

@ -17,7 +17,6 @@ import logging
from typing import ( from typing import (
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion, room_version: RoomVersion,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Args: Args:

View File

@ -19,7 +19,6 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a # This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests. # TestStateResolutionStore in tests.
def get_events( def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]: ) -> Awaitable[Dict[str, EventBase]]:
... ...
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
auth_difference_unpersisted_part: Collection[str] = union - intersection auth_difference_unpersisted_part: StrCollection = union - intersection
else: else:
auth_difference_unpersisted_part = () auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets] state_sets_ids = [set(state_set.values()) for state_set in state_sets]

View File

@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},

View File

@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import Clock from synapse.util import Clock
@ -150,12 +150,12 @@ async def filter_events_for_client(
async def filter_event_for_clients_with_state( async def filter_event_for_clients_with_state(
store: DataStore, store: DataStore,
user_ids: Collection[str], user_ids: StrCollection,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
is_peeking: bool = False, is_peeking: bool = False,
filter_send_to_client: bool = True, filter_send_to_client: bool = True,
) -> Collection[str]: ) -> StrCollection:
""" """
Checks to see if an event is visible to the users in the list at the time of Checks to see if an event is visible to the users in the list at the time of
the event. the event.