Use StrCollection in additional places. (#16301)
This commit is contained in:
parent
e9addf6a01
commit
d38d0dffc9
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
|
@ -27,9 +27,7 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
|
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
|
|||
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
||||
load_legacy_third_party_event_rules,
|
||||
)
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.types import ISynapseReactor, StrCollection
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
||||
from synapse.util.daemonize import daemonize_process
|
||||
|
@ -278,7 +276,7 @@ def register_start(
|
|||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||
|
||||
|
||||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
||||
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
|
||||
"""
|
||||
Start Prometheus metrics server.
|
||||
"""
|
||||
|
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
|
|||
|
||||
|
||||
def listen_manhole(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
manhole_settings: ManholeConfig,
|
||||
manhole_globals: dict,
|
||||
|
@ -339,7 +337,7 @@ def listen_manhole(
|
|||
|
||||
|
||||
def listen_tcp(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
reactor: IReactorTCP = reactor,
|
||||
|
@ -448,7 +446,7 @@ def listen_http(
|
|||
|
||||
|
||||
def listen_ssl(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
context_factory: IOpenSSLContextFactory,
|
||||
|
|
|
@ -26,7 +26,6 @@ from textwrap import dedent
|
|||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
|
@ -384,7 +383,7 @@ class RootConfig:
|
|||
|
||||
config_classes: List[Type[Config]] = []
|
||||
|
||||
def __init__(self, config_files: Collection[str] = ()):
|
||||
def __init__(self, config_files: StrSequence = ()):
|
||||
# Capture absolute paths here, so we can reload config after we daemonize.
|
||||
self.config_files = [os.path.abspath(path) for path in config_files]
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
|||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
|
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
def keys(self) -> Iterable[str]:
|
||||
return self._dict.keys()
|
||||
|
||||
def prev_event_ids(self) -> Sequence[str]:
|
||||
def prev_event_ids(self) -> List[str]:
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
|
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
|
|||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||
return self._event_id
|
||||
|
||||
def prev_event_ids(self) -> Sequence[str]:
|
||||
def prev_event_ids(self) -> List[str]:
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from signedjson.types import SigningKey
|
||||
|
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
|
|||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import EventID, JsonDict
|
||||
from synapse.types import EventID, JsonDict, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -103,7 +103,7 @@ class EventBuilder:
|
|||
|
||||
async def build(
|
||||
self,
|
||||
prev_event_ids: Collection[str],
|
||||
prev_event_ids: StrCollection,
|
||||
auth_event_ids: Optional[List[str]],
|
||||
depth: Optional[int] = None,
|
||||
) -> EventBase:
|
||||
|
@ -136,7 +136,7 @@ class EventBuilder:
|
|||
|
||||
format_version = self.room_version.event_format
|
||||
# The types of auth/prev events changes between event versions.
|
||||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
|
||||
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
|
||||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
if format_version == EventFormatVersions.ROOM_V1_V2:
|
||||
auth_events = await self._store.add_event_hashes(auth_event_ids)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
from typing import Iterable, List, Type, Union, cast
|
||||
from typing import List, Type, Union, cast
|
||||
|
||||
import jsonschema
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
|
@ -36,7 +36,7 @@ from synapse.events.utils import (
|
|||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.http.servlet import validate_json_object
|
||||
from synapse.rest.models import RequestBodyModel
|
||||
from synapse.types import EventID, JsonDict, RoomID, UserID
|
||||
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
|
||||
|
||||
|
||||
class EventValidator:
|
||||
|
@ -225,7 +225,7 @@ class EventValidator:
|
|||
|
||||
self._ensure_state_event(event)
|
||||
|
||||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
|
||||
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
|
||||
for s in keys:
|
||||
if s not in d:
|
||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||
|
|
|
@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
|
|||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.types import ISynapseReactor, StrSequence
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
|
||||
|
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
|
|||
# the value actually has to be a List, but List is invariant so we can't specify that
|
||||
# the entries can either be Lists or bytes.
|
||||
RawHeaderValue = Union[
|
||||
List[str],
|
||||
StrSequence,
|
||||
List[bytes],
|
||||
List[Union[str, bytes]],
|
||||
Tuple[str, ...],
|
||||
Tuple[bytes, ...],
|
||||
Tuple[Union[str, bytes], ...],
|
||||
]
|
||||
|
|
|
@ -18,7 +18,6 @@ import logging
|
|||
from http import HTTPStatus
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
|
@ -38,7 +37,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http import redact_uri
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -340,7 +339,7 @@ def parse_string(
|
|||
name: str,
|
||||
default: str,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -352,7 +351,7 @@ def parse_string(
|
|||
name: str,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -365,7 +364,7 @@ def parse_string(
|
|||
*,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -376,7 +375,7 @@ def parse_string(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
|
@ -485,7 +484,7 @@ def parse_enum(
|
|||
|
||||
def _parse_string_value(
|
||||
value: bytes,
|
||||
allowed_values: Optional[Iterable[str]],
|
||||
allowed_values: Optional[StrCollection],
|
||||
name: str,
|
||||
encoding: str,
|
||||
) -> str:
|
||||
|
@ -511,7 +510,7 @@ def parse_strings_from_args(
|
|||
args: Mapping[bytes, Sequence[bytes]],
|
||||
name: str,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
...
|
||||
|
@ -523,7 +522,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
default: List[str],
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> List[str]:
|
||||
...
|
||||
|
@ -535,7 +534,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> List[str]:
|
||||
...
|
||||
|
@ -548,7 +547,7 @@ def parse_strings_from_args(
|
|||
default: Optional[List[str]] = None,
|
||||
*,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
...
|
||||
|
@ -559,7 +558,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
default: Optional[List[str]] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
|
@ -610,7 +609,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -623,7 +622,7 @@ def parse_string_from_args(
|
|||
default: Optional[str] = None,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -635,7 +634,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -646,7 +645,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
|
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
|
|||
return validate_json_object(content, model_type)
|
||||
|
||||
|
||||
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
|
||||
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
|||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
|
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
|
|||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
|
||||
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
|
||||
from synapse.metrics._types import Collector
|
||||
from synapse.types import StrSequence
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -81,7 +81,7 @@ class LaterGauge(Collector):
|
|||
|
||||
name: str
|
||||
desc: str
|
||||
labels: Optional[Sequence[str]] = attr.ib(hash=False)
|
||||
labels: Optional[StrSequence] = attr.ib(hash=False)
|
||||
# callback: should either return a value (if there are no labels for this metric),
|
||||
# or dict mapping from a label tuple to a value
|
||||
caller: Callable[
|
||||
|
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
|
|||
self,
|
||||
name: str,
|
||||
desc: str,
|
||||
labels: Sequence[str],
|
||||
sub_metrics: Sequence[str],
|
||||
labels: StrSequence,
|
||||
sub_metrics: StrSequence,
|
||||
):
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
|
|
|
@ -104,7 +104,7 @@ class _NotifierUserStream:
|
|||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
rooms: Collection[str],
|
||||
rooms: StrCollection,
|
||||
current_token: StreamToken,
|
||||
time_now_ms: int,
|
||||
):
|
||||
|
@ -457,7 +457,7 @@ class Notifier:
|
|||
stream_key: str,
|
||||
new_token: Union[int, RoomStreamToken],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
rooms: Optional[Collection[str]] = None,
|
||||
rooms: Optional[StrCollection] = None,
|
||||
) -> None:
|
||||
"""Used to inform listeners that something has happened event wise.
|
||||
|
||||
|
@ -529,7 +529,7 @@ class Notifier:
|
|||
user_id: str,
|
||||
timeout: int,
|
||||
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
|
||||
room_ids: Optional[Collection[str]] = None,
|
||||
room_ids: Optional[StrCollection] = None,
|
||||
from_token: StreamToken = StreamToken.START,
|
||||
) -> T:
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
|
|
|
@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
|
|||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_patterns(
|
||||
path_regex: str,
|
||||
releases: Iterable[str] = ("r0", "v3"),
|
||||
releases: StrCollection = ("r0", "v3"),
|
||||
unstable: bool = True,
|
||||
v1: bool = False,
|
||||
) -> Iterable[Pattern]:
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
|
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
|
|||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import StateMap, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -197,7 +196,7 @@ class StateHandler:
|
|||
async def compute_state_after_events(
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
event_ids: StrCollection,
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
await_full_state: bool = True,
|
||||
) -> StateMap[str]:
|
||||
|
@ -231,7 +230,7 @@ class StateHandler:
|
|||
return await ret.get_state(self._state_storage_controller, state_filter)
|
||||
|
||||
async def get_current_user_ids_in_room(
|
||||
self, room_id: str, latest_event_ids: Collection[str]
|
||||
self, room_id: str, latest_event_ids: StrCollection
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the users IDs who are currently in a room.
|
||||
|
@ -256,7 +255,7 @@ class StateHandler:
|
|||
return await self.store.get_joined_user_ids_from_state(room_id, state)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
self, room_id: str, event_ids: StrCollection
|
||||
) -> FrozenSet[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
|
@ -470,7 +469,7 @@ class StateHandler:
|
|||
@trace
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
|
||||
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
@ -882,7 +881,7 @@ class StateResolutionStore:
|
|||
store: "DataStore"
|
||||
|
||||
def get_events(
|
||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
"""Get events from the database
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
|
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
|
|||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,7 +44,7 @@ async def resolve_events_with_store(
|
|||
room_version: RoomVersion,
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
|
||||
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
|
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
|
|||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
|
|||
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
|
||||
# TestStateResolutionStore in tests.
|
||||
def get_events(
|
||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
...
|
||||
|
||||
|
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
|
|||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
|
||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
||||
|
||||
auth_difference_unpersisted_part: Collection[str] = union - intersection
|
||||
auth_difference_unpersisted_part: StrCollection = union - intersection
|
||||
else:
|
||||
auth_difference_unpersisted_part = ()
|
||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||
|
|
|
@ -47,7 +47,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.types import JsonDict, StrCollection, StrSequence
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
@cached(max_entries=5000, iterable=True)
|
||||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
|
||||
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="event_forward_extremities",
|
||||
keyvalues={"room_id": room_id},
|
||||
|
|
|
@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
|
|||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
||||
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -150,12 +150,12 @@ async def filter_events_for_client(
|
|||
|
||||
async def filter_event_for_clients_with_state(
|
||||
store: DataStore,
|
||||
user_ids: Collection[str],
|
||||
user_ids: StrCollection,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
is_peeking: bool = False,
|
||||
filter_send_to_client: bool = True,
|
||||
) -> Collection[str]:
|
||||
) -> StrCollection:
|
||||
"""
|
||||
Checks to see if an event is visible to the users in the list at the time of
|
||||
the event.
|
||||
|
|
Loading…
Reference in New Issue