Add types to synapse.util. (#10601)
This commit is contained in:
parent
ceab5a4bfa
commit
524b8ead77
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to the synapse.util package.
|
75
mypy.ini
75
mypy.ini
|
@ -74,17 +74,7 @@ files =
|
||||||
synapse/storage/util,
|
synapse/storage/util,
|
||||||
synapse/streams,
|
synapse/streams,
|
||||||
synapse/types.py,
|
synapse/types.py,
|
||||||
synapse/util/async_helpers.py,
|
synapse/util,
|
||||||
synapse/util/caches,
|
|
||||||
synapse/util/daemonize.py,
|
|
||||||
synapse/util/hash.py,
|
|
||||||
synapse/util/iterutils.py,
|
|
||||||
synapse/util/linked_list.py,
|
|
||||||
synapse/util/metrics.py,
|
|
||||||
synapse/util/macaroons.py,
|
|
||||||
synapse/util/module_loader.py,
|
|
||||||
synapse/util/msisdn.py,
|
|
||||||
synapse/util/stringutils.py,
|
|
||||||
synapse/visibility.py,
|
synapse/visibility.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_event_auth.py,
|
tests/test_event_auth.py,
|
||||||
|
@ -102,6 +92,69 @@ files =
|
||||||
[mypy-synapse.rest.client.*]
|
[mypy-synapse.rest.client.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.batching_queue]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.caches.dictionary_cache]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.file_consumer]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.frozenutils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.hash]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.httpresourcetree]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.iterutils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.linked_list]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.logcontext]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.logformatter]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.macaroons]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.manhole]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.module_loader]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.msisdn]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.ratelimitutils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.retryutils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.rlimit]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.stringutils]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.templates]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.threepids]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.util.wheel_timer]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-pymacaroons.*]
|
[mypy-pymacaroons.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|
|
@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
|
||||||
def buildProtocol(self, addr) -> RedisProtocol: ...
|
def buildProtocol(self, addr) -> RedisProtocol: ...
|
||||||
|
|
||||||
class SubscriberFactory(RedisFactory):
|
class SubscriberFactory(RedisFactory):
|
||||||
def __init__(self): ...
|
def __init__(self) -> None: ...
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Ratelimiter:
|
||||||
# * How many times an action has occurred since a point in time
|
# * How many times an action has occurred since a point in time
|
||||||
# * The point in time
|
# * The point in time
|
||||||
# * The rate_hz of this particular entry. This can vary per request
|
# * The rate_hz of this particular entry. This can vary per request
|
||||||
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
|
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
|
||||||
|
|
||||||
async def can_do_action(
|
async def can_do_action(
|
||||||
self,
|
self,
|
||||||
|
@ -56,7 +56,7 @@ class Ratelimiter:
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
n_actions: int = 1,
|
n_actions: int = 1,
|
||||||
_time_now_s: Optional[int] = None,
|
_time_now_s: Optional[float] = None,
|
||||||
) -> Tuple[bool, float]:
|
) -> Tuple[bool, float]:
|
||||||
"""Can the entity (e.g. user or IP address) perform the action?
|
"""Can the entity (e.g. user or IP address) perform the action?
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ class Ratelimiter:
|
||||||
|
|
||||||
return allowed, time_allowed
|
return allowed, time_allowed
|
||||||
|
|
||||||
def _prune_message_counts(self, time_now_s: int):
|
def _prune_message_counts(self, time_now_s: float):
|
||||||
"""Remove message count entries that have not exceeded their defined
|
"""Remove message count entries that have not exceeded their defined
|
||||||
rate_hz limit
|
rate_hz limit
|
||||||
|
|
||||||
|
@ -188,7 +188,7 @@ class Ratelimiter:
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
n_actions: int = 1,
|
n_actions: int = 1,
|
||||||
_time_now_s: Optional[int] = None,
|
_time_now_s: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,18 +31,13 @@ class RateLimitConfig:
|
||||||
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
|
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
class FederationRateLimitConfig:
|
class FederationRateLimitConfig:
|
||||||
_items_and_default = {
|
window_size: int = 1000
|
||||||
"window_size": 1000,
|
sleep_limit: int = 10
|
||||||
"sleep_limit": 10,
|
sleep_delay: int = 500
|
||||||
"sleep_delay": 500,
|
reject_limit: int = 50
|
||||||
"reject_limit": 50,
|
concurrent: int = 3
|
||||||
"concurrent": 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
for i in self._items_and_default.keys():
|
|
||||||
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
|
|
||||||
|
|
||||||
|
|
||||||
class RatelimitConfig(Config):
|
class RatelimitConfig(Config):
|
||||||
|
@ -69,11 +66,15 @@ class RatelimitConfig(Config):
|
||||||
else:
|
else:
|
||||||
self.rc_federation = FederationRateLimitConfig(
|
self.rc_federation = FederationRateLimitConfig(
|
||||||
**{
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
"window_size": config.get("federation_rc_window_size"),
|
"window_size": config.get("federation_rc_window_size"),
|
||||||
"sleep_limit": config.get("federation_rc_sleep_limit"),
|
"sleep_limit": config.get("federation_rc_sleep_limit"),
|
||||||
"sleep_delay": config.get("federation_rc_sleep_delay"),
|
"sleep_delay": config.get("federation_rc_sleep_delay"),
|
||||||
"reject_limit": config.get("federation_rc_reject_limit"),
|
"reject_limit": config.get("federation_rc_reject_limit"),
|
||||||
"concurrent": config.get("federation_rc_concurrent"),
|
"concurrent": config.get("federation_rc_concurrent"),
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from prometheus_client import Counter
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
|
@ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender):
|
||||||
)
|
)
|
||||||
|
|
||||||
# wake up destinations that have outstanding PDUs to be caught up
|
# wake up destinations that have outstanding PDUs to be caught up
|
||||||
self._catchup_after_startup_timer = self.clock.call_later(
|
self._catchup_after_startup_timer: Optional[
|
||||||
|
IDelayedCall
|
||||||
|
] = self.clock.call_later(
|
||||||
CATCH_UP_STARTUP_DELAY_SEC,
|
CATCH_UP_STARTUP_DELAY_SEC,
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
"wake_destinations_needing_catchup",
|
"wake_destinations_needing_catchup",
|
||||||
|
@ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(event.event_id)
|
ts = await self.store.get_received_ts(event.event_id)
|
||||||
|
assert ts is not None
|
||||||
synapse.metrics.event_processing_lag_by_event.labels(
|
synapse.metrics.event_processing_lag_by_event.labels(
|
||||||
"federation_sender"
|
"federation_sender"
|
||||||
).observe((now - ts) / 1000)
|
).observe((now - ts) / 1000)
|
||||||
|
@ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
if events:
|
if events:
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(events[-1].event_id)
|
ts = await self.store.get_received_ts(events[-1].event_id)
|
||||||
|
assert ts is not None
|
||||||
|
|
||||||
synapse.metrics.event_processing_lag.labels(
|
synapse.metrics.event_processing_lag.labels(
|
||||||
"federation_sender"
|
"federation_sender"
|
||||||
|
|
|
@ -398,6 +398,7 @@ class AccountValidityHandler:
|
||||||
"""
|
"""
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
if expiration_ts is None:
|
if expiration_ts is None:
|
||||||
|
assert self._account_validity_period is not None
|
||||||
expiration_ts = now + self._account_validity_period
|
expiration_ts = now + self._account_validity_period
|
||||||
|
|
||||||
await self.store.set_account_validity_for_user(
|
await self.store.set_account_validity_for_user(
|
||||||
|
|
|
@ -131,6 +131,8 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(event.event_id)
|
ts = await self.store.get_received_ts(event.event_id)
|
||||||
|
assert ts is not None
|
||||||
|
|
||||||
synapse.metrics.event_processing_lag_by_event.labels(
|
synapse.metrics.event_processing_lag_by_event.labels(
|
||||||
"appservice_sender"
|
"appservice_sender"
|
||||||
).observe((now - ts) / 1000)
|
).observe((now - ts) / 1000)
|
||||||
|
@ -166,6 +168,7 @@ class ApplicationServicesHandler:
|
||||||
if events:
|
if events:
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(events[-1].event_id)
|
ts = await self.store.get_received_ts(events[-1].event_id)
|
||||||
|
assert ts is not None
|
||||||
|
|
||||||
synapse.metrics.event_processing_lag.labels(
|
synapse.metrics.event_processing_lag.labels(
|
||||||
"appservice_sender"
|
"appservice_sender"
|
||||||
|
|
|
@ -28,6 +28,7 @@ from bisect import bisect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.wheel_timer = WheelTimer()
|
self.wheel_timer: WheelTimer[str] = WheelTimer()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self._presence_enabled = hs.config.use_presence
|
self._presence_enabled = hs.config.use_presence
|
||||||
|
|
||||||
|
@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
prev_state = await self.current_state_for_user(user_id)
|
prev_state = await self.current_state_for_user(user_id)
|
||||||
|
|
||||||
new_fields = {"last_active_ts": self.clock.time_msec()}
|
new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
|
||||||
if prev_state.state == PresenceState.UNAVAILABLE:
|
if prev_state.state == PresenceState.UNAVAILABLE:
|
||||||
new_fields["state"] = PresenceState.ONLINE
|
new_fields["state"] = PresenceState.ONLINE
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ class FollowerTypingHandler:
|
||||||
self._room_typing: Dict[str, Set[str]] = {}
|
self._room_typing: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
self._member_last_federation_poke: Dict[RoomMember, int] = {}
|
self._member_last_federation_poke: Dict[RoomMember, int] = {}
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
|
||||||
self._latest_room_serial = 0
|
self._latest_room_serial = 0
|
||||||
|
|
||||||
self.clock.looping_call(self._handle_timeouts, 5000)
|
self.clock.looping_call(self._handle_timeouts, 5000)
|
||||||
|
|
|
@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||||
# Artificially delay requests if rate > sleep_limit/window_size
|
# Artificially delay requests if rate > sleep_limit/window_size
|
||||||
sleep_limit=1,
|
sleep_limit=1,
|
||||||
# Amount of artificial delay to apply
|
# Amount of artificial delay to apply
|
||||||
sleep_msec=1000,
|
sleep_delay=1000,
|
||||||
# Error with 429 if more than reject_limit requests are queued
|
# Error with 429 if more than reject_limit requests are queued
|
||||||
reject_limit=1,
|
reject_limit=1,
|
||||||
# Allow 1 request at a time
|
# Allow 1 request at a time
|
||||||
concurrent_requests=1,
|
concurrent=1,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
Returns:
|
Returns:
|
||||||
dictionary for response from /register
|
dictionary for response from /register
|
||||||
"""
|
"""
|
||||||
result = {"user_id": user_id, "home_server": self.hs.hostname}
|
result: JsonDict = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
if not params.get("inhibit_login", False):
|
if not params.get("inhibit_login", False):
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
|
@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
user_id, device_id, initial_display_name, is_guest=True
|
user_id, device_id, initial_display_name, is_guest=True
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result: JsonDict = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
|
|
@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
|
||||||
yield hs.config.sso.sso_template_dir
|
yield hs.config.sso.sso_template_dir
|
||||||
yield hs.config.sso.default_template_dir
|
yield hs.config.sso.default_template_dir
|
||||||
|
|
||||||
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
|
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
|
||||||
yield hs.config.sso.sso_template_dir
|
yield hs.config.sso.sso_template_dir
|
||||||
yield hs.config.sso.default_template_dir
|
yield hs.config.sso.default_template_dir
|
||||||
|
|
||||||
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
|
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: Request) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
delta equal to 10% of the validity period.
|
delta equal to 10% of the validity period.
|
||||||
"""
|
"""
|
||||||
now_ms = self._clock.time_msec()
|
now_ms = self._clock.time_msec()
|
||||||
|
assert self._account_validity_period is not None
|
||||||
expiration_ts = now_ms + self._account_validity_period
|
expiration_ts = now_ms + self._account_validity_period
|
||||||
|
|
||||||
if use_delta:
|
if use_delta:
|
||||||
|
|
|
@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
|
||||||
IReactorCore,
|
IReactorCore,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTCP,
|
IReactorTCP,
|
||||||
|
IReactorThreads,
|
||||||
IReactorTime,
|
IReactorTime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
|
||||||
# Note that this seems to require inheriting *directly* from Interface in order
|
# Note that this seems to require inheriting *directly* from Interface in order
|
||||||
# for mypy-zope to realize it is an interface.
|
# for mypy-zope to realize it is an interface.
|
||||||
class ISynapseReactor(
|
class ISynapseReactor(
|
||||||
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
|
IReactorTCP,
|
||||||
|
IReactorPluggableNameResolver,
|
||||||
|
IReactorTime,
|
||||||
|
IReactorCore,
|
||||||
|
IReactorThreads,
|
||||||
|
Interface,
|
||||||
):
|
):
|
||||||
"""The interfaces necessary for Synapse to function."""
|
"""The interfaces necessary for Synapse to function."""
|
||||||
|
|
||||||
|
|
|
@ -15,27 +15,35 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Pattern
|
import typing
|
||||||
|
from typing import Any, Callable, Dict, Generator, Pattern
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer, task
|
from twisted.internet import defer, task
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.interfaces import IDelayedCall, IReactorTime
|
||||||
|
from twisted.internet.task import LoopingCall
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging import context
|
from synapse.logging import context
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_WILDCARD_RUN = re.compile(r"([\?\*]+)")
|
_WILDCARD_RUN = re.compile(r"([\?\*]+)")
|
||||||
|
|
||||||
|
|
||||||
def _reject_invalid_json(val):
|
def _reject_invalid_json(val: Any) -> None:
|
||||||
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
|
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
|
||||||
raise ValueError("Invalid JSON value: '%s'" % val)
|
raise ValueError("Invalid JSON value: '%s'" % val)
|
||||||
|
|
||||||
|
|
||||||
def _handle_frozendict(obj):
|
def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
|
||||||
"""Helper for json_encoder. Makes frozendicts serializable by returning
|
"""Helper for json_encoder. Makes frozendicts serializable by returning
|
||||||
the underlying dict
|
the underlying dict
|
||||||
"""
|
"""
|
||||||
|
@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
|
||||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||||
|
|
||||||
|
|
||||||
def unwrapFirstError(failure):
|
def unwrapFirstError(failure: Failure) -> Failure:
|
||||||
# defer.gatherResults and DeferredLists wrap failures.
|
# defer.gatherResults and DeferredLists wrap failures.
|
||||||
failure.trap(defer.FirstError)
|
failure.trap(defer.FirstError)
|
||||||
return failure.value.subFailure
|
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
|
@ -75,25 +83,25 @@ class Clock:
|
||||||
reactor: The Twisted reactor to use.
|
reactor: The Twisted reactor to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_reactor = attr.ib()
|
_reactor: IReactorTime = attr.ib()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations
|
||||||
def sleep(self, seconds):
|
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
|
||||||
d = defer.Deferred()
|
d: defer.Deferred[float] = defer.Deferred()
|
||||||
with context.PreserveLoggingContext():
|
with context.PreserveLoggingContext():
|
||||||
self._reactor.callLater(seconds, d.callback, seconds)
|
self._reactor.callLater(seconds, d.callback, seconds)
|
||||||
res = yield d
|
res = yield d
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def time(self):
|
def time(self) -> float:
|
||||||
"""Returns the current system time in seconds since epoch."""
|
"""Returns the current system time in seconds since epoch."""
|
||||||
return self._reactor.seconds()
|
return self._reactor.seconds()
|
||||||
|
|
||||||
def time_msec(self):
|
def time_msec(self) -> int:
|
||||||
"""Returns the current system time in milliseconds since epoch."""
|
"""Returns the current system time in milliseconds since epoch."""
|
||||||
return int(self.time() * 1000)
|
return int(self.time() * 1000)
|
||||||
|
|
||||||
def looping_call(self, f, msec, *args, **kwargs):
|
def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
|
||||||
"""Call a function repeatedly.
|
"""Call a function repeatedly.
|
||||||
|
|
||||||
Waits `msec` initially before calling `f` for the first time.
|
Waits `msec` initially before calling `f` for the first time.
|
||||||
|
@ -102,8 +110,8 @@ class Clock:
|
||||||
other than trivial, you probably want to wrap it in run_as_background_process.
|
other than trivial, you probably want to wrap it in run_as_background_process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f(function): The function to call repeatedly.
|
f: The function to call repeatedly.
|
||||||
msec(float): How long to wait between calls in milliseconds.
|
msec: How long to wait between calls in milliseconds.
|
||||||
*args: Postional arguments to pass to function.
|
*args: Postional arguments to pass to function.
|
||||||
**kwargs: Key arguments to pass to function.
|
**kwargs: Key arguments to pass to function.
|
||||||
"""
|
"""
|
||||||
|
@ -113,7 +121,7 @@ class Clock:
|
||||||
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
|
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
|
||||||
return call
|
return call
|
||||||
|
|
||||||
def call_later(self, delay, callback, *args, **kwargs):
|
def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
|
||||||
"""Call something later
|
"""Call something later
|
||||||
|
|
||||||
Note that the function will be called with no logcontext, so if it is anything
|
Note that the function will be called with no logcontext, so if it is anything
|
||||||
|
@ -133,7 +141,7 @@ class Clock:
|
||||||
with context.PreserveLoggingContext():
|
with context.PreserveLoggingContext():
|
||||||
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||||
|
|
||||||
def cancel_call_later(self, timer, ignore_errs=False):
|
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
|
||||||
try:
|
try:
|
||||||
timer.cancel()
|
timer.cancel()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -37,6 +37,7 @@ import attr
|
||||||
from typing_extensions import ContextManager
|
from typing_extensions import ContextManager
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.base import ReactorBase
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
@ -268,6 +269,7 @@ class Linearizer:
|
||||||
if not clock:
|
if not clock:
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
assert isinstance(reactor, ReactorBase)
|
||||||
clock = Clock(reactor)
|
clock = Clock(reactor)
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
|
@ -411,7 +413,7 @@ class ReadWriteLock:
|
||||||
# writers and readers have been resolved. The new writer replaces the latest
|
# writers and readers have been resolved. The new writer replaces the latest
|
||||||
# writer.
|
# writer.
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# Latest readers queued
|
# Latest readers queued
|
||||||
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
|
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
|
||||||
|
|
||||||
|
@ -503,7 +505,7 @@ def timeout_deferred(
|
||||||
|
|
||||||
timed_out = [False]
|
timed_out = [False]
|
||||||
|
|
||||||
def time_it_out():
|
def time_it_out() -> None:
|
||||||
timed_out[0] = True
|
timed_out[0] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -550,19 +552,21 @@ def timeout_deferred(
|
||||||
return new_d
|
return new_d
|
||||||
|
|
||||||
|
|
||||||
|
# This class can't be generic because it uses slots with attrs.
|
||||||
|
# See: https://github.com/python-attrs/attrs/issues/313
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class DoneAwaitable:
|
class DoneAwaitable: # should be: Generic[R]
|
||||||
"""Simple awaitable that returns the provided value."""
|
"""Simple awaitable that returns the provided value."""
|
||||||
|
|
||||||
value = attr.ib()
|
value = attr.ib(type=Any) # should be: R
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> "DoneAwaitable":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self) -> None:
|
||||||
raise StopIteration(self.value)
|
raise StopIteration(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
|
||||||
|
|
||||||
# First we create a defer and add it and the value to the list of
|
# First we create a defer and add it and the value to the list of
|
||||||
# pending items.
|
# pending items.
|
||||||
d = defer.Deferred()
|
d: defer.Deferred[R] = defer.Deferred()
|
||||||
self._next_values.setdefault(key, []).append((value, d))
|
self._next_values.setdefault(key, []).append((value, d))
|
||||||
|
|
||||||
# If we're not currently processing the key fire off a background
|
# If we're not currently processing the key fire off a background
|
||||||
|
|
|
@ -64,32 +64,32 @@ class CacheMetric:
|
||||||
evicted_size = attr.ib(default=0)
|
evicted_size = attr.ib(default=0)
|
||||||
memory_usage = attr.ib(default=None)
|
memory_usage = attr.ib(default=None)
|
||||||
|
|
||||||
def inc_hits(self):
|
def inc_hits(self) -> None:
|
||||||
self.hits += 1
|
self.hits += 1
|
||||||
|
|
||||||
def inc_misses(self):
|
def inc_misses(self) -> None:
|
||||||
self.misses += 1
|
self.misses += 1
|
||||||
|
|
||||||
def inc_evictions(self, size=1):
|
def inc_evictions(self, size: int = 1) -> None:
|
||||||
self.evicted_size += size
|
self.evicted_size += size
|
||||||
|
|
||||||
def inc_memory_usage(self, memory: int):
|
def inc_memory_usage(self, memory: int) -> None:
|
||||||
if self.memory_usage is None:
|
if self.memory_usage is None:
|
||||||
self.memory_usage = 0
|
self.memory_usage = 0
|
||||||
|
|
||||||
self.memory_usage += memory
|
self.memory_usage += memory
|
||||||
|
|
||||||
def dec_memory_usage(self, memory: int):
|
def dec_memory_usage(self, memory: int) -> None:
|
||||||
self.memory_usage -= memory
|
self.memory_usage -= memory
|
||||||
|
|
||||||
def clear_memory_usage(self):
|
def clear_memory_usage(self) -> None:
|
||||||
if self.memory_usage is not None:
|
if self.memory_usage is not None:
|
||||||
self.memory_usage = 0
|
self.memory_usage = 0
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def collect(self):
|
def collect(self) -> None:
|
||||||
try:
|
try:
|
||||||
if self._cache_type == "response_cache":
|
if self._cache_type == "response_cache":
|
||||||
response_cache_size.labels(self._cache_name).set(len(self._cache))
|
response_cache_size.labels(self._cache_name).set(len(self._cache))
|
||||||
|
|
|
@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
TreeCache, "MutableMapping[KT, CacheEntry]"
|
TreeCache, "MutableMapping[KT, CacheEntry]"
|
||||||
] = cache_type()
|
] = cache_type()
|
||||||
|
|
||||||
def metrics_cb():
|
def metrics_cb() -> None:
|
||||||
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
|
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
|
||||||
|
|
||||||
# cache is used for completed results and maps to the result itself, rather than
|
# cache is used for completed results and maps to the result itself, rather than
|
||||||
|
@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
def max_entries(self):
|
def max_entries(self):
|
||||||
return self.cache.max_size
|
return self.cache.max_size
|
||||||
|
|
||||||
def check_thread(self):
|
def check_thread(self) -> None:
|
||||||
expected_thread = self.thread
|
expected_thread = self.thread
|
||||||
if expected_thread is None:
|
if expected_thread is None:
|
||||||
self.thread = threading.current_thread()
|
self.thread = threading.current_thread()
|
||||||
|
@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
|
|
||||||
self._pending_deferred_cache[key] = entry
|
self._pending_deferred_cache[key] = entry
|
||||||
|
|
||||||
def compare_and_pop():
|
def compare_and_pop() -> bool:
|
||||||
"""Check if our entry is still the one in _pending_deferred_cache, and
|
"""Check if our entry is still the one in _pending_deferred_cache, and
|
||||||
if so, pop it.
|
if so, pop it.
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def cb(result):
|
def cb(result) -> None:
|
||||||
if compare_and_pop():
|
if compare_and_pop():
|
||||||
self.cache.set(key, result, entry.callbacks)
|
self.cache.set(key, result, entry.callbacks)
|
||||||
else:
|
else:
|
||||||
|
@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
# not have been. Either way, let's double-check now.
|
# not have been. Either way, let's double-check now.
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
def eb(_fail):
|
def eb(_fail) -> None:
|
||||||
compare_and_pop()
|
compare_and_pop()
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
|
@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
for entry in iterate_tree_cache_entry(entry):
|
for entry in iterate_tree_cache_entry(entry):
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
def invalidate_all(self):
|
def invalidate_all(self) -> None:
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
for entry in self._pending_deferred_cache.values():
|
for entry in self._pending_deferred_cache.values():
|
||||||
|
@ -332,7 +332,7 @@ class CacheEntry:
|
||||||
self.callbacks = set(callbacks)
|
self.callbacks = set(callbacks)
|
||||||
self.invalidated = False
|
self.invalidated = False
|
||||||
|
|
||||||
def invalidate(self):
|
def invalidate(self) -> None:
|
||||||
if not self.invalidated:
|
if not self.invalidated:
|
||||||
self.invalidated = True
|
self.invalidated = True
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
|
@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
|
||||||
KT = TypeVar("KT")
|
KT = TypeVar("KT")
|
||||||
# The type of the dictionary keys.
|
# The type of the dictionary keys.
|
||||||
DKT = TypeVar("DKT")
|
DKT = TypeVar("DKT")
|
||||||
|
# The type of the dictionary values.
|
||||||
|
DV = TypeVar("DV")
|
||||||
|
|
||||||
|
|
||||||
|
# This class can't be generic because it uses slots with attrs.
|
||||||
|
# See: https://github.com/python-attrs/attrs/issues/313
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class DictionaryEntry:
|
class DictionaryEntry: # should be: Generic[DKT, DV].
|
||||||
"""Returned when getting an entry from the cache
|
"""Returned when getting an entry from the cache
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -43,10 +47,10 @@ class DictionaryEntry:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
full = attr.ib(type=bool)
|
full = attr.ib(type=bool)
|
||||||
known_absent = attr.ib()
|
known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
|
||||||
value = attr.ib()
|
value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.value)
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
|
||||||
sentinel = object()
|
sentinel = object()
|
||||||
|
|
||||||
|
|
||||||
class DictionaryCache(Generic[KT, DKT]):
|
class DictionaryCache(Generic[KT, DKT, DV]):
|
||||||
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
||||||
fetching a subset of dictionary keys for a particular key.
|
fetching a subset of dictionary keys for a particular key.
|
||||||
"""
|
"""
|
||||||
|
@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key
|
key
|
||||||
dict_key: If given a set of keys then return only those keys
|
dict_keys: If given a set of keys then return only those keys
|
||||||
that exist in the cache.
|
that exist in the cache.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
|
||||||
self,
|
self,
|
||||||
sequence: int,
|
sequence: int,
|
||||||
key: KT,
|
key: KT,
|
||||||
value: Dict[DKT, Any],
|
value: Dict[DKT, DV],
|
||||||
fetched_keys: Optional[Set[DKT]] = None,
|
fetched_keys: Optional[Set[DKT]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Updates the entry in the cache
|
"""Updates the entry in the cache
|
||||||
|
@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
|
||||||
self._update_or_insert(key, value, fetched_keys)
|
self._update_or_insert(key, value, fetched_keys)
|
||||||
|
|
||||||
def _update_or_insert(
|
def _update_or_insert(
|
||||||
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
|
self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
|
||||||
) -> None:
|
) -> None:
|
||||||
# We pop and reinsert as we need to tell the cache the size may have
|
# We pop and reinsert as we need to tell the cache the size may have
|
||||||
# changed
|
# changed
|
||||||
|
|
||||||
entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
|
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
|
||||||
entry.value.update(value)
|
entry.value.update(value)
|
||||||
entry.known_absent.update(known_absent)
|
entry.known_absent.update(known_absent)
|
||||||
self.cache[key] = entry
|
self.cache[key] = entry
|
||||||
|
|
||||||
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
|
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
|
||||||
self.cache[key] = DictionaryEntry(True, known_absent, value)
|
self.cache[key] = DictionaryEntry(True, known_absent, value)
|
||||||
|
|
|
@ -35,6 +35,7 @@ from typing import (
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
from twisted.internet.interfaces import IReactorTime
|
||||||
|
|
||||||
from synapse.config import cache as cache_config
|
from synapse.config import cache as cache_config
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
|
@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
|
||||||
# Default `clock` to something sensible. Note that we rename it to
|
# Default `clock` to something sensible. Note that we rename it to
|
||||||
# `real_clock` so that mypy doesn't think its still `Optional`.
|
# `real_clock` so that mypy doesn't think its still `Optional`.
|
||||||
if clock is None:
|
if clock is None:
|
||||||
real_clock = Clock(reactor)
|
real_clock = Clock(cast(IReactorTime, reactor))
|
||||||
else:
|
else:
|
||||||
real_clock = clock
|
real_clock = clock
|
||||||
|
|
||||||
|
@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
|
||||||
|
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
def evict():
|
def evict() -> None:
|
||||||
while cache_len() > self.max_size:
|
while cache_len() > self.max_size:
|
||||||
# Get the last node in the list (i.e. the oldest node).
|
# Get the last node in the list (i.e. the oldest node).
|
||||||
todelete = list_root.prev_node
|
todelete = list_root.prev_node
|
||||||
|
|
|
@ -195,7 +195,7 @@ class StreamChangeCache:
|
||||||
for entity in r:
|
for entity in r:
|
||||||
del self._entity_to_key[entity]
|
del self._entity_to_key[entity]
|
||||||
|
|
||||||
def _evict(self):
|
def _evict(self) -> None:
|
||||||
while len(self._cache) > self._max_size:
|
while len(self._cache) > self._max_size:
|
||||||
k, r = self._cache.popitem(0)
|
k, r = self._cache.popitem(0)
|
||||||
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
|
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
|
||||||
|
|
|
@ -35,17 +35,17 @@ class TreeCache:
|
||||||
root = {key_1: {key_2: _value}}
|
root = {key_1: {key_2: _value}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.size = 0
|
self.size: int = 0
|
||||||
self.root = TreeCacheNode()
|
self.root = TreeCacheNode()
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value) -> None:
|
||||||
return self.set(key, value)
|
self.set(key, value)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key) -> bool:
|
||||||
return self.get(key, SENTINEL) is not SENTINEL
|
return self.get(key, SENTINEL) is not SENTINEL
|
||||||
|
|
||||||
def set(self, key, value):
|
def set(self, key, value) -> None:
|
||||||
if isinstance(value, TreeCacheNode):
|
if isinstance(value, TreeCacheNode):
|
||||||
# this would mean we couldn't tell where our tree ended and the value
|
# this would mean we couldn't tell where our tree ended and the value
|
||||||
# started.
|
# started.
|
||||||
|
@ -73,7 +73,7 @@ class TreeCache:
|
||||||
return default
|
return default
|
||||||
return node.get(key[-1], default)
|
return node.get(key[-1], default)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.root = TreeCacheNode()
|
self.root = TreeCacheNode()
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ class TreeCache:
|
||||||
def values(self):
|
def values(self):
|
||||||
return iterate_tree_cache_entry(self.root)
|
return iterate_tree_cache_entry(self.root)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
||||||
signal.signal(signal.SIGTERM, sigterm)
|
signal.signal(signal.SIGTERM, sigterm)
|
||||||
|
|
||||||
# Cleanup pid file at exit.
|
# Cleanup pid file at exit.
|
||||||
def exit():
|
def exit() -> None:
|
||||||
logger.warning("Stopping daemon.")
|
logger.warning("Stopping daemon.")
|
||||||
os.remove(pid_file)
|
os.remove(pid_file)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -37,11 +38,11 @@ class Distributor:
|
||||||
model will do for today.
|
model will do for today.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.signals = {}
|
self.signals: Dict[str, Signal] = {}
|
||||||
self.pre_registration = {}
|
self.pre_registration: Dict[str, List[Callable]] = {}
|
||||||
|
|
||||||
def declare(self, name):
|
def declare(self, name: str) -> None:
|
||||||
if name in self.signals:
|
if name in self.signals:
|
||||||
raise KeyError("%r already has a signal named %s" % (self, name))
|
raise KeyError("%r already has a signal named %s" % (self, name))
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ class Distributor:
|
||||||
for observer in self.pre_registration[name]:
|
for observer in self.pre_registration[name]:
|
||||||
signal.observe(observer)
|
signal.observe(observer)
|
||||||
|
|
||||||
def observe(self, name, observer):
|
def observe(self, name: str, observer: Callable) -> None:
|
||||||
if name in self.signals:
|
if name in self.signals:
|
||||||
self.signals[name].observe(observer)
|
self.signals[name].observe(observer)
|
||||||
else:
|
else:
|
||||||
|
@ -62,7 +63,7 @@ class Distributor:
|
||||||
self.pre_registration[name] = []
|
self.pre_registration[name] = []
|
||||||
self.pre_registration[name].append(observer)
|
self.pre_registration[name].append(observer)
|
||||||
|
|
||||||
def fire(self, name, *args, **kwargs):
|
def fire(self, name: str, *args, **kwargs) -> None:
|
||||||
"""Dispatches the given signal to the registered observers.
|
"""Dispatches the given signal to the registered observers.
|
||||||
|
|
||||||
Runs the observers as a background process. Does not return a deferred.
|
Runs the observers as a background process. Does not return a deferred.
|
||||||
|
@ -83,18 +84,18 @@ class Signal:
|
||||||
method into all of the observers.
|
method into all of the observers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
self.name: str = name
|
||||||
self.observers = []
|
self.observers: List[Callable] = []
|
||||||
|
|
||||||
def observe(self, observer):
|
def observe(self, observer: Callable) -> None:
|
||||||
"""Adds a new callable to the observer list which will be invoked by
|
"""Adds a new callable to the observer list which will be invoked by
|
||||||
the 'fire' method.
|
the 'fire' method.
|
||||||
|
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
def fire(self, *args, **kwargs):
|
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
not an error to fire a signal with no observers.
|
not an error to fire a signal with no observers.
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import queue
|
import queue
|
||||||
|
from typing import BinaryIO, Optional, Union, cast
|
||||||
|
|
||||||
from twisted.internet import threads
|
from twisted.internet import threads
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.interfaces import IPullProducer, IPushProducer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
|
|
||||||
|
|
||||||
class BackgroundFileConsumer:
|
class BackgroundFileConsumer:
|
||||||
|
@ -24,9 +28,9 @@ class BackgroundFileConsumer:
|
||||||
and pull producers
|
and pull producers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_obj (file): The file like object to write to. Closed when
|
file_obj: The file like object to write to. Closed when
|
||||||
finished.
|
finished.
|
||||||
reactor (twisted.internet.reactor): the Twisted reactor to use
|
reactor: the Twisted reactor to use
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# For PushProducers pause if we have this many unwritten slices
|
# For PushProducers pause if we have this many unwritten slices
|
||||||
|
@ -34,13 +38,13 @@ class BackgroundFileConsumer:
|
||||||
# And resume once the size of the queue is less than this
|
# And resume once the size of the queue is less than this
|
||||||
_RESUME_ON_QUEUE_SIZE = 2
|
_RESUME_ON_QUEUE_SIZE = 2
|
||||||
|
|
||||||
def __init__(self, file_obj, reactor):
|
def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None:
|
||||||
self._file_obj = file_obj
|
self._file_obj: BinaryIO = file_obj
|
||||||
|
|
||||||
self._reactor = reactor
|
self._reactor: ISynapseReactor = reactor
|
||||||
|
|
||||||
# Producer we're registered with
|
# Producer we're registered with
|
||||||
self._producer = None
|
self._producer: Optional[Union[IPushProducer, IPullProducer]] = None
|
||||||
|
|
||||||
# True if PushProducer, false if PullProducer
|
# True if PushProducer, false if PullProducer
|
||||||
self.streaming = False
|
self.streaming = False
|
||||||
|
@ -51,20 +55,22 @@ class BackgroundFileConsumer:
|
||||||
|
|
||||||
# Queue of slices of bytes to be written. When producer calls
|
# Queue of slices of bytes to be written. When producer calls
|
||||||
# unregister a final None is sent.
|
# unregister a final None is sent.
|
||||||
self._bytes_queue = queue.Queue()
|
self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
|
||||||
|
|
||||||
# Deferred that is resolved when finished writing
|
# Deferred that is resolved when finished writing
|
||||||
self._finished_deferred = None
|
self._finished_deferred: Optional[Deferred[None]] = None
|
||||||
|
|
||||||
# If the _writer thread throws an exception it gets stored here.
|
# If the _writer thread throws an exception it gets stored here.
|
||||||
self._write_exception = None
|
self._write_exception: Optional[Exception] = None
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
def registerProducer(
|
||||||
|
self, producer: Union[IPushProducer, IPullProducer], streaming: bool
|
||||||
|
) -> None:
|
||||||
"""Part of IConsumer interface
|
"""Part of IConsumer interface
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
producer (IProducer)
|
producer
|
||||||
streaming (bool): True if push based producer, False if pull
|
streaming: True if push based producer, False if pull
|
||||||
based.
|
based.
|
||||||
"""
|
"""
|
||||||
if self._producer:
|
if self._producer:
|
||||||
|
@ -81,29 +87,33 @@ class BackgroundFileConsumer:
|
||||||
if not streaming:
|
if not streaming:
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
|
|
||||||
def unregisterProducer(self):
|
def unregisterProducer(self) -> None:
|
||||||
"""Part of IProducer interface"""
|
"""Part of IProducer interface"""
|
||||||
self._producer = None
|
self._producer = None
|
||||||
|
assert self._finished_deferred is not None
|
||||||
if not self._finished_deferred.called:
|
if not self._finished_deferred.called:
|
||||||
self._bytes_queue.put_nowait(None)
|
self._bytes_queue.put_nowait(None)
|
||||||
|
|
||||||
def write(self, bytes):
|
def write(self, write_bytes: bytes) -> None:
|
||||||
"""Part of IProducer interface"""
|
"""Part of IProducer interface"""
|
||||||
if self._write_exception:
|
if self._write_exception:
|
||||||
raise self._write_exception
|
raise self._write_exception
|
||||||
|
|
||||||
|
assert self._finished_deferred is not None
|
||||||
if self._finished_deferred.called:
|
if self._finished_deferred.called:
|
||||||
raise Exception("consumer has closed")
|
raise Exception("consumer has closed")
|
||||||
|
|
||||||
self._bytes_queue.put_nowait(bytes)
|
self._bytes_queue.put_nowait(write_bytes)
|
||||||
|
|
||||||
# If this is a PushProducer and the queue is getting behind
|
# If this is a PushProducer and the queue is getting behind
|
||||||
# then we pause the producer.
|
# then we pause the producer.
|
||||||
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
|
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
|
||||||
self._paused_producer = True
|
self._paused_producer = True
|
||||||
self._producer.pauseProducing()
|
assert self._producer is not None
|
||||||
|
# cast safe because `streaming` means this is an IPushProducer
|
||||||
|
cast(IPushProducer, self._producer).pauseProducing()
|
||||||
|
|
||||||
def _writer(self):
|
def _writer(self) -> None:
|
||||||
"""This is run in a background thread to write to the file."""
|
"""This is run in a background thread to write to the file."""
|
||||||
try:
|
try:
|
||||||
while self._producer or not self._bytes_queue.empty():
|
while self._producer or not self._bytes_queue.empty():
|
||||||
|
@ -130,11 +140,11 @@ class BackgroundFileConsumer:
|
||||||
finally:
|
finally:
|
||||||
self._file_obj.close()
|
self._file_obj.close()
|
||||||
|
|
||||||
def wait(self):
|
def wait(self) -> "Deferred[None]":
|
||||||
"""Returns a deferred that resolves when finished writing to file"""
|
"""Returns a deferred that resolves when finished writing to file"""
|
||||||
return make_deferred_yieldable(self._finished_deferred)
|
return make_deferred_yieldable(self._finished_deferred)
|
||||||
|
|
||||||
def _resume_paused_producer(self):
|
def _resume_paused_producer(self) -> None:
|
||||||
"""Gets called if we should resume producing after being paused"""
|
"""Gets called if we should resume producing after being paused"""
|
||||||
if self._paused_producer and self._producer:
|
if self._paused_producer and self._producer:
|
||||||
self._paused_producer = False
|
self._paused_producer = False
|
||||||
|
|
|
@ -11,11 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
|
||||||
def freeze(o):
|
def freeze(o: Any) -> Any:
|
||||||
if isinstance(o, dict):
|
if isinstance(o, dict):
|
||||||
return frozendict({k: freeze(v) for k, v in o.items()})
|
return frozendict({k: freeze(v) for k, v in o.items()})
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ def freeze(o):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def unfreeze(o):
|
def unfreeze(o: Any) -> Any:
|
||||||
if isinstance(o, (dict, frozendict)):
|
if isinstance(o, (dict, frozendict)):
|
||||||
return {k: unfreeze(v) for k, v in o.items()}
|
return {k: unfreeze(v) for k, v in o.items()}
|
||||||
|
|
||||||
|
|
|
@ -13,42 +13,43 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from twisted.web.resource import NoResource
|
from twisted.web.resource import NoResource, Resource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_resource_tree(desired_tree, root_resource):
|
def create_resource_tree(
|
||||||
|
desired_tree: Dict[str, Resource], root_resource: Resource
|
||||||
|
) -> Resource:
|
||||||
"""Create the resource tree for this homeserver.
|
"""Create the resource tree for this homeserver.
|
||||||
|
|
||||||
This in unduly complicated because Twisted does not support putting
|
This in unduly complicated because Twisted does not support putting
|
||||||
child resources more than 1 level deep at a time.
|
child resources more than 1 level deep at a time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
web_client (bool): True to enable the web client.
|
desired_tree: Dict from desired paths to desired resources.
|
||||||
root_resource (twisted.web.resource.Resource): The root
|
root_resource: The root resource to add the tree to.
|
||||||
resource to add the tree to.
|
|
||||||
Returns:
|
Returns:
|
||||||
twisted.web.resource.Resource: the ``root_resource`` with a tree of
|
The ``root_resource`` with a tree of child resources added to it.
|
||||||
child resources added to it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ideally we'd just use getChild and putChild but getChild doesn't work
|
# ideally we'd just use getChild and putChild but getChild doesn't work
|
||||||
# unless you give it a Request object IN ADDITION to the name :/ So
|
# unless you give it a Request object IN ADDITION to the name :/ So
|
||||||
# instead, we'll store a copy of this mapping so we can actually add
|
# instead, we'll store a copy of this mapping so we can actually add
|
||||||
# extra resources to existing nodes. See self._resource_id for the key.
|
# extra resources to existing nodes. See self._resource_id for the key.
|
||||||
resource_mappings = {}
|
resource_mappings: Dict[str, Resource] = {}
|
||||||
for full_path, res in desired_tree.items():
|
for full_path_str, res in desired_tree.items():
|
||||||
# twisted requires all resources to be bytes
|
# twisted requires all resources to be bytes
|
||||||
full_path = full_path.encode("utf-8")
|
full_path = full_path_str.encode("utf-8")
|
||||||
|
|
||||||
logger.info("Attaching %s to path %s", res, full_path)
|
logger.info("Attaching %s to path %s", res, full_path)
|
||||||
last_resource = root_resource
|
last_resource = root_resource
|
||||||
for path_seg in full_path.split(b"/")[1:-1]:
|
for path_seg in full_path.split(b"/")[1:-1]:
|
||||||
if path_seg not in last_resource.listNames():
|
if path_seg not in last_resource.listNames():
|
||||||
# resource doesn't exist, so make a "dummy resource"
|
# resource doesn't exist, so make a "dummy resource"
|
||||||
child_resource = NoResource()
|
child_resource: Resource = NoResource()
|
||||||
last_resource.putChild(path_seg, child_resource)
|
last_resource.putChild(path_seg, child_resource)
|
||||||
res_id = _resource_id(last_resource, path_seg)
|
res_id = _resource_id(last_resource, path_seg)
|
||||||
resource_mappings[res_id] = child_resource
|
resource_mappings[res_id] = child_resource
|
||||||
|
@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource):
|
||||||
return root_resource
|
return root_resource
|
||||||
|
|
||||||
|
|
||||||
def _resource_id(resource, path_seg):
|
def _resource_id(resource: Resource, path_seg: bytes) -> str:
|
||||||
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
||||||
later.
|
later.
|
||||||
|
|
||||||
|
@ -96,4 +97,4 @@ def _resource_id(resource, path_seg):
|
||||||
Returns:
|
Returns:
|
||||||
str: A unique string which can be a key to the child Resource.
|
str: A unique string which can be a key to the child Resource.
|
||||||
"""
|
"""
|
||||||
return "%s-%s" % (resource, path_seg)
|
return "%s-%r" % (resource, path_seg)
|
||||||
|
|
|
@ -74,7 +74,7 @@ class ListNode(Generic[P]):
|
||||||
new_node._refs_insert_after(node)
|
new_node._refs_insert_after(node)
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def remove_from_list(self):
|
def remove_from_list(self) -> None:
|
||||||
"""Remove this node from the list."""
|
"""Remove this node from the list."""
|
||||||
with self._LOCK:
|
with self._LOCK:
|
||||||
self._refs_remove_node_from_list()
|
self._refs_remove_node_from_list()
|
||||||
|
@ -84,7 +84,7 @@ class ListNode(Generic[P]):
|
||||||
# immediately rather than at the next GC.
|
# immediately rather than at the next GC.
|
||||||
self.cache_entry = None
|
self.cache_entry = None
|
||||||
|
|
||||||
def move_after(self, node: "ListNode"):
|
def move_after(self, node: "ListNode") -> None:
|
||||||
"""Move this node from its current location in the list to after the
|
"""Move this node from its current location in the list to after the
|
||||||
given node.
|
given node.
|
||||||
"""
|
"""
|
||||||
|
@ -103,7 +103,7 @@ class ListNode(Generic[P]):
|
||||||
# Insert self back into the list, after target node
|
# Insert self back into the list, after target node
|
||||||
self._refs_insert_after(node)
|
self._refs_insert_after(node)
|
||||||
|
|
||||||
def _refs_remove_node_from_list(self):
|
def _refs_remove_node_from_list(self) -> None:
|
||||||
"""Internal method to *just* remove the node from the list, without
|
"""Internal method to *just* remove the node from the list, without
|
||||||
e.g. clearing out the cache entry.
|
e.g. clearing out the cache entry.
|
||||||
"""
|
"""
|
||||||
|
@ -122,7 +122,7 @@ class ListNode(Generic[P]):
|
||||||
self.prev_node = None
|
self.prev_node = None
|
||||||
self.next_node = None
|
self.next_node = None
|
||||||
|
|
||||||
def _refs_insert_after(self, node: "ListNode"):
|
def _refs_insert_after(self, node: "ListNode") -> None:
|
||||||
"""Internal method to insert the node after the given node."""
|
"""Internal method to insert the node after the given node."""
|
||||||
|
|
||||||
# This method should only be called when we're not already in the list.
|
# This method should only be called when we're not already in the list.
|
||||||
|
|
|
@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N
|
||||||
should be considered expired. Normally the current time.
|
should be considered expired. Normally the current time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def verify_expiry_caveat(caveat: str):
|
def verify_expiry_caveat(caveat: str) -> bool:
|
||||||
time_msec = get_time_ms()
|
time_msec = get_time_ms()
|
||||||
prefix = "time < "
|
prefix = "time < "
|
||||||
if not caveat.startswith(prefix):
|
if not caveat.startswith(prefix):
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from twisted.conch import manhole_ssh
|
from twisted.conch import manhole_ssh
|
||||||
from twisted.conch.insults import insults
|
from twisted.conch.insults import insults
|
||||||
|
@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
|
||||||
from twisted.conch.ssh.keys import Key
|
from twisted.conch.ssh.keys import Key
|
||||||
from twisted.cred import checkers, portal
|
from twisted.cred import checkers, portal
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
|
||||||
|
from synapse.config.server import ManholeConfig
|
||||||
|
|
||||||
PUBLIC_KEY = (
|
PUBLIC_KEY = (
|
||||||
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
|
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
|
||||||
|
@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
|
||||||
-----END RSA PRIVATE KEY-----"""
|
-----END RSA PRIVATE KEY-----"""
|
||||||
|
|
||||||
|
|
||||||
def manhole(settings, globals):
|
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
|
||||||
"""Starts a ssh listener with password authentication using
|
"""Starts a ssh listener with password authentication using
|
||||||
the given username and password. Clients connecting to the ssh
|
the given username and password. Clients connecting to the ssh
|
||||||
listener will find themselves in a colored python shell with
|
listener will find themselves in a colored python shell with
|
||||||
the supplied globals.
|
the supplied globals.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
username(str): The username ssh clients should auth with.
|
username: The username ssh clients should auth with.
|
||||||
password(str): The password ssh clients should auth with.
|
password: The password ssh clients should auth with.
|
||||||
globals(dict): The variables to expose in the shell.
|
globals: The variables to expose in the shell.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
|
A factory to pass to ``listenTCP``
|
||||||
"""
|
"""
|
||||||
username = settings.username
|
username = settings.username
|
||||||
password = settings.password
|
password = settings.password.encode("ascii")
|
||||||
priv_key = settings.priv_key
|
priv_key = settings.priv_key
|
||||||
if priv_key is None:
|
if priv_key is None:
|
||||||
priv_key = Key.fromString(PRIVATE_KEY)
|
priv_key = Key.fromString(PRIVATE_KEY)
|
||||||
|
@ -84,19 +88,22 @@ def manhole(settings, globals):
|
||||||
if pub_key is None:
|
if pub_key is None:
|
||||||
pub_key = Key.fromString(PUBLIC_KEY)
|
pub_key = Key.fromString(PUBLIC_KEY)
|
||||||
|
|
||||||
if not isinstance(password, bytes):
|
|
||||||
password = password.encode("ascii")
|
|
||||||
|
|
||||||
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
|
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
|
||||||
|
|
||||||
rlm = manhole_ssh.TerminalRealm()
|
rlm = manhole_ssh.TerminalRealm()
|
||||||
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
|
# mypy ignored here because:
|
||||||
|
# - can't deduce types of lambdas
|
||||||
|
# - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol]
|
||||||
|
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment]
|
||||||
SynapseManhole, dict(globals, __name__="__console__")
|
SynapseManhole, dict(globals, __name__="__console__")
|
||||||
)
|
)
|
||||||
|
|
||||||
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
||||||
factory.privateKeys[b"ssh-rsa"] = priv_key
|
|
||||||
factory.publicKeys[b"ssh-rsa"] = pub_key
|
# conch has the wrong type on these dicts (says bytes to bytes,
|
||||||
|
# should be bytes to Keys judging by how it's used).
|
||||||
|
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
|
||||||
|
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
@ -104,7 +111,7 @@ def manhole(settings, globals):
|
||||||
class SynapseManhole(ColoredManhole):
|
class SynapseManhole(ColoredManhole):
|
||||||
"""Overrides connectionMade to create our own ManholeInterpreter"""
|
"""Overrides connectionMade to create our own ManholeInterpreter"""
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
super().connectionMade()
|
super().connectionMade()
|
||||||
|
|
||||||
# replace the manhole interpreter with our own impl
|
# replace the manhole interpreter with our own impl
|
||||||
|
@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole):
|
||||||
|
|
||||||
|
|
||||||
class SynapseManholeInterpreter(ManholeInterpreter):
|
class SynapseManholeInterpreter(ManholeInterpreter):
|
||||||
def showsyntaxerror(self, filename=None):
|
def showsyntaxerror(self, filename: Optional[str] = None) -> None:
|
||||||
"""Display the syntax error that just occurred.
|
"""Display the syntax error that just occurred.
|
||||||
|
|
||||||
Overrides the base implementation, ignoring sys.excepthook. We always want
|
Overrides the base implementation, ignoring sys.excepthook. We always want
|
||||||
any syntax errors to be sent to the terminal, rather than sentry.
|
any syntax errors to be sent to the terminal, rather than sentry.
|
||||||
"""
|
"""
|
||||||
type, value, tb = sys.exc_info()
|
type, value, tb = sys.exc_info()
|
||||||
|
assert value is not None
|
||||||
sys.last_type = type
|
sys.last_type = type
|
||||||
sys.last_value = value
|
sys.last_value = value
|
||||||
sys.last_traceback = tb
|
sys.last_traceback = tb
|
||||||
|
@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
|
||||||
lines = traceback.format_exception_only(type, value)
|
lines = traceback.format_exception_only(type, value)
|
||||||
self.write("".join(lines))
|
self.write("".join(lines))
|
||||||
|
|
||||||
def showtraceback(self):
|
def showtraceback(self) -> None:
|
||||||
"""Display the exception that just occurred.
|
"""Display the exception that just occurred.
|
||||||
|
|
||||||
Overrides the base implementation, ignoring sys.excepthook. We always want
|
Overrides the base implementation, ignoring sys.excepthook. We always want
|
||||||
|
@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter):
|
||||||
"""
|
"""
|
||||||
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
|
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
|
||||||
sys.last_traceback = last_tb
|
sys.last_traceback = last_tb
|
||||||
|
assert last_tb is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We remove the first stack item because it is our own code.
|
# We remove the first stack item because it is our own code.
|
||||||
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
|
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
|
||||||
self.write("".join(lines))
|
self.write("".join(lines))
|
||||||
finally:
|
finally:
|
||||||
last_tb = ei = None
|
# On the line below, last_tb and ei appear to be dead.
|
||||||
|
# It's unclear whether there is a reason behind this line.
|
||||||
|
# It conceivably could be because an exception raised in this block
|
||||||
|
# will keep the local frame (containing these local variables) around.
|
||||||
|
# This was adapted taken from CPython's Lib/code.py; see here:
|
||||||
|
# https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150
|
||||||
|
last_tb = ei = None # type: ignore
|
||||||
|
|
||||||
def displayhook(self, obj):
|
def displayhook(self, obj: Any) -> None:
|
||||||
"""
|
"""
|
||||||
We override the displayhook so that we automatically convert coroutines
|
We override the displayhook so that we automatically convert coroutines
|
||||||
into Deferreds. (Our superclass' displayhook will take care of the rest,
|
into Deferreds. (Our superclass' displayhook will take care of the rest,
|
||||||
|
|
|
@ -24,7 +24,7 @@ from twisted.python.failure import Failure
|
||||||
_already_patched = False
|
_already_patched = False
|
||||||
|
|
||||||
|
|
||||||
def do_patch():
|
def do_patch() -> None:
|
||||||
"""
|
"""
|
||||||
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
|
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
|
||||||
"""
|
"""
|
||||||
|
@ -107,7 +107,7 @@ def do_patch():
|
||||||
_already_patched = True
|
_already_patched = True
|
||||||
|
|
||||||
|
|
||||||
def _check_yield_points(f: Callable, changes: List[str]):
|
def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
|
||||||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
|
||||||
checking that after every yield the log contexts are correct.
|
checking that after every yield the log contexts are correct.
|
||||||
|
|
||||||
|
|
|
@ -15,33 +15,36 @@
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
|
from typing import Any, DefaultDict, Iterator, List, Set
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
run_in_background,
|
run_in_background,
|
||||||
)
|
)
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from contextlib import _GeneratorContextManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FederationRateLimiter:
|
class FederationRateLimiter:
|
||||||
def __init__(self, clock, config):
|
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
|
||||||
"""
|
def new_limiter() -> "_PerHostRatelimiter":
|
||||||
Args:
|
|
||||||
clock (Clock)
|
|
||||||
config (FederationRateLimitConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def new_limiter():
|
|
||||||
return _PerHostRatelimiter(clock=clock, config=config)
|
return _PerHostRatelimiter(clock=clock, config=config)
|
||||||
|
|
||||||
self.ratelimiters = collections.defaultdict(new_limiter)
|
self.ratelimiters: DefaultDict[
|
||||||
|
str, "_PerHostRatelimiter"
|
||||||
|
] = collections.defaultdict(new_limiter)
|
||||||
|
|
||||||
def ratelimit(self, host):
|
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
|
||||||
"""Used to ratelimit an incoming request from a given host
|
"""Used to ratelimit an incoming request from a given host
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -60,11 +63,11 @@ class FederationRateLimiter:
|
||||||
|
|
||||||
|
|
||||||
class _PerHostRatelimiter:
|
class _PerHostRatelimiter:
|
||||||
def __init__(self, clock, config):
|
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
clock (Clock)
|
clock
|
||||||
config (FederationRateLimitConfig)
|
config
|
||||||
"""
|
"""
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
|
|
||||||
|
@ -75,21 +78,23 @@ class _PerHostRatelimiter:
|
||||||
self.concurrent_requests = config.concurrent
|
self.concurrent_requests = config.concurrent
|
||||||
|
|
||||||
# request_id objects for requests which have been slept
|
# request_id objects for requests which have been slept
|
||||||
self.sleeping_requests = set()
|
self.sleeping_requests: Set[object] = set()
|
||||||
|
|
||||||
# map from request_id object to Deferred for requests which are ready
|
# map from request_id object to Deferred for requests which are ready
|
||||||
# for processing but have been queued
|
# for processing but have been queued
|
||||||
self.ready_request_queue = collections.OrderedDict()
|
self.ready_request_queue: collections.OrderedDict[
|
||||||
|
object, defer.Deferred[None]
|
||||||
|
] = collections.OrderedDict()
|
||||||
|
|
||||||
# request id objects for requests which are in progress
|
# request id objects for requests which are in progress
|
||||||
self.current_processing = set()
|
self.current_processing: Set[object] = set()
|
||||||
|
|
||||||
# times at which we have recently (within the last window_size ms)
|
# times at which we have recently (within the last window_size ms)
|
||||||
# received requests.
|
# received requests.
|
||||||
self.request_times = []
|
self.request_times: List[int] = []
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def ratelimit(self):
|
def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
|
||||||
# `contextlib.contextmanager` takes a generator and turns it into a
|
# `contextlib.contextmanager` takes a generator and turns it into a
|
||||||
# context manager. The generator should only yield once with a value
|
# context manager. The generator should only yield once with a value
|
||||||
# to be returned by manager.
|
# to be returned by manager.
|
||||||
|
@ -102,7 +107,7 @@ class _PerHostRatelimiter:
|
||||||
finally:
|
finally:
|
||||||
self._on_exit(request_id)
|
self._on_exit(request_id)
|
||||||
|
|
||||||
def _on_enter(self, request_id):
|
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
# remove any entries from request_times which aren't within the window
|
# remove any entries from request_times which aren't within the window
|
||||||
|
@ -120,9 +125,9 @@ class _PerHostRatelimiter:
|
||||||
|
|
||||||
self.request_times.append(time_now)
|
self.request_times.append(time_now)
|
||||||
|
|
||||||
def queue_request():
|
def queue_request() -> "defer.Deferred[None]":
|
||||||
if len(self.current_processing) >= self.concurrent_requests:
|
if len(self.current_processing) >= self.concurrent_requests:
|
||||||
queue_defer = defer.Deferred()
|
queue_defer: defer.Deferred[None] = defer.Deferred()
|
||||||
self.ready_request_queue[request_id] = queue_defer
|
self.ready_request_queue[request_id] = queue_defer
|
||||||
logger.info(
|
logger.info(
|
||||||
"Ratelimiter: queueing request (queue now %i items)",
|
"Ratelimiter: queueing request (queue now %i items)",
|
||||||
|
@ -145,7 +150,7 @@ class _PerHostRatelimiter:
|
||||||
|
|
||||||
self.sleeping_requests.add(request_id)
|
self.sleeping_requests.add(request_id)
|
||||||
|
|
||||||
def on_wait_finished(_):
|
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
|
||||||
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
|
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
|
||||||
self.sleeping_requests.discard(request_id)
|
self.sleeping_requests.discard(request_id)
|
||||||
queue_defer = queue_request()
|
queue_defer = queue_request()
|
||||||
|
@ -155,19 +160,19 @@ class _PerHostRatelimiter:
|
||||||
else:
|
else:
|
||||||
ret_defer = queue_request()
|
ret_defer = queue_request()
|
||||||
|
|
||||||
def on_start(r):
|
def on_start(r: object) -> object:
|
||||||
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
|
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
|
||||||
self.current_processing.add(request_id)
|
self.current_processing.add(request_id)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def on_err(r):
|
def on_err(r: object) -> object:
|
||||||
# XXX: why is this necessary? this is called before we start
|
# XXX: why is this necessary? this is called before we start
|
||||||
# processing the request so why would the request be in
|
# processing the request so why would the request be in
|
||||||
# current_processing?
|
# current_processing?
|
||||||
self.current_processing.discard(request_id)
|
self.current_processing.discard(request_id)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def on_both(r):
|
def on_both(r: object) -> object:
|
||||||
# Ensure that we've properly cleaned up.
|
# Ensure that we've properly cleaned up.
|
||||||
self.sleeping_requests.discard(request_id)
|
self.sleeping_requests.discard(request_id)
|
||||||
self.ready_request_queue.pop(request_id, None)
|
self.ready_request_queue.pop(request_id, None)
|
||||||
|
@ -177,7 +182,7 @@ class _PerHostRatelimiter:
|
||||||
ret_defer.addBoth(on_both)
|
ret_defer.addBoth(on_both)
|
||||||
return make_deferred_yieldable(ret_defer)
|
return make_deferred_yieldable(ret_defer)
|
||||||
|
|
||||||
def _on_exit(self, request_id):
|
def _on_exit(self, request_id: object) -> None:
|
||||||
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
|
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
|
||||||
self.current_processing.discard(request_id)
|
self.current_processing.discard(request_id)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,9 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
import synapse.logging.context
|
import synapse.logging.context
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62
|
||||||
|
|
||||||
|
|
||||||
class NotRetryingDestination(Exception):
|
class NotRetryingDestination(Exception):
|
||||||
def __init__(self, retry_last_ts, retry_interval, destination):
|
def __init__(self, retry_last_ts: int, retry_interval: int, destination: str):
|
||||||
"""Raised by the limiter (and federation client) to indicate that we are
|
"""Raised by the limiter (and federation client) to indicate that we are
|
||||||
are deliberately not attempting to contact a given server.
|
are deliberately not attempting to contact a given server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
retry_last_ts (int): the unix ts in milliseconds of our last attempt
|
retry_last_ts: the unix ts in milliseconds of our last attempt
|
||||||
to contact the server. 0 indicates that the last attempt was
|
to contact the server. 0 indicates that the last attempt was
|
||||||
successful or that we've never actually attempted to connect.
|
successful or that we've never actually attempted to connect.
|
||||||
retry_interval (int): the time in milliseconds to wait until the next
|
retry_interval: the time in milliseconds to wait until the next
|
||||||
attempt.
|
attempt.
|
||||||
destination (str): the domain in question
|
destination: the domain in question
|
||||||
"""
|
"""
|
||||||
|
|
||||||
msg = "Not retrying server %s." % (destination,)
|
msg = "Not retrying server %s." % (destination,)
|
||||||
|
@ -51,7 +55,13 @@ class NotRetryingDestination(Exception):
|
||||||
self.destination = destination
|
self.destination = destination
|
||||||
|
|
||||||
|
|
||||||
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
|
async def get_retry_limiter(
|
||||||
|
destination: str,
|
||||||
|
clock: Clock,
|
||||||
|
store: DataStore,
|
||||||
|
ignore_backoff: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "RetryDestinationLimiter":
|
||||||
"""For a given destination check if we have previously failed to
|
"""For a given destination check if we have previously failed to
|
||||||
send a request there and are waiting before retrying the destination.
|
send a request there and are waiting before retrying the destination.
|
||||||
If we are not ready to retry the destination, this will raise a
|
If we are not ready to retry the destination, this will raise a
|
||||||
|
@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
||||||
CodeMessageException with code < 500)
|
CodeMessageException with code < 500)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): name of homeserver
|
destination: name of homeserver
|
||||||
clock (synapse.util.clock): timing source
|
clock: timing source
|
||||||
store (synapse.storage.transactions.TransactionStore): datastore
|
store: datastore
|
||||||
ignore_backoff (bool): true to ignore the historical backoff data and
|
ignore_backoff: true to ignore the historical backoff data and
|
||||||
try the request anyway. We will still reset the retry_interval on success.
|
try the request anyway. We will still reset the retry_interval on success.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
|
||||||
class RetryDestinationLimiter:
|
class RetryDestinationLimiter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
destination,
|
destination: str,
|
||||||
clock,
|
clock: Clock,
|
||||||
store,
|
store: DataStore,
|
||||||
failure_ts,
|
failure_ts: Optional[int],
|
||||||
retry_interval,
|
retry_interval: int,
|
||||||
backoff_on_404=False,
|
backoff_on_404: bool = False,
|
||||||
backoff_on_failure=True,
|
backoff_on_failure: bool = True,
|
||||||
):
|
):
|
||||||
"""Marks the destination as "down" if an exception is thrown in the
|
"""Marks the destination as "down" if an exception is thrown in the
|
||||||
context, except for CodeMessageException with code < 500.
|
context, except for CodeMessageException with code < 500.
|
||||||
|
@ -128,17 +138,17 @@ class RetryDestinationLimiter:
|
||||||
If no exception is raised, marks the destination as "up".
|
If no exception is raised, marks the destination as "up".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str)
|
destination
|
||||||
clock (Clock)
|
clock
|
||||||
store (DataStore)
|
store
|
||||||
failure_ts (int|None): when this destination started failing (in ms since
|
failure_ts: when this destination started failing (in ms since
|
||||||
the epoch), or zero if the last request was successful
|
the epoch), or zero if the last request was successful
|
||||||
retry_interval (int): The next retry interval taken from the
|
retry_interval: The next retry interval taken from the
|
||||||
database in milliseconds, or zero if the last request was
|
database in milliseconds, or zero if the last request was
|
||||||
successful.
|
successful.
|
||||||
backoff_on_404 (bool): Back off if we get a 404
|
backoff_on_404: Back off if we get a 404
|
||||||
|
|
||||||
backoff_on_failure (bool): set to False if we should not increase the
|
backoff_on_failure: set to False if we should not increase the
|
||||||
retry interval on a failure.
|
retry interval on a failure.
|
||||||
"""
|
"""
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
|
@ -150,10 +160,15 @@ class RetryDestinationLimiter:
|
||||||
self.backoff_on_404 = backoff_on_404
|
self.backoff_on_404 = backoff_on_404
|
||||||
self.backoff_on_failure = backoff_on_failure
|
self.backoff_on_failure = backoff_on_failure
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
valid_err_code = False
|
valid_err_code = False
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
valid_err_code = True
|
valid_err_code = True
|
||||||
|
@ -161,7 +176,7 @@ class RetryDestinationLimiter:
|
||||||
# avoid treating exceptions which don't derive from Exception as
|
# avoid treating exceptions which don't derive from Exception as
|
||||||
# failures; this is mostly so as not to catch defer._DefGen.
|
# failures; this is mostly so as not to catch defer._DefGen.
|
||||||
valid_err_code = True
|
valid_err_code = True
|
||||||
elif issubclass(exc_type, CodeMessageException):
|
elif isinstance(exc_val, CodeMessageException):
|
||||||
# Some error codes are perfectly fine for some APIs, whereas other
|
# Some error codes are perfectly fine for some APIs, whereas other
|
||||||
# APIs may expect to never received e.g. a 404. It's important to
|
# APIs may expect to never received e.g. a 404. It's important to
|
||||||
# handle 404 as some remote servers will return a 404 when the HS
|
# handle 404 as some remote servers will return a 404 when the HS
|
||||||
|
@ -216,7 +231,7 @@ class RetryDestinationLimiter:
|
||||||
if self.failure_ts is None:
|
if self.failure_ts is None:
|
||||||
self.failure_ts = retry_last_ts
|
self.failure_ts = retry_last_ts
|
||||||
|
|
||||||
async def store_retry_timings():
|
async def store_retry_timings() -> None:
|
||||||
try:
|
try:
|
||||||
await self.store.set_destination_retry_timings(
|
await self.store.set_destination_retry_timings(
|
||||||
self.destination,
|
self.destination,
|
||||||
|
|
|
@ -18,7 +18,7 @@ import resource
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
|
|
||||||
def change_resource_limit(soft_file_no):
|
def change_resource_limit(soft_file_no: int) -> None:
|
||||||
try:
|
try:
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
@ -25,9 +25,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def build_jinja_env(
|
def build_jinja_env(
|
||||||
template_search_directories: Iterable[str],
|
template_search_directories: Sequence[str],
|
||||||
config: "HomeServerConfig",
|
config: "HomeServerConfig",
|
||||||
autoescape: Union[bool, Callable[[str], bool], None] = None,
|
autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None,
|
||||||
) -> jinja2.Environment:
|
) -> jinja2.Environment:
|
||||||
"""Set up a Jinja2 environment to load templates from the given search path
|
"""Set up a Jinja2 environment to load templates from the given search path
|
||||||
|
|
||||||
|
@ -110,5 +110,5 @@ def _create_mxc_to_http_filter(
|
||||||
return mxc_to_http_filter
|
return mxc_to_http_filter
|
||||||
|
|
||||||
|
|
||||||
def _format_ts_filter(value: int, format: str):
|
def _format_ts_filter(value: int, format: str) -> str:
|
||||||
return time.strftime(format, time.localtime(value / 1000))
|
return time.strftime(format, time.localtime(value / 1000))
|
||||||
|
|
|
@ -14,6 +14,10 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import typing
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -28,13 +32,13 @@ logger = logging.getLogger(__name__)
|
||||||
MAX_EMAIL_ADDRESS_LENGTH = 500
|
MAX_EMAIL_ADDRESS_LENGTH = 500
|
||||||
|
|
||||||
|
|
||||||
def check_3pid_allowed(hs, medium, address):
|
def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
|
||||||
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer): server
|
hs: server
|
||||||
medium (str): 3pid medium - e.g. email, msisdn
|
medium: 3pid medium - e.g. email, msisdn
|
||||||
address (str): address within that medium (e.g. "wotan@matrix.org")
|
address: address within that medium (e.g. "wotan@matrix.org")
|
||||||
msisdns need to first have been canonicalised
|
msisdns need to first have been canonicalised
|
||||||
Returns:
|
Returns:
|
||||||
bool: whether the 3PID medium/address is allowed to be added to this HS
|
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||||
|
|
|
@ -19,7 +19,7 @@ import subprocess
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_version_string(module):
|
def get_version_string(module) -> str:
|
||||||
"""Given a module calculate a git-aware version string for it.
|
"""Given a module calculate a git-aware version string for it.
|
||||||
|
|
||||||
If called on a module not in a git checkout will return `__verison__`.
|
If called on a module not in a git checkout will return `__verison__`.
|
||||||
|
|
|
@ -11,38 +11,41 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Generic, List, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class _Entry:
|
class _Entry(Generic[T]):
|
||||||
__slots__ = ["end_key", "queue"]
|
__slots__ = ["end_key", "queue"]
|
||||||
|
|
||||||
def __init__(self, end_key):
|
def __init__(self, end_key: int) -> None:
|
||||||
self.end_key = end_key
|
self.end_key: int = end_key
|
||||||
self.queue = []
|
self.queue: List[T] = []
|
||||||
|
|
||||||
|
|
||||||
class WheelTimer:
|
class WheelTimer(Generic[T]):
|
||||||
"""Stores arbitrary objects that will be returned after their timers have
|
"""Stores arbitrary objects that will be returned after their timers have
|
||||||
expired.
|
expired.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bucket_size=5000):
|
def __init__(self, bucket_size: int = 5000) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
bucket_size (int): Size of buckets in ms. Corresponds roughly to the
|
bucket_size: Size of buckets in ms. Corresponds roughly to the
|
||||||
accuracy of the timer.
|
accuracy of the timer.
|
||||||
"""
|
"""
|
||||||
self.bucket_size = bucket_size
|
self.bucket_size: int = bucket_size
|
||||||
self.entries = []
|
self.entries: List[_Entry[T]] = []
|
||||||
self.current_tick = 0
|
self.current_tick: int = 0
|
||||||
|
|
||||||
def insert(self, now, obj, then):
|
def insert(self, now: int, obj: T, then: int) -> None:
|
||||||
"""Inserts object into timer.
|
"""Inserts object into timer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
now (int): Current time in msec
|
now: Current time in msec
|
||||||
obj (object): Object to be inserted
|
obj: Object to be inserted
|
||||||
then (int): When to return the object strictly after.
|
then: When to return the object strictly after.
|
||||||
"""
|
"""
|
||||||
then_key = int(then / self.bucket_size) + 1
|
then_key = int(then / self.bucket_size) + 1
|
||||||
|
|
||||||
|
@ -70,7 +73,7 @@ class WheelTimer:
|
||||||
|
|
||||||
self.entries[-1].queue.append(obj)
|
self.entries[-1].queue.append(obj)
|
||||||
|
|
||||||
def fetch(self, now):
|
def fetch(self, now: int) -> List[T]:
|
||||||
"""Fetch any objects that have timed out
|
"""Fetch any objects that have timed out
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -87,5 +90,5 @@ class WheelTimer:
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return sum(len(entry.queue) for entry in self.entries)
|
return sum(len(entry.queue) for entry in self.entries)
|
||||||
|
|
|
@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource):
|
||||||
FederationRateLimitConfig(
|
FederationRateLimitConfig(
|
||||||
window_size=1,
|
window_size=1,
|
||||||
sleep_limit=1,
|
sleep_limit=1,
|
||||||
sleep_msec=1,
|
sleep_delay=1,
|
||||||
reject_limit=1000,
|
reject_limit=1000,
|
||||||
concurrent_requests=1000,
|
concurrent=1000,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue