Use StrCollection in place of Collection[str] in (most) handlers code. (#14922)

Due to the increased safety of StrCollection over Collection[str]
and Sequence[str].
This commit is contained in:
Patrick Cloke 2023-01-26 12:31:58 -05:00 committed by GitHub
parent dc901a885f
commit ba79fb4a61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 43 additions and 58 deletions

1
changelog.d/14922.misc Normal file
View File

@ -0,0 +1 @@
Use `StrCollection` to avoid potential bugs with `Collection[str]`.

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
@ -26,7 +26,7 @@ from synapse.replication.http.account_data import (
ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
from synapse.types import JsonDict, StreamKeyType, UserID
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
user: UserID,
from_key: int,
limit: int,
room_ids: Collection[str],
room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:

View File

@ -18,7 +18,6 @@ from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
StrCollection,
StreamKeyType,
StreamToken,
UserID,
@ -146,7 +146,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
async def notify_device_update(
self, user_id: str, device_ids: Collection[str]
self, user_id: str, device_ids: StrCollection
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
from typing import TYPE_CHECKING, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
@ -29,7 +29,7 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, get_domain_from_id
from synapse.types import StateMap, StrCollection, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -290,7 +290,7 @@ class EventAuthHandler:
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
) -> Collection[str]:
) -> StrCollection:
"""
Generate a list of rooms in which membership allows access to a room.
@ -331,7 +331,7 @@ class EventAuthHandler:
return result
async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool:
"""
Check whether a user is a member of any of the provided rooms.

View File

@ -20,17 +20,7 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Histogram
@ -70,7 +60,7 @@ from synapse.replication.http.federation import (
)
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@ -179,7 +169,7 @@ class FederationHandler:
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
str, Tuple[Optional[str], Collection[str]]
str, Tuple[Optional[str], StrCollection]
] = {}
# A lock guarding the partial state flag for rooms.
# When the lock is held for a given room, no other concurrent code may
@ -437,7 +427,7 @@ class FederationHandler:
)
)
async def try_backfill(domains: Collection[str]) -> bool:
async def try_backfill(domains: StrCollection) -> bool:
# TODO: Should we try multiple of these at a time?
# Number of contacted remote homeservers that have denied our backfill
@ -1730,7 +1720,7 @@ class FederationHandler:
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
other_destinations: Collection[str],
other_destinations: StrCollection,
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
@ -1812,7 +1802,7 @@ class FederationHandler:
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
other_destinations: Collection[str],
other_destinations: StrCollection,
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
@ -1949,9 +1939,9 @@ class FederationHandler:
def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
other_destinations: Collection[str],
other_destinations: StrCollection,
room_id: str,
) -> Collection[str]:
) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.
If an `initial_destination` is given, it takes top priority. Otherwise

View File

@ -80,6 +80,7 @@ from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
StateMap,
StrCollection,
UserID,
get_domain_from_id,
)
@ -615,7 +616,7 @@ class FederationEventHandler:
@trace
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
self, dest: str, room_id: str, limit: int, extremities: StrCollection
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@ -1565,7 +1566,7 @@ class FederationEventHandler:
@trace
@tag_args
async def _get_events_and_persist(
self, destination: str, room_id: str, event_ids: Collection[str]
self, destination: str, room_id: str, event_ids: StrCollection
) -> None:
"""Fetch the given events from a server, and persist them as outliers.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set
import attr
@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamKeyType
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType
from synapse.types.state import StateFilter
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
@ -391,7 +391,7 @@ class PaginationHandler:
"""
return self._delete_by_id.get(delete_id)
def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]:
def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]:
"""Get all active delete ids by room
Args:

View File

@ -20,16 +20,7 @@ import random
import string
from collections import OrderedDict
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Tuple,
)
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
import attr
from typing_extensions import TypedDict
@ -72,6 +63,7 @@ from synapse.types import (
RoomID,
RoomStreamToken,
StateMap,
StrCollection,
StreamKeyType,
StreamToken,
UserID,
@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
user: UserID,
from_key: RoomStreamToken,
limit: int,
room_ids: Collection[str],
room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:

View File

@ -36,7 +36,7 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.types import JsonDict, Requester
from synapse.types import JsonDict, Requester, StrCollection
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@ -870,7 +870,7 @@ class _RoomQueueEntry:
# The room ID of this entry.
room_id: str
# The server to query if the room is not known locally.
via: Sequence[str]
via: StrCollection
# The minimum number of hops necessary to get to this room (compared to the
# originally requested room).
depth: int = 0

View File

@ -14,7 +14,7 @@
import itertools
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from unpaddedbase64 import decode_base64, encode_base64
@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.types import JsonDict, StreamKeyType, UserID
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
@ -418,7 +418,7 @@ class SearchHandler:
async def _search_by_rank(
self,
user: UserID,
room_ids: Collection[str],
room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,
@ -491,7 +491,7 @@ class SearchHandler:
async def _search_by_recent(
self,
user: UserID,
room_ids: Collection[str],
room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,

View File

@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import (
JsonDict,
StrCollection,
UserID,
contains_invalid_mxid_characters,
create_requester,
@ -141,7 +141,8 @@ class UserAttributes:
confirm_localpart: bool = False
display_name: Optional[str] = None
picture: Optional[str] = None
emails: Collection[str] = attr.Factory(list)
# mypy thinks these are incompatible for some reason.
emails: StrCollection = attr.Factory(list) # type: ignore[assignment]
@attr.s(slots=True, auto_attribs=True)
@ -159,7 +160,7 @@ class UsernameMappingSession:
# attributes returned by the ID mapper
display_name: Optional[str]
emails: Collection[str]
emails: StrCollection
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
@ -174,7 +175,7 @@ class UsernameMappingSession:
# choices made by the user
chosen_localpart: Optional[str] = None
use_display_name: bool = True
emails_to_use: Collection[str] = ()
emails_to_use: StrCollection = ()
terms_accepted_version: Optional[str] = None

View File

@ -17,7 +17,6 @@ from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Collection,
Dict,
FrozenSet,
List,
@ -62,6 +61,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
StrCollection,
StreamKeyType,
StreamToken,
UserID,
@ -1179,7 +1179,7 @@ class SyncHandler:
async def _find_missing_partial_state_memberships(
self,
room_id: str,
members_to_fetch: Collection[str],
members_to_fetch: StrCollection,
events_with_membership_auth: Mapping[str, EventBase],
found_state_ids: StateMap[str],
) -> StateMap[str]:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, List, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet):
raise UnrecognizedRequestError()
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
def _rule_spec_from_path(path: List[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args: