Use StrCollection in additional places. (#16301)
This commit is contained in:
parent
e9addf6a01
commit
d38d0dffc9
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
|
@ -27,9 +27,7 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
|
||||||
List,
|
List,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
|
||||||
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
||||||
load_legacy_third_party_event_rules,
|
load_legacy_third_party_event_rules,
|
||||||
)
|
)
|
||||||
from synapse.types import ISynapseReactor
|
from synapse.types import ISynapseReactor, StrCollection
|
||||||
from synapse.util import SYNAPSE_VERSION
|
from synapse.util import SYNAPSE_VERSION
|
||||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
||||||
from synapse.util.daemonize import daemonize_process
|
from synapse.util.daemonize import daemonize_process
|
||||||
|
@ -278,7 +276,7 @@ def register_start(
|
||||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||||
|
|
||||||
|
|
||||||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
|
||||||
"""
|
"""
|
||||||
Start Prometheus metrics server.
|
Start Prometheus metrics server.
|
||||||
"""
|
"""
|
||||||
|
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
def listen_manhole(
|
def listen_manhole(
|
||||||
bind_addresses: Collection[str],
|
bind_addresses: StrCollection,
|
||||||
port: int,
|
port: int,
|
||||||
manhole_settings: ManholeConfig,
|
manhole_settings: ManholeConfig,
|
||||||
manhole_globals: dict,
|
manhole_globals: dict,
|
||||||
|
@ -339,7 +337,7 @@ def listen_manhole(
|
||||||
|
|
||||||
|
|
||||||
def listen_tcp(
|
def listen_tcp(
|
||||||
bind_addresses: Collection[str],
|
bind_addresses: StrCollection,
|
||||||
port: int,
|
port: int,
|
||||||
factory: ServerFactory,
|
factory: ServerFactory,
|
||||||
reactor: IReactorTCP = reactor,
|
reactor: IReactorTCP = reactor,
|
||||||
|
@ -448,7 +446,7 @@ def listen_http(
|
||||||
|
|
||||||
|
|
||||||
def listen_ssl(
|
def listen_ssl(
|
||||||
bind_addresses: Collection[str],
|
bind_addresses: StrCollection,
|
||||||
port: int,
|
port: int,
|
||||||
factory: ServerFactory,
|
factory: ServerFactory,
|
||||||
context_factory: IOpenSSLContextFactory,
|
context_factory: IOpenSSLContextFactory,
|
||||||
|
|
|
@ -26,7 +26,6 @@ from textwrap import dedent
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Collection,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
|
@ -384,7 +383,7 @@ class RootConfig:
|
||||||
|
|
||||||
config_classes: List[Type[Config]] = []
|
config_classes: List[Type[Config]] = []
|
||||||
|
|
||||||
def __init__(self, config_files: Collection[str] = ()):
|
def __init__(self, config_files: StrSequence = ()):
|
||||||
# Capture absolute paths here, so we can reload config after we daemonize.
|
# Capture absolute paths here, so we can reload config after we daemonize.
|
||||||
self.config_files = [os.path.abspath(path) for path in config_files]
|
self.config_files = [os.path.abspath(path) for path in config_files]
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||||
def keys(self) -> Iterable[str]:
|
def keys(self) -> Iterable[str]:
|
||||||
return self._dict.keys()
|
return self._dict.keys()
|
||||||
|
|
||||||
def prev_event_ids(self) -> Sequence[str]:
|
def prev_event_ids(self) -> List[str]:
|
||||||
"""Returns the list of prev event IDs. The order matches the order
|
"""Returns the list of prev event IDs. The order matches the order
|
||||||
specified in the event, though there is no meaning to it.
|
specified in the event, though there is no meaning to it.
|
||||||
|
|
||||||
|
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
|
||||||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||||
return self._event_id
|
return self._event_id
|
||||||
|
|
||||||
def prev_event_ids(self) -> Sequence[str]:
|
def prev_event_ids(self) -> List[str]:
|
||||||
"""Returns the list of prev event IDs. The order matches the order
|
"""Returns the list of prev event IDs. The order matches the order
|
||||||
specified in the event, though there is no meaning to it.
|
specified in the event, though there is no meaning to it.
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.types import SigningKey
|
from signedjson.types import SigningKey
|
||||||
|
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
|
||||||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||||
from synapse.state import StateHandler
|
from synapse.state import StateHandler
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import EventID, JsonDict
|
from synapse.types import EventID, JsonDict, StrCollection
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -103,7 +103,7 @@ class EventBuilder:
|
||||||
|
|
||||||
async def build(
|
async def build(
|
||||||
self,
|
self,
|
||||||
prev_event_ids: Collection[str],
|
prev_event_ids: StrCollection,
|
||||||
auth_event_ids: Optional[List[str]],
|
auth_event_ids: Optional[List[str]],
|
||||||
depth: Optional[int] = None,
|
depth: Optional[int] = None,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
|
@ -136,7 +136,7 @@ class EventBuilder:
|
||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
# The types of auth/prev events changes between event versions.
|
# The types of auth/prev events changes between event versions.
|
||||||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
|
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
|
||||||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||||
if format_version == EventFormatVersions.ROOM_V1_V2:
|
if format_version == EventFormatVersions.ROOM_V1_V2:
|
||||||
auth_events = await self._store.add_event_hashes(auth_event_ids)
|
auth_events = await self._store.add_event_hashes(auth_event_ids)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
from typing import Iterable, List, Type, Union, cast
|
from typing import List, Type, Union, cast
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
from pydantic import Field, StrictBool, StrictStr
|
from pydantic import Field, StrictBool, StrictStr
|
||||||
|
@ -36,7 +36,7 @@ from synapse.events.utils import (
|
||||||
from synapse.federation.federation_server import server_matches_acl_event
|
from synapse.federation.federation_server import server_matches_acl_event
|
||||||
from synapse.http.servlet import validate_json_object
|
from synapse.http.servlet import validate_json_object
|
||||||
from synapse.rest.models import RequestBodyModel
|
from synapse.rest.models import RequestBodyModel
|
||||||
from synapse.types import EventID, JsonDict, RoomID, UserID
|
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
|
||||||
|
|
||||||
|
|
||||||
class EventValidator:
|
class EventValidator:
|
||||||
|
@ -225,7 +225,7 @@ class EventValidator:
|
||||||
|
|
||||||
self._ensure_state_event(event)
|
self._ensure_state_event(event)
|
||||||
|
|
||||||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
|
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
|
||||||
for s in keys:
|
for s in keys:
|
||||||
if s not in d:
|
if s not in d:
|
||||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||||
|
|
|
@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||||
from synapse.types import ISynapseReactor
|
from synapse.types import ISynapseReactor, StrSequence
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
|
|
||||||
|
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
|
||||||
# the value actually has to be a List, but List is invariant so we can't specify that
|
# the value actually has to be a List, but List is invariant so we can't specify that
|
||||||
# the entries can either be Lists or bytes.
|
# the entries can either be Lists or bytes.
|
||||||
RawHeaderValue = Union[
|
RawHeaderValue = Union[
|
||||||
List[str],
|
StrSequence,
|
||||||
List[bytes],
|
List[bytes],
|
||||||
List[Union[str, bytes]],
|
List[Union[str, bytes]],
|
||||||
Tuple[str, ...],
|
|
||||||
Tuple[bytes, ...],
|
Tuple[bytes, ...],
|
||||||
Tuple[Union[str, bytes], ...],
|
Tuple[Union[str, bytes], ...],
|
||||||
]
|
]
|
||||||
|
|
|
@ -18,7 +18,6 @@ import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Iterable,
|
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -38,7 +37,7 @@ from twisted.web.server import Request
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http import redact_uri
|
from synapse.http import redact_uri
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.types import JsonDict, RoomAlias, RoomID
|
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -340,7 +339,7 @@ def parse_string(
|
||||||
name: str,
|
name: str,
|
||||||
default: str,
|
default: str,
|
||||||
*,
|
*,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> str:
|
) -> str:
|
||||||
...
|
...
|
||||||
|
@ -352,7 +351,7 @@ def parse_string(
|
||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
required: Literal[True],
|
required: Literal[True],
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> str:
|
) -> str:
|
||||||
...
|
...
|
||||||
|
@ -365,7 +364,7 @@ def parse_string(
|
||||||
*,
|
*,
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
...
|
...
|
||||||
|
@ -376,7 +375,7 @@ def parse_string(
|
||||||
name: str,
|
name: str,
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -485,7 +484,7 @@ def parse_enum(
|
||||||
|
|
||||||
def _parse_string_value(
|
def _parse_string_value(
|
||||||
value: bytes,
|
value: bytes,
|
||||||
allowed_values: Optional[Iterable[str]],
|
allowed_values: Optional[StrCollection],
|
||||||
name: str,
|
name: str,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -511,7 +510,7 @@ def parse_strings_from_args(
|
||||||
args: Mapping[bytes, Sequence[bytes]],
|
args: Mapping[bytes, Sequence[bytes]],
|
||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[List[str]]:
|
) -> Optional[List[str]]:
|
||||||
...
|
...
|
||||||
|
@ -523,7 +522,7 @@ def parse_strings_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
default: List[str],
|
default: List[str],
|
||||||
*,
|
*,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
...
|
...
|
||||||
|
@ -535,7 +534,7 @@ def parse_strings_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
*,
|
*,
|
||||||
required: Literal[True],
|
required: Literal[True],
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
...
|
...
|
||||||
|
@ -548,7 +547,7 @@ def parse_strings_from_args(
|
||||||
default: Optional[List[str]] = None,
|
default: Optional[List[str]] = None,
|
||||||
*,
|
*,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[List[str]]:
|
) -> Optional[List[str]]:
|
||||||
...
|
...
|
||||||
|
@ -559,7 +558,7 @@ def parse_strings_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
default: Optional[List[str]] = None,
|
default: Optional[List[str]] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[List[str]]:
|
) -> Optional[List[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -610,7 +609,7 @@ def parse_string_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
...
|
...
|
||||||
|
@ -623,7 +622,7 @@ def parse_string_from_args(
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
required: Literal[True],
|
required: Literal[True],
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> str:
|
) -> str:
|
||||||
...
|
...
|
||||||
|
@ -635,7 +634,7 @@ def parse_string_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
...
|
...
|
||||||
|
@ -646,7 +645,7 @@ def parse_string_from_args(
|
||||||
name: str,
|
name: str,
|
||||||
default: Optional[str] = None,
|
default: Optional[str] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[StrCollection] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
|
||||||
return validate_json_object(content, model_type)
|
return validate_json_object(content, model_type)
|
||||||
|
|
||||||
|
|
||||||
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
|
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
|
||||||
absent = []
|
absent = []
|
||||||
for k in required:
|
for k in required:
|
||||||
if k not in body:
|
if k not in body:
|
||||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
|
||||||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
|
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
|
||||||
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
|
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
|
||||||
from synapse.metrics._types import Collector
|
from synapse.metrics._types import Collector
|
||||||
|
from synapse.types import StrSequence
|
||||||
from synapse.util import SYNAPSE_VERSION
|
from synapse.util import SYNAPSE_VERSION
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -81,7 +81,7 @@ class LaterGauge(Collector):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
desc: str
|
desc: str
|
||||||
labels: Optional[Sequence[str]] = attr.ib(hash=False)
|
labels: Optional[StrSequence] = attr.ib(hash=False)
|
||||||
# callback: should either return a value (if there are no labels for this metric),
|
# callback: should either return a value (if there are no labels for this metric),
|
||||||
# or dict mapping from a label tuple to a value
|
# or dict mapping from a label tuple to a value
|
||||||
caller: Callable[
|
caller: Callable[
|
||||||
|
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
labels: Sequence[str],
|
labels: StrSequence,
|
||||||
sub_metrics: Sequence[str],
|
sub_metrics: StrSequence,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.desc = desc
|
self.desc = desc
|
||||||
|
|
|
@ -104,7 +104,7 @@ class _NotifierUserStream:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
rooms: Collection[str],
|
rooms: StrCollection,
|
||||||
current_token: StreamToken,
|
current_token: StreamToken,
|
||||||
time_now_ms: int,
|
time_now_ms: int,
|
||||||
):
|
):
|
||||||
|
@ -457,7 +457,7 @@ class Notifier:
|
||||||
stream_key: str,
|
stream_key: str,
|
||||||
new_token: Union[int, RoomStreamToken],
|
new_token: Union[int, RoomStreamToken],
|
||||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||||
rooms: Optional[Collection[str]] = None,
|
rooms: Optional[StrCollection] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Used to inform listeners that something has happened event wise.
|
"""Used to inform listeners that something has happened event wise.
|
||||||
|
|
||||||
|
@ -529,7 +529,7 @@ class Notifier:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
|
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
|
||||||
room_ids: Optional[Collection[str]] = None,
|
room_ids: Optional[StrCollection] = None,
|
||||||
from_token: StreamToken = StreamToken.START,
|
from_token: StreamToken = StreamToken.START,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
|
|
|
@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
|
||||||
|
|
||||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||||
from synapse.api.urls import CLIENT_API_PREFIX
|
from synapse.api.urls import CLIENT_API_PREFIX
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, StrCollection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def client_patterns(
|
def client_patterns(
|
||||||
path_regex: str,
|
path_regex: str,
|
||||||
releases: Iterable[str] = ("r0", "v3"),
|
releases: StrCollection = ("r0", "v3"),
|
||||||
unstable: bool = True,
|
unstable: bool = True,
|
||||||
v1: bool = False,
|
v1: bool = False,
|
||||||
) -> Iterable[Pattern]:
|
) -> Iterable[Pattern]:
|
||||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
|
||||||
DefaultDict,
|
DefaultDict,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
|
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
|
||||||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
||||||
from synapse.state import v1, v2
|
from synapse.state import v1, v2
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap, StrCollection
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -197,7 +196,7 @@ class StateHandler:
|
||||||
async def compute_state_after_events(
|
async def compute_state_after_events(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_ids: Collection[str],
|
event_ids: StrCollection,
|
||||||
state_filter: Optional[StateFilter] = None,
|
state_filter: Optional[StateFilter] = None,
|
||||||
await_full_state: bool = True,
|
await_full_state: bool = True,
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
|
@ -231,7 +230,7 @@ class StateHandler:
|
||||||
return await ret.get_state(self._state_storage_controller, state_filter)
|
return await ret.get_state(self._state_storage_controller, state_filter)
|
||||||
|
|
||||||
async def get_current_user_ids_in_room(
|
async def get_current_user_ids_in_room(
|
||||||
self, room_id: str, latest_event_ids: Collection[str]
|
self, room_id: str, latest_event_ids: StrCollection
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
Get the users IDs who are currently in a room.
|
Get the users IDs who are currently in a room.
|
||||||
|
@ -256,7 +255,7 @@ class StateHandler:
|
||||||
return await self.store.get_joined_user_ids_from_state(room_id, state)
|
return await self.store.get_joined_user_ids_from_state(room_id, state)
|
||||||
|
|
||||||
async def get_hosts_in_room_at_events(
|
async def get_hosts_in_room_at_events(
|
||||||
self, room_id: str, event_ids: Collection[str]
|
self, room_id: str, event_ids: StrCollection
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
"""Get the hosts that were in a room at the given event ids
|
"""Get the hosts that were in a room at the given event ids
|
||||||
|
|
||||||
|
@ -470,7 +469,7 @@ class StateHandler:
|
||||||
@trace
|
@trace
|
||||||
@measure_func()
|
@measure_func()
|
||||||
async def resolve_state_groups_for_events(
|
async def resolve_state_groups_for_events(
|
||||||
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
|
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
|
||||||
) -> _StateCacheEntry:
|
) -> _StateCacheEntry:
|
||||||
"""Given a list of event_ids this method fetches the state at each
|
"""Given a list of event_ids this method fetches the state at each
|
||||||
event, resolves conflicts between them and returns them.
|
event, resolves conflicts between them and returns them.
|
||||||
|
@ -882,7 +881,7 @@ class StateResolutionStore:
|
||||||
store: "DataStore"
|
store: "DataStore"
|
||||||
|
|
||||||
def get_events(
|
def get_events(
|
||||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||||
) -> Awaitable[Dict[str, EventBase]]:
|
) -> Awaitable[Dict[str, EventBase]]:
|
||||||
"""Get events from the database
|
"""Get events from the database
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -45,7 +44,7 @@ async def resolve_events_with_store(
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
|
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -19,7 +19,6 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
|
||||||
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
|
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
|
||||||
# TestStateResolutionStore in tests.
|
# TestStateResolutionStore in tests.
|
||||||
def get_events(
|
def get_events(
|
||||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||||
) -> Awaitable[Dict[str, EventBase]]:
|
) -> Awaitable[Dict[str, EventBase]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
|
||||||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
|
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
|
||||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
||||||
|
|
||||||
auth_difference_unpersisted_part: Collection[str] = union - intersection
|
auth_difference_unpersisted_part: StrCollection = union - intersection
|
||||||
else:
|
else:
|
||||||
auth_difference_unpersisted_part = ()
|
auth_difference_unpersisted_part = ()
|
||||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||||
|
|
|
@ -47,7 +47,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.types import JsonDict, StrCollection
|
from synapse.types import JsonDict, StrCollection, StrSequence
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=5000, iterable=True)
|
@cached(max_entries=5000, iterable=True)
|
||||||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
|
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
|
||||||
return await self.db_pool.simple_select_onecol(
|
return await self.db_pool.simple_select_onecol(
|
||||||
table="event_forward_extremities",
|
table="event_forward_extremities",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
|
|
|
@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
from synapse.storage.controllers import StorageControllers
|
from synapse.storage.controllers import StorageControllers
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -150,12 +150,12 @@ async def filter_events_for_client(
|
||||||
|
|
||||||
async def filter_event_for_clients_with_state(
|
async def filter_event_for_clients_with_state(
|
||||||
store: DataStore,
|
store: DataStore,
|
||||||
user_ids: Collection[str],
|
user_ids: StrCollection,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
is_peeking: bool = False,
|
is_peeking: bool = False,
|
||||||
filter_send_to_client: bool = True,
|
filter_send_to_client: bool = True,
|
||||||
) -> Collection[str]:
|
) -> StrCollection:
|
||||||
"""
|
"""
|
||||||
Checks to see if an event is visible to the users in the list at the time of
|
Checks to see if an event is visible to the users in the list at the time of
|
||||||
the event.
|
the event.
|
||||||
|
|
Loading…
Reference in New Issue