Add missing type hints to `synapse.app`. (#11287)
This commit is contained in:
parent
66c4b774fd
commit
5cace20bf1
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to `synapse.app`.
|
19
mypy.ini
19
mypy.ini
|
@ -23,22 +23,6 @@ files =
|
||||||
# https://docs.python.org/3/library/re.html#re.X
|
# https://docs.python.org/3/library/re.html#re.X
|
||||||
exclude = (?x)
|
exclude = (?x)
|
||||||
^(
|
^(
|
||||||
|synapse/app/__init__.py
|
|
||||||
|synapse/app/_base.py
|
|
||||||
|synapse/app/admin_cmd.py
|
|
||||||
|synapse/app/appservice.py
|
|
||||||
|synapse/app/client_reader.py
|
|
||||||
|synapse/app/event_creator.py
|
|
||||||
|synapse/app/federation_reader.py
|
|
||||||
|synapse/app/federation_sender.py
|
|
||||||
|synapse/app/frontend_proxy.py
|
|
||||||
|synapse/app/generic_worker.py
|
|
||||||
|synapse/app/homeserver.py
|
|
||||||
|synapse/app/media_repository.py
|
|
||||||
|synapse/app/phone_stats_home.py
|
|
||||||
|synapse/app/pusher.py
|
|
||||||
|synapse/app/synchrotron.py
|
|
||||||
|synapse/app/user_dir.py
|
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/__init__.py
|
|synapse/storage/databases/main/__init__.py
|
||||||
|synapse/storage/databases/main/account_data.py
|
|synapse/storage/databases/main/account_data.py
|
||||||
|
@ -179,6 +163,9 @@ exclude = (?x)
|
||||||
[mypy-synapse.api.*]
|
[mypy-synapse.api.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.app.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.crypto.*]
|
[mypy-synapse.crypto.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Container
|
||||||
|
|
||||||
from synapse import python_dependencies # noqa: E402
|
from synapse import python_dependencies # noqa: E402
|
||||||
|
|
||||||
|
@ -27,7 +28,9 @@ except python_dependencies.DependencyException as e:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def check_bind_error(e, address, bind_addresses):
|
def check_bind_error(
|
||||||
|
e: Exception, address: str, bind_addresses: Container[str]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
This method checks an exception occurred while binding on 0.0.0.0.
|
This method checks an exception occurred while binding on 0.0.0.0.
|
||||||
If :: is specified in the bind addresses a warning is shown.
|
If :: is specified in the bind addresses a warning is shown.
|
||||||
|
@ -38,9 +41,9 @@ def check_bind_error(e, address, bind_addresses):
|
||||||
When binding on 0.0.0.0 after :: this can safely be ignored.
|
When binding on 0.0.0.0 after :: this can safely be ignored.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
e (Exception): Exception that was caught.
|
e: Exception that was caught.
|
||||||
address (str): Address on which binding was attempted.
|
address: Address on which binding was attempted.
|
||||||
bind_addresses (list): Addresses on which the service listens.
|
bind_addresses: Addresses on which the service listens.
|
||||||
"""
|
"""
|
||||||
if address == "0.0.0.0" and "::" in bind_addresses:
|
if address == "0.0.0.0" and "::" in bind_addresses:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -22,13 +22,27 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
NoReturn,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from cryptography.utils import CryptographyDeprecationWarning
|
from cryptography.utils import CryptographyDeprecationWarning
|
||||||
from typing_extensions import NoReturn
|
|
||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
from twisted.internet import defer, error, reactor
|
from twisted.internet import defer, error, reactor as _reactor
|
||||||
|
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP
|
||||||
|
from twisted.internet.protocol import ServerFactory
|
||||||
|
from twisted.internet.tcp import Port
|
||||||
from twisted.logger import LoggingFile, LogLevel
|
from twisted.logger import LoggingFile, LogLevel
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
from twisted.python.threadpool import ThreadPool
|
from twisted.python.threadpool import ThreadPool
|
||||||
|
@ -48,6 +62,7 @@ from synapse.logging.context import PreserveLoggingContext
|
||||||
from synapse.metrics import register_threadpool
|
from synapse.metrics import register_threadpool
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
||||||
from synapse.util.daemonize import daemonize_process
|
from synapse.util.daemonize import daemonize_process
|
||||||
from synapse.util.gai_resolver import GAIResolver
|
from synapse.util.gai_resolver import GAIResolver
|
||||||
|
@ -57,33 +72,44 @@ from synapse.util.versionstring import get_version_string
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
# Twisted injects the global reactor to make it easier to import, this confuses
|
||||||
|
# mypy which thinks it is a module. Tell it that it a more proper type.
|
||||||
|
reactor = cast(ISynapseReactor, _reactor)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# list of tuples of function, args list, kwargs dict
|
# list of tuples of function, args list, kwargs dict
|
||||||
_sighup_callbacks = []
|
_sighup_callbacks: List[
|
||||||
|
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
|
||||||
def register_sighup(func, *args, **kwargs):
|
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Register a function to be called when a SIGHUP occurs.
|
Register a function to be called when a SIGHUP occurs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (function): Function to be called when sent a SIGHUP signal.
|
func: Function to be called when sent a SIGHUP signal.
|
||||||
*args, **kwargs: args and kwargs to be passed to the target function.
|
*args, **kwargs: args and kwargs to be passed to the target function.
|
||||||
"""
|
"""
|
||||||
_sighup_callbacks.append((func, args, kwargs))
|
_sighup_callbacks.append((func, args, kwargs))
|
||||||
|
|
||||||
|
|
||||||
def start_worker_reactor(appname, config, run_command=reactor.run):
|
def start_worker_reactor(
|
||||||
|
appname: str,
|
||||||
|
config: HomeServerConfig,
|
||||||
|
run_command: Callable[[], None] = reactor.run,
|
||||||
|
) -> None:
|
||||||
"""Run the reactor in the main process
|
"""Run the reactor in the main process
|
||||||
|
|
||||||
Daemonizes if necessary, and then configures some resources, before starting
|
Daemonizes if necessary, and then configures some resources, before starting
|
||||||
the reactor. Pulls configuration from the 'worker' settings in 'config'.
|
the reactor. Pulls configuration from the 'worker' settings in 'config'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
appname (str): application name which will be sent to syslog
|
appname: application name which will be sent to syslog
|
||||||
config (synapse.config.Config): config object
|
config: config object
|
||||||
run_command (Callable[]): callable that actually runs the reactor
|
run_command: callable that actually runs the reactor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger = logging.getLogger(config.worker.worker_app)
|
logger = logging.getLogger(config.worker.worker_app)
|
||||||
|
@ -101,32 +127,32 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
|
||||||
|
|
||||||
|
|
||||||
def start_reactor(
|
def start_reactor(
|
||||||
appname,
|
appname: str,
|
||||||
soft_file_limit,
|
soft_file_limit: int,
|
||||||
gc_thresholds,
|
gc_thresholds: Tuple[int, int, int],
|
||||||
pid_file,
|
pid_file: str,
|
||||||
daemonize,
|
daemonize: bool,
|
||||||
print_pidfile,
|
print_pidfile: bool,
|
||||||
logger,
|
logger: logging.Logger,
|
||||||
run_command=reactor.run,
|
run_command: Callable[[], None] = reactor.run,
|
||||||
):
|
) -> None:
|
||||||
"""Run the reactor in the main process
|
"""Run the reactor in the main process
|
||||||
|
|
||||||
Daemonizes if necessary, and then configures some resources, before starting
|
Daemonizes if necessary, and then configures some resources, before starting
|
||||||
the reactor
|
the reactor
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
appname (str): application name which will be sent to syslog
|
appname: application name which will be sent to syslog
|
||||||
soft_file_limit (int):
|
soft_file_limit:
|
||||||
gc_thresholds:
|
gc_thresholds:
|
||||||
pid_file (str): name of pid file to write to if daemonize is True
|
pid_file: name of pid file to write to if daemonize is True
|
||||||
daemonize (bool): true to run the reactor in a background process
|
daemonize: true to run the reactor in a background process
|
||||||
print_pidfile (bool): whether to print the pid file, if daemonize is True
|
print_pidfile: whether to print the pid file, if daemonize is True
|
||||||
logger (logging.Logger): logger instance to pass to Daemonize
|
logger: logger instance to pass to Daemonize
|
||||||
run_command (Callable[]): callable that actually runs the reactor
|
run_command: callable that actually runs the reactor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run():
|
def run() -> None:
|
||||||
logger.info("Running")
|
logger.info("Running")
|
||||||
setup_jemalloc_stats()
|
setup_jemalloc_stats()
|
||||||
change_resource_limit(soft_file_limit)
|
change_resource_limit(soft_file_limit)
|
||||||
|
@ -185,7 +211,7 @@ def redirect_stdio_to_logs() -> None:
|
||||||
print("Redirected stdout/stderr to logs")
|
print("Redirected stdout/stderr to logs")
|
||||||
|
|
||||||
|
|
||||||
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
|
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
|
||||||
"""Register a callback with the reactor, to be called once it is running
|
"""Register a callback with the reactor, to be called once it is running
|
||||||
|
|
||||||
This can be used to initialise parts of the system which require an asynchronous
|
This can be used to initialise parts of the system which require an asynchronous
|
||||||
|
@ -195,7 +221,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
|
||||||
will exit.
|
will exit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def wrapper():
|
async def wrapper() -> None:
|
||||||
try:
|
try:
|
||||||
await cb(*args, **kwargs)
|
await cb(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -224,7 +250,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
|
||||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||||
|
|
||||||
|
|
||||||
def listen_metrics(bind_addresses, port):
|
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
||||||
"""
|
"""
|
||||||
Start Prometheus metrics server.
|
Start Prometheus metrics server.
|
||||||
"""
|
"""
|
||||||
|
@ -236,11 +262,11 @@ def listen_metrics(bind_addresses, port):
|
||||||
|
|
||||||
|
|
||||||
def listen_manhole(
|
def listen_manhole(
|
||||||
bind_addresses: Iterable[str],
|
bind_addresses: Collection[str],
|
||||||
port: int,
|
port: int,
|
||||||
manhole_settings: ManholeConfig,
|
manhole_settings: ManholeConfig,
|
||||||
manhole_globals: dict,
|
manhole_globals: dict,
|
||||||
):
|
) -> None:
|
||||||
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
|
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
|
||||||
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
|
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
|
||||||
# suppress the warning for now.
|
# suppress the warning for now.
|
||||||
|
@ -259,12 +285,18 @@ def listen_manhole(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
def listen_tcp(
|
||||||
|
bind_addresses: Collection[str],
|
||||||
|
port: int,
|
||||||
|
factory: ServerFactory,
|
||||||
|
reactor: IReactorTCP = reactor,
|
||||||
|
backlog: int = 50,
|
||||||
|
) -> List[Port]:
|
||||||
"""
|
"""
|
||||||
Create a TCP socket for a port and several addresses
|
Create a TCP socket for a port and several addresses
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[twisted.internet.tcp.Port]: listening for TCP connections
|
list of twisted.internet.tcp.Port listening for TCP connections
|
||||||
"""
|
"""
|
||||||
r = []
|
r = []
|
||||||
for address in bind_addresses:
|
for address in bind_addresses:
|
||||||
|
@ -273,12 +305,19 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
||||||
except error.CannotListenError as e:
|
except error.CannotListenError as e:
|
||||||
check_bind_error(e, address, bind_addresses)
|
check_bind_error(e, address, bind_addresses)
|
||||||
|
|
||||||
return r
|
# IReactorTCP returns an object implementing IListeningPort from listenTCP,
|
||||||
|
# but we know it will be a Port instance.
|
||||||
|
return r # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def listen_ssl(
|
def listen_ssl(
|
||||||
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
|
bind_addresses: Collection[str],
|
||||||
):
|
port: int,
|
||||||
|
factory: ServerFactory,
|
||||||
|
context_factory: IOpenSSLContextFactory,
|
||||||
|
reactor: IReactorSSL = reactor,
|
||||||
|
backlog: int = 50,
|
||||||
|
) -> List[Port]:
|
||||||
"""
|
"""
|
||||||
Create an TLS-over-TCP socket for a port and several addresses
|
Create an TLS-over-TCP socket for a port and several addresses
|
||||||
|
|
||||||
|
@ -294,10 +333,13 @@ def listen_ssl(
|
||||||
except error.CannotListenError as e:
|
except error.CannotListenError as e:
|
||||||
check_bind_error(e, address, bind_addresses)
|
check_bind_error(e, address, bind_addresses)
|
||||||
|
|
||||||
return r
|
# IReactorSSL incorrectly declares that an int is returned from listenSSL,
|
||||||
|
# it actually returns an object implementing IListeningPort, but we know it
|
||||||
|
# will be a Port instance.
|
||||||
|
return r # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def refresh_certificate(hs: "HomeServer"):
|
def refresh_certificate(hs: "HomeServer") -> None:
|
||||||
"""
|
"""
|
||||||
Refresh the TLS certificates that Synapse is using by re-reading them from
|
Refresh the TLS certificates that Synapse is using by re-reading them from
|
||||||
disk and updating the TLS context factories to use them.
|
disk and updating the TLS context factories to use them.
|
||||||
|
@ -329,7 +371,7 @@ def refresh_certificate(hs: "HomeServer"):
|
||||||
logger.info("Context factories updated.")
|
logger.info("Context factories updated.")
|
||||||
|
|
||||||
|
|
||||||
async def start(hs: "HomeServer"):
|
async def start(hs: "HomeServer") -> None:
|
||||||
"""
|
"""
|
||||||
Start a Synapse server or worker.
|
Start a Synapse server or worker.
|
||||||
|
|
||||||
|
@ -360,7 +402,7 @@ async def start(hs: "HomeServer"):
|
||||||
if hasattr(signal, "SIGHUP"):
|
if hasattr(signal, "SIGHUP"):
|
||||||
|
|
||||||
@wrap_as_background_process("sighup")
|
@wrap_as_background_process("sighup")
|
||||||
def handle_sighup(*args, **kwargs):
|
def handle_sighup(*args: Any, **kwargs: Any) -> None:
|
||||||
# Tell systemd our state, if we're using it. This will silently fail if
|
# Tell systemd our state, if we're using it. This will silently fail if
|
||||||
# we're not using systemd.
|
# we're not using systemd.
|
||||||
sdnotify(b"RELOADING=1")
|
sdnotify(b"RELOADING=1")
|
||||||
|
@ -373,7 +415,7 @@ async def start(hs: "HomeServer"):
|
||||||
# We defer running the sighup handlers until next reactor tick. This
|
# We defer running the sighup handlers until next reactor tick. This
|
||||||
# is so that we're in a sane state, e.g. flushing the logs may fail
|
# is so that we're in a sane state, e.g. flushing the logs may fail
|
||||||
# if the sighup happens in the middle of writing a log entry.
|
# if the sighup happens in the middle of writing a log entry.
|
||||||
def run_sighup(*args, **kwargs):
|
def run_sighup(*args: Any, **kwargs: Any) -> None:
|
||||||
# `callFromThread` should be "signal safe" as well as thread
|
# `callFromThread` should be "signal safe" as well as thread
|
||||||
# safe.
|
# safe.
|
||||||
reactor.callFromThread(handle_sighup, *args, **kwargs)
|
reactor.callFromThread(handle_sighup, *args, **kwargs)
|
||||||
|
@ -436,12 +478,8 @@ async def start(hs: "HomeServer"):
|
||||||
atexit.register(gc.freeze)
|
atexit.register(gc.freeze)
|
||||||
|
|
||||||
|
|
||||||
def setup_sentry(hs: "HomeServer"):
|
def setup_sentry(hs: "HomeServer") -> None:
|
||||||
"""Enable sentry integration, if enabled in configuration
|
"""Enable sentry integration, if enabled in configuration"""
|
||||||
|
|
||||||
Args:
|
|
||||||
hs
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not hs.config.metrics.sentry_enabled:
|
if not hs.config.metrics.sentry_enabled:
|
||||||
return
|
return
|
||||||
|
@ -466,7 +504,7 @@ def setup_sentry(hs: "HomeServer"):
|
||||||
scope.set_tag("worker_name", name)
|
scope.set_tag("worker_name", name)
|
||||||
|
|
||||||
|
|
||||||
def setup_sdnotify(hs: "HomeServer"):
|
def setup_sdnotify(hs: "HomeServer") -> None:
|
||||||
"""Adds process state hooks to tell systemd what we are up to."""
|
"""Adds process state hooks to tell systemd what we are up to."""
|
||||||
|
|
||||||
# Tell systemd our state, if we're using it. This will silently fail if
|
# Tell systemd our state, if we're using it. This will silently fail if
|
||||||
|
@ -481,7 +519,7 @@ def setup_sdnotify(hs: "HomeServer"):
|
||||||
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
|
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
|
||||||
|
|
||||||
|
|
||||||
def sdnotify(state):
|
def sdnotify(state: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
|
Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
|
||||||
|
|
||||||
|
@ -490,7 +528,7 @@ def sdnotify(state):
|
||||||
package which many OSes don't include as a matter of principle.
|
package which many OSes don't include as a matter of principle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state (bytes): notification to send
|
state: notification to send
|
||||||
"""
|
"""
|
||||||
if not isinstance(state, bytes):
|
if not isinstance(state, bytes):
|
||||||
raise TypeError("sdnotify should be called with a bytes")
|
raise TypeError("sdnotify should be called with a bytes")
|
||||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from twisted.internet import defer, task
|
from twisted.internet import defer, task
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ from synapse.app import _base
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.logger import setup_logging
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.handlers.admin import ExfiltrationWriter
|
from synapse.handlers.admin import ExfiltrationWriter
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||||
|
@ -40,6 +42,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||||
|
from synapse.types import StateMap
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
@ -65,16 +68,11 @@ class AdminCmdSlavedStore(
|
||||||
|
|
||||||
|
|
||||||
class AdminCmdServer(HomeServer):
|
class AdminCmdServer(HomeServer):
|
||||||
DATASTORE_CLASS = AdminCmdSlavedStore
|
DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def export_data_command(hs: HomeServer, args):
|
async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None:
|
||||||
"""Export data for a user.
|
"""Export data for a user."""
|
||||||
|
|
||||||
Args:
|
|
||||||
hs
|
|
||||||
args (argparse.Namespace)
|
|
||||||
"""
|
|
||||||
|
|
||||||
user_id = args.user_id
|
user_id = args.user_id
|
||||||
directory = args.output_directory
|
directory = args.output_directory
|
||||||
|
@ -92,12 +90,12 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
Note: This writes to disk on the main reactor thread.
|
Note: This writes to disk on the main reactor thread.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user whose data is being exfiltrated.
|
user_id: The user whose data is being exfiltrated.
|
||||||
directory (str|None): The directory to write the data to, if None then
|
directory: The directory to write the data to, if None then will write
|
||||||
will write to a temporary directory.
|
to a temporary directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, user_id, directory=None):
|
def __init__(self, user_id: str, directory: Optional[str] = None):
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
|
|
||||||
if directory:
|
if directory:
|
||||||
|
@ -111,7 +109,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
if list(os.listdir(self.base_directory)):
|
if list(os.listdir(self.base_directory)):
|
||||||
raise Exception("Directory must be empty")
|
raise Exception("Directory must be empty")
|
||||||
|
|
||||||
def write_events(self, room_id, events):
|
def write_events(self, room_id: str, events: List[EventBase]) -> None:
|
||||||
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
||||||
os.makedirs(room_directory, exist_ok=True)
|
os.makedirs(room_directory, exist_ok=True)
|
||||||
events_file = os.path.join(room_directory, "events")
|
events_file = os.path.join(room_directory, "events")
|
||||||
|
@ -120,7 +118,9 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
for event in events:
|
for event in events:
|
||||||
print(json.dumps(event.get_pdu_json()), file=f)
|
print(json.dumps(event.get_pdu_json()), file=f)
|
||||||
|
|
||||||
def write_state(self, room_id, event_id, state):
|
def write_state(
|
||||||
|
self, room_id: str, event_id: str, state: StateMap[EventBase]
|
||||||
|
) -> None:
|
||||||
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
||||||
state_directory = os.path.join(room_directory, "state")
|
state_directory = os.path.join(room_directory, "state")
|
||||||
os.makedirs(state_directory, exist_ok=True)
|
os.makedirs(state_directory, exist_ok=True)
|
||||||
|
@ -131,7 +131,9 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
for event in state.values():
|
for event in state.values():
|
||||||
print(json.dumps(event.get_pdu_json()), file=f)
|
print(json.dumps(event.get_pdu_json()), file=f)
|
||||||
|
|
||||||
def write_invite(self, room_id, event, state):
|
def write_invite(
|
||||||
|
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> None:
|
||||||
self.write_events(room_id, [event])
|
self.write_events(room_id, [event])
|
||||||
|
|
||||||
# We write the invite state somewhere else as they aren't full events
|
# We write the invite state somewhere else as they aren't full events
|
||||||
|
@ -145,7 +147,9 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
for event in state.values():
|
for event in state.values():
|
||||||
print(json.dumps(event), file=f)
|
print(json.dumps(event), file=f)
|
||||||
|
|
||||||
def write_knock(self, room_id, event, state):
|
def write_knock(
|
||||||
|
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||||
|
) -> None:
|
||||||
self.write_events(room_id, [event])
|
self.write_events(room_id, [event])
|
||||||
|
|
||||||
# We write the knock state somewhere else as they aren't full events
|
# We write the knock state somewhere else as they aren't full events
|
||||||
|
@ -159,11 +163,11 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||||
for event in state.values():
|
for event in state.values():
|
||||||
print(json.dumps(event), file=f)
|
print(json.dumps(event), file=f)
|
||||||
|
|
||||||
def finished(self):
|
def finished(self) -> str:
|
||||||
return self.base_directory
|
return self.base_directory
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options: List[str]) -> None:
|
||||||
parser = argparse.ArgumentParser(description="Synapse Admin Command")
|
parser = argparse.ArgumentParser(description="Synapse Admin Command")
|
||||||
HomeServerConfig.add_arguments_to_parser(parser)
|
HomeServerConfig.add_arguments_to_parser(parser)
|
||||||
|
|
||||||
|
@ -231,7 +235,7 @@ def start(config_options):
|
||||||
# We also make sure that `_base.start` gets run before we actually run the
|
# We also make sure that `_base.start` gets run before we actually run the
|
||||||
# command.
|
# command.
|
||||||
|
|
||||||
async def run():
|
async def run() -> None:
|
||||||
with LoggingContext("command"):
|
with LoggingContext("command"):
|
||||||
await _base.start(ss)
|
await _base.start(ss)
|
||||||
await args.func(ss, args)
|
await args.func(ss, args)
|
||||||
|
|
|
@ -14,11 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import address
|
from twisted.internet import address
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.events
|
import synapse.events
|
||||||
|
@ -44,7 +43,7 @@ from synapse.config.server import ListenerConfig
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
from synapse.http.server import JsonResource, OptionsResource
|
from synapse.http.server import JsonResource, OptionsResource
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseSite
|
from synapse.http.site import SynapseRequest, SynapseSite
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||||
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
||||||
|
@ -119,6 +118,7 @@ from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
||||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||||
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
@ -143,7 +143,9 @@ class KeyUploadServlet(RestServlet):
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
self.main_uri = hs.config.worker.worker_main_http_uri
|
self.main_uri = hs.config.worker.worker_main_http_uri
|
||||||
|
|
||||||
async def on_POST(self, request: Request, device_id: Optional[str]):
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, device_id: Optional[str]
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
@ -187,9 +189,8 @@ class KeyUploadServlet(RestServlet):
|
||||||
# If the header exists, add to the comma-separated list of the first
|
# If the header exists, add to the comma-separated list of the first
|
||||||
# instance of the header. Otherwise, generate a new header.
|
# instance of the header. Otherwise, generate a new header.
|
||||||
if x_forwarded_for:
|
if x_forwarded_for:
|
||||||
x_forwarded_for = [
|
x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host]
|
||||||
x_forwarded_for[0] + b", " + previous_host
|
x_forwarded_for.extend(x_forwarded_for[1:])
|
||||||
] + x_forwarded_for[1:]
|
|
||||||
else:
|
else:
|
||||||
x_forwarded_for = [previous_host]
|
x_forwarded_for = [previous_host]
|
||||||
headers[b"X-Forwarded-For"] = x_forwarded_for
|
headers[b"X-Forwarded-For"] = x_forwarded_for
|
||||||
|
@ -253,13 +254,16 @@ class GenericWorkerSlavedStore(
|
||||||
SessionStore,
|
SessionStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
):
|
):
|
||||||
pass
|
# Properties that multiple storage classes define. Tell mypy what the
|
||||||
|
# expected type is.
|
||||||
|
server_name: str
|
||||||
|
config: HomeServerConfig
|
||||||
|
|
||||||
|
|
||||||
class GenericWorkerServer(HomeServer):
|
class GenericWorkerServer(HomeServer):
|
||||||
DATASTORE_CLASS = GenericWorkerSlavedStore
|
DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore
|
||||||
|
|
||||||
def _listen_http(self, listener_config: ListenerConfig):
|
def _listen_http(self, listener_config: ListenerConfig) -> None:
|
||||||
port = listener_config.port
|
port = listener_config.port
|
||||||
bind_addresses = listener_config.bind_addresses
|
bind_addresses = listener_config.bind_addresses
|
||||||
|
|
||||||
|
@ -267,10 +271,10 @@ class GenericWorkerServer(HomeServer):
|
||||||
|
|
||||||
site_tag = listener_config.http_options.tag
|
site_tag = listener_config.http_options.tag
|
||||||
if site_tag is None:
|
if site_tag is None:
|
||||||
site_tag = port
|
site_tag = str(port)
|
||||||
|
|
||||||
# We always include a health resource.
|
# We always include a health resource.
|
||||||
resources: Dict[str, IResource] = {"/health": HealthResource()}
|
resources: Dict[str, Resource] = {"/health": HealthResource()}
|
||||||
|
|
||||||
for res in listener_config.http_options.resources:
|
for res in listener_config.http_options.resources:
|
||||||
for name in res.names:
|
for name in res.names:
|
||||||
|
@ -386,7 +390,7 @@ class GenericWorkerServer(HomeServer):
|
||||||
|
|
||||||
logger.info("Synapse worker now listening on port %d", port)
|
logger.info("Synapse worker now listening on port %d", port)
|
||||||
|
|
||||||
def start_listening(self):
|
def start_listening(self) -> None:
|
||||||
for listener in self.config.worker.worker_listeners:
|
for listener in self.config.worker.worker_listeners:
|
||||||
if listener.type == "http":
|
if listener.type == "http":
|
||||||
self._listen_http(listener)
|
self._listen_http(listener)
|
||||||
|
@ -411,7 +415,7 @@ class GenericWorkerServer(HomeServer):
|
||||||
self.get_tcp_replication().start_replication(self)
|
self.get_tcp_replication().start_replication(self)
|
||||||
|
|
||||||
|
|
||||||
def start(config_options):
|
def start(config_options: List[str]) -> None:
|
||||||
try:
|
try:
|
||||||
config = HomeServerConfig.load_config("Synapse worker", config_options)
|
config = HomeServerConfig.load_config("Synapse worker", config_options)
|
||||||
except ConfigError as e:
|
except ConfigError as e:
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterator
|
from typing import Dict, Iterable, Iterator, List
|
||||||
|
|
||||||
from twisted.internet import reactor
|
from twisted.internet.tcp import Port
|
||||||
from twisted.web.resource import EncodingResourceWrapper, IResource
|
from twisted.web.resource import EncodingResourceWrapper, Resource
|
||||||
from twisted.web.server import GzipEncoderFactory
|
from twisted.web.server import GzipEncoderFactory
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
|
|
||||||
|
@ -76,23 +76,27 @@ from synapse.util.versionstring import get_version_string
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
|
|
||||||
def gz_wrap(r):
|
def gz_wrap(r: Resource) -> Resource:
|
||||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||||
|
|
||||||
|
|
||||||
class SynapseHomeServer(HomeServer):
|
class SynapseHomeServer(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore
|
DATASTORE_CLASS = DataStore # type: ignore
|
||||||
|
|
||||||
def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
|
def _listener_http(
|
||||||
|
self, config: HomeServerConfig, listener_config: ListenerConfig
|
||||||
|
) -> Iterable[Port]:
|
||||||
port = listener_config.port
|
port = listener_config.port
|
||||||
bind_addresses = listener_config.bind_addresses
|
bind_addresses = listener_config.bind_addresses
|
||||||
tls = listener_config.tls
|
tls = listener_config.tls
|
||||||
|
# Must exist since this is an HTTP listener.
|
||||||
|
assert listener_config.http_options is not None
|
||||||
site_tag = listener_config.http_options.tag
|
site_tag = listener_config.http_options.tag
|
||||||
if site_tag is None:
|
if site_tag is None:
|
||||||
site_tag = str(port)
|
site_tag = str(port)
|
||||||
|
|
||||||
# We always include a health resource.
|
# We always include a health resource.
|
||||||
resources = {"/health": HealthResource()}
|
resources: Dict[str, Resource] = {"/health": HealthResource()}
|
||||||
|
|
||||||
for res in listener_config.http_options.resources:
|
for res in listener_config.http_options.resources:
|
||||||
for name in res.names:
|
for name in res.names:
|
||||||
|
@ -111,7 +115,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
|
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
|
||||||
)
|
)
|
||||||
handler = handler_cls(config, module_api)
|
handler = handler_cls(config, module_api)
|
||||||
if IResource.providedBy(handler):
|
if isinstance(handler, Resource):
|
||||||
resource = handler
|
resource = handler
|
||||||
elif hasattr(handler, "handle_request"):
|
elif hasattr(handler, "handle_request"):
|
||||||
resource = AdditionalResource(self, handler.handle_request)
|
resource = AdditionalResource(self, handler.handle_request)
|
||||||
|
@ -128,7 +132,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
# try to find something useful to redirect '/' to
|
# try to find something useful to redirect '/' to
|
||||||
if WEB_CLIENT_PREFIX in resources:
|
if WEB_CLIENT_PREFIX in resources:
|
||||||
root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
|
root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
|
||||||
elif STATIC_PREFIX in resources:
|
elif STATIC_PREFIX in resources:
|
||||||
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
|
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
|
||||||
else:
|
else:
|
||||||
|
@ -145,6 +149,8 @@ class SynapseHomeServer(HomeServer):
|
||||||
)
|
)
|
||||||
|
|
||||||
if tls:
|
if tls:
|
||||||
|
# refresh_certificate should have been called before this.
|
||||||
|
assert self.tls_server_context_factory is not None
|
||||||
ports = listen_ssl(
|
ports = listen_ssl(
|
||||||
bind_addresses,
|
bind_addresses,
|
||||||
port,
|
port,
|
||||||
|
@ -165,20 +171,21 @@ class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
return ports
|
return ports
|
||||||
|
|
||||||
def _configure_named_resource(self, name, compress=False):
|
def _configure_named_resource(
|
||||||
|
self, name: str, compress: bool = False
|
||||||
|
) -> Dict[str, Resource]:
|
||||||
"""Build a resource map for a named resource
|
"""Build a resource map for a named resource
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): named resource: one of "client", "federation", etc
|
name: named resource: one of "client", "federation", etc
|
||||||
compress (bool): whether to enable gzip compression for this
|
compress: whether to enable gzip compression for this resource
|
||||||
resource
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Resource]: map from path to HTTP resource
|
map from path to HTTP resource
|
||||||
"""
|
"""
|
||||||
resources = {}
|
resources: Dict[str, Resource] = {}
|
||||||
if name == "client":
|
if name == "client":
|
||||||
client_resource = ClientRestResource(self)
|
client_resource: Resource = ClientRestResource(self)
|
||||||
if compress:
|
if compress:
|
||||||
client_resource = gz_wrap(client_resource)
|
client_resource = gz_wrap(client_resource)
|
||||||
|
|
||||||
|
@ -207,7 +214,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
if name == "consent":
|
if name == "consent":
|
||||||
from synapse.rest.consent.consent_resource import ConsentResource
|
from synapse.rest.consent.consent_resource import ConsentResource
|
||||||
|
|
||||||
consent_resource = ConsentResource(self)
|
consent_resource: Resource = ConsentResource(self)
|
||||||
if compress:
|
if compress:
|
||||||
consent_resource = gz_wrap(consent_resource)
|
consent_resource = gz_wrap(consent_resource)
|
||||||
resources.update({"/_matrix/consent": consent_resource})
|
resources.update({"/_matrix/consent": consent_resource})
|
||||||
|
@ -277,7 +284,7 @@ class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
def start_listening(self):
|
def start_listening(self) -> None:
|
||||||
if self.config.redis.redis_enabled:
|
if self.config.redis.redis_enabled:
|
||||||
# If redis is enabled we connect via the replication command handler
|
# If redis is enabled we connect via the replication command handler
|
||||||
# in the same way as the workers (since we're effectively a client
|
# in the same way as the workers (since we're effectively a client
|
||||||
|
@ -303,7 +310,9 @@ class SynapseHomeServer(HomeServer):
|
||||||
ReplicationStreamProtocolFactory(self),
|
ReplicationStreamProtocolFactory(self),
|
||||||
)
|
)
|
||||||
for s in services:
|
for s in services:
|
||||||
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
|
self.get_reactor().addSystemEventTrigger(
|
||||||
|
"before", "shutdown", s.stopListening
|
||||||
|
)
|
||||||
elif listener.type == "metrics":
|
elif listener.type == "metrics":
|
||||||
if not self.config.metrics.enable_metrics:
|
if not self.config.metrics.enable_metrics:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -318,14 +327,13 @@ class SynapseHomeServer(HomeServer):
|
||||||
logger.warning("Unrecognized listener type: %s", listener.type)
|
logger.warning("Unrecognized listener type: %s", listener.type)
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
config_options_options: The options passed to Synapse. Usually
|
config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||||
`sys.argv[1:]`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HomeServer
|
A homeserver instance.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = HomeServerConfig.load_or_generate_config(
|
config = HomeServerConfig.load_or_generate_config(
|
||||||
|
@ -364,7 +372,7 @@ def setup(config_options):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
handle_startup_exception(e)
|
handle_startup_exception(e)
|
||||||
|
|
||||||
async def start():
|
async def start() -> None:
|
||||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||||
if hs.config.oidc.oidc_enabled:
|
if hs.config.oidc.oidc_enabled:
|
||||||
oidc = hs.get_oidc_handler()
|
oidc = hs.get_oidc_handler()
|
||||||
|
@ -404,39 +412,15 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
|
||||||
|
|
||||||
yield ":\n %s" % (e.msg,)
|
yield ":\n %s" % (e.msg,)
|
||||||
|
|
||||||
e = e.__cause__
|
parent_e = e.__cause__
|
||||||
indent = 1
|
indent = 1
|
||||||
while e:
|
while parent_e:
|
||||||
indent += 1
|
indent += 1
|
||||||
yield ":\n%s%s" % (" " * indent, str(e))
|
yield ":\n%s%s" % (" " * indent, str(parent_e))
|
||||||
e = e.__cause__
|
parent_e = parent_e.__cause__
|
||||||
|
|
||||||
|
|
||||||
def run(hs: HomeServer):
|
def run(hs: HomeServer) -> None:
|
||||||
PROFILE_SYNAPSE = False
|
|
||||||
if PROFILE_SYNAPSE:
|
|
||||||
|
|
||||||
def profile(func):
|
|
||||||
from cProfile import Profile
|
|
||||||
from threading import current_thread
|
|
||||||
|
|
||||||
def profiled(*args, **kargs):
|
|
||||||
profile = Profile()
|
|
||||||
profile.enable()
|
|
||||||
func(*args, **kargs)
|
|
||||||
profile.disable()
|
|
||||||
ident = current_thread().ident
|
|
||||||
profile.dump_stats(
|
|
||||||
"/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
|
|
||||||
)
|
|
||||||
|
|
||||||
return profiled
|
|
||||||
|
|
||||||
from twisted.python.threadpool import ThreadPool
|
|
||||||
|
|
||||||
ThreadPool._worker = profile(ThreadPool._worker)
|
|
||||||
reactor.run = profile(reactor.run)
|
|
||||||
|
|
||||||
_base.start_reactor(
|
_base.start_reactor(
|
||||||
"synapse-homeserver",
|
"synapse-homeserver",
|
||||||
soft_file_limit=hs.config.server.soft_file_limit,
|
soft_file_limit=hs.config.server.soft_file_limit,
|
||||||
|
@ -448,7 +432,7 @@ def run(hs: HomeServer):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
with LoggingContext("main"):
|
with LoggingContext("main"):
|
||||||
# check base requirements
|
# check base requirements
|
||||||
check_requirements()
|
check_requirements()
|
||||||
|
|
|
@ -15,11 +15,12 @@ import logging
|
||||||
import math
|
import math
|
||||||
import resource
|
import resource
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Sized, Tuple
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -28,7 +29,7 @@ logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
# Contains the list of processes we will be monitoring
|
# Contains the list of processes we will be monitoring
|
||||||
# currently either 0 or 1
|
# currently either 0 or 1
|
||||||
_stats_process = []
|
_stats_process: List[Tuple[int, "resource.struct_rusage"]] = []
|
||||||
|
|
||||||
# Gauges to expose monthly active user control metrics
|
# Gauges to expose monthly active user control metrics
|
||||||
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
||||||
|
@ -45,9 +46,15 @@ registered_reserved_users_mau_gauge = Gauge(
|
||||||
|
|
||||||
|
|
||||||
@wrap_as_background_process("phone_stats_home")
|
@wrap_as_background_process("phone_stats_home")
|
||||||
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
|
async def phone_stats_home(
|
||||||
|
hs: "HomeServer",
|
||||||
|
stats: JsonDict,
|
||||||
|
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
|
||||||
|
) -> None:
|
||||||
logger.info("Gathering stats for reporting")
|
logger.info("Gathering stats for reporting")
|
||||||
now = int(hs.get_clock().time())
|
now = int(hs.get_clock().time())
|
||||||
|
# Ensure the homeserver has started.
|
||||||
|
assert hs.start_time is not None
|
||||||
uptime = int(now - hs.start_time)
|
uptime = int(now - hs.start_time)
|
||||||
if uptime < 0:
|
if uptime < 0:
|
||||||
uptime = 0
|
uptime = 0
|
||||||
|
@ -146,15 +153,15 @@ async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process
|
||||||
logger.warning("Error reporting stats: %s", e)
|
logger.warning("Error reporting stats: %s", e)
|
||||||
|
|
||||||
|
|
||||||
def start_phone_stats_home(hs: "HomeServer"):
|
def start_phone_stats_home(hs: "HomeServer") -> None:
|
||||||
"""
|
"""
|
||||||
Start the background tasks which report phone home stats.
|
Start the background tasks which report phone home stats.
|
||||||
"""
|
"""
|
||||||
clock = hs.get_clock()
|
clock = hs.get_clock()
|
||||||
|
|
||||||
stats = {}
|
stats: JsonDict = {}
|
||||||
|
|
||||||
def performance_stats_init():
|
def performance_stats_init() -> None:
|
||||||
_stats_process.clear()
|
_stats_process.clear()
|
||||||
_stats_process.append(
|
_stats_process.append(
|
||||||
(int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
|
(int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
|
||||||
|
@ -170,10 +177,10 @@ def start_phone_stats_home(hs: "HomeServer"):
|
||||||
hs.get_datastore().reap_monthly_active_users()
|
hs.get_datastore().reap_monthly_active_users()
|
||||||
|
|
||||||
@wrap_as_background_process("generate_monthly_active_users")
|
@wrap_as_background_process("generate_monthly_active_users")
|
||||||
async def generate_monthly_active_users():
|
async def generate_monthly_active_users() -> None:
|
||||||
current_mau_count = 0
|
current_mau_count = 0
|
||||||
current_mau_count_by_service = {}
|
current_mau_count_by_service = {}
|
||||||
reserved_users = ()
|
reserved_users: Sized = ()
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
|
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
|
||||||
current_mau_count = await store.get_monthly_active_count()
|
current_mau_count = await store.get_monthly_active_count()
|
||||||
|
|
|
@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def write_invite(
|
def write_invite(
|
||||||
self, room_id: str, event: EventBase, state: StateMap[dict]
|
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write an invite for the room, with associated invite state.
|
"""Write an invite for the room, with associated invite state.
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def write_knock(
|
def write_knock(
|
||||||
self, room_id: str, event: EventBase, state: StateMap[dict]
|
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write a knock for the room, with associated knock state.
|
"""Write a knock for the room, with associated knock state.
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
from logging import Handler, LogRecord
|
from logging import Handler, LogRecord
|
||||||
from logging.handlers import MemoryHandler
|
from logging.handlers import MemoryHandler
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
|
|
||||||
from twisted.internet.interfaces import IReactorCore
|
from twisted.internet.interfaces import IReactorCore
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
|
||||||
if reactor is None:
|
if reactor is None:
|
||||||
from twisted.internet import reactor as global_reactor
|
from twisted.internet import reactor as global_reactor
|
||||||
|
|
||||||
reactor_to_use = global_reactor # type: ignore[assignment]
|
reactor_to_use = cast(IReactorCore, global_reactor)
|
||||||
else:
|
else:
|
||||||
reactor_to_use = reactor
|
reactor_to_use = reactor
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import attr
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -196,7 +196,7 @@ class ModuleApi:
|
||||||
"""
|
"""
|
||||||
return self._password_auth_provider.register_password_auth_provider_callbacks
|
return self._password_auth_provider.register_password_auth_provider_callbacks
|
||||||
|
|
||||||
def register_web_resource(self, path: str, resource: IResource):
|
def register_web_resource(self, path: str, resource: Resource):
|
||||||
"""Registers a web resource to be served at the given path.
|
"""Registers a web resource to be served at the given path.
|
||||||
|
|
||||||
This function should be called during initialisation of the module.
|
This function should be called during initialisation of the module.
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import ServerFactory
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.tcp.commands import PositionCommand
|
from synapse.replication.tcp.commands import PositionCommand
|
||||||
|
@ -38,7 +38,7 @@ stream_updates_counter = Counter(
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReplicationStreamProtocolFactory(Factory):
|
class ReplicationStreamProtocolFactory(ServerFactory):
|
||||||
"""Factory for new replication connections."""
|
"""Factory for new replication connections."""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
|
|
@ -33,9 +33,10 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import twisted.internet.tcp
|
from twisted.internet.interfaces import IOpenSSLContextFactory
|
||||||
|
from twisted.internet.tcp import Port
|
||||||
from twisted.web.iweb import IPolicyForHTTPS
|
from twisted.web.iweb import IPolicyForHTTPS
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.filtering import Filtering
|
from synapse.api.filtering import Filtering
|
||||||
|
@ -206,7 +207,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
config (synapse.config.homeserver.HomeserverConfig):
|
config (synapse.config.homeserver.HomeserverConfig):
|
||||||
_listening_services (list[twisted.internet.tcp.Port]): TCP ports that
|
_listening_services (list[Port]): TCP ports that
|
||||||
we are listening on to provide HTTP services.
|
we are listening on to provide HTTP services.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -225,6 +226,8 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
# instantiated during setup() for future return by get_datastore()
|
# instantiated during setup() for future return by get_datastore()
|
||||||
DATASTORE_CLASS = abc.abstractproperty()
|
DATASTORE_CLASS = abc.abstractproperty()
|
||||||
|
|
||||||
|
tls_server_context_factory: Optional[IOpenSSLContextFactory]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
|
@ -247,7 +250,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
# the key we use to sign events and requests
|
# the key we use to sign events and requests
|
||||||
self.signing_key = config.key.signing_key[0]
|
self.signing_key = config.key.signing_key[0]
|
||||||
self.config = config
|
self.config = config
|
||||||
self._listening_services: List[twisted.internet.tcp.Port] = []
|
self._listening_services: List[Port] = []
|
||||||
self.start_time: Optional[int] = None
|
self.start_time: Optional[int] = None
|
||||||
|
|
||||||
self._instance_id = random_string(5)
|
self._instance_id = random_string(5)
|
||||||
|
@ -257,10 +260,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
self.datastores: Optional[Databases] = None
|
self.datastores: Optional[Databases] = None
|
||||||
|
|
||||||
self._module_web_resources: Dict[str, IResource] = {}
|
self._module_web_resources: Dict[str, Resource] = {}
|
||||||
self._module_web_resources_consumed = False
|
self._module_web_resources_consumed = False
|
||||||
|
|
||||||
def register_module_web_resource(self, path: str, resource: IResource):
|
def register_module_web_resource(self, path: str, resource: Resource):
|
||||||
"""Allows a module to register a web resource to be served at the given path.
|
"""Allows a module to register a web resource to be served at the given path.
|
||||||
|
|
||||||
If multiple modules register a resource for the same path, the module that
|
If multiple modules register a resource for the same path, the module that
|
||||||
|
|
|
@ -38,6 +38,7 @@ from zope.interface import Interface
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IReactorCore,
|
IReactorCore,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
|
IReactorSSL,
|
||||||
IReactorTCP,
|
IReactorTCP,
|
||||||
IReactorThreads,
|
IReactorThreads,
|
||||||
IReactorTime,
|
IReactorTime,
|
||||||
|
@ -66,6 +67,7 @@ JsonDict = Dict[str, Any]
|
||||||
# for mypy-zope to realize it is an interface.
|
# for mypy-zope to realize it is an interface.
|
||||||
class ISynapseReactor(
|
class ISynapseReactor(
|
||||||
IReactorTCP,
|
IReactorTCP,
|
||||||
|
IReactorSSL,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTime,
|
IReactorTime,
|
||||||
IReactorCore,
|
IReactorCore,
|
||||||
|
|
|
@ -31,13 +31,13 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import ContextManager
|
from typing_extensions import ContextManager
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.base import ReactorBase
|
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
@ -271,8 +271,7 @@ class Linearizer:
|
||||||
if not clock:
|
if not clock:
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
|
||||||
assert isinstance(reactor, ReactorBase)
|
clock = Clock(cast(IReactorTime, reactor))
|
||||||
clock = Clock(reactor)
|
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
|
|
||||||
|
|
|
@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str:
|
||||||
the mapping should looks like _resource_id(A,C) = B.
|
the mapping should looks like _resource_id(A,C) = B.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resource (Resource): The *parent* Resourceb
|
resource: The *parent* Resourceb
|
||||||
path_seg (str): The name of the child Resource to be attached.
|
path_seg: The name of the child Resource to be attached.
|
||||||
Returns:
|
Returns:
|
||||||
str: A unique string which can be a key to the child Resource.
|
A unique string which can be a key to the child Resource.
|
||||||
"""
|
"""
|
||||||
return "%s-%r" % (resource, path_seg)
|
return "%s-%r" % (resource, path_seg)
|
||||||
|
|
|
@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
|
||||||
from twisted.conch.ssh.keys import Key
|
from twisted.conch.ssh.keys import Key
|
||||||
from twisted.cred import checkers, portal
|
from twisted.cred import checkers, portal
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import ServerFactory
|
||||||
|
|
||||||
from synapse.config.server import ManholeConfig
|
from synapse.config.server import ManholeConfig
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
|
||||||
-----END RSA PRIVATE KEY-----"""
|
-----END RSA PRIVATE KEY-----"""
|
||||||
|
|
||||||
|
|
||||||
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
|
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
|
||||||
"""Starts a ssh listener with password authentication using
|
"""Starts a ssh listener with password authentication using
|
||||||
the given username and password. Clients connecting to the ssh
|
the given username and password. Clients connecting to the ssh
|
||||||
listener will find themselves in a colored python shell with
|
listener will find themselves in a colored python shell with
|
||||||
|
@ -105,7 +105,8 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
|
||||||
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
|
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
|
||||||
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
|
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
|
||||||
|
|
||||||
return factory
|
# ConchFactory is a Factory, not a ServerFactory, but they are identical.
|
||||||
|
return factory # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
class SynapseManhole(ColoredManhole):
|
class SynapseManhole(ColoredManhole):
|
||||||
|
|
Loading…
Reference in New Issue