SIGHUP for TLS cert reloading (#4495)

This commit is contained in:
Amber Brown 2019-01-30 11:00:02 +00:00 committed by GitHub
parent bc5f6e1797
commit f6813919e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 20 deletions

1
.gitignore vendored
View File

@ -12,6 +12,7 @@ dbs/
dist/ dist/
docs/build/ docs/build/
*.egg-info *.egg-info
pip-wheel-metadata/
cmdclient_config.json cmdclient_config.json
homeserver*.db homeserver*.db

1
changelog.d/4495.feature Normal file
View File

@ -0,0 +1 @@
Synapse will now reload TLS certificates from disk upon SIGHUP.

View File

@ -143,6 +143,9 @@ def listen_metrics(bind_addresses, port):
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
""" """
Create a TCP socket for a port and several addresses Create a TCP socket for a port and several addresses
Returns:
list (empty)
""" """
for address in bind_addresses: for address in bind_addresses:
try: try:
@ -155,25 +158,37 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
except error.CannotListenError as e: except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses) check_bind_error(e, address, bind_addresses)
logger.info("Synapse now listening on TCP port %d", port)
return []
def listen_ssl( def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
): ):
""" """
Create an SSL socket for a port and several addresses Create an TLS-over-TCP socket for a port and several addresses
Returns:
list of twisted.internet.tcp.Port listening for TLS connections
""" """
r = []
for address in bind_addresses: for address in bind_addresses:
try: try:
reactor.listenSSL( r.append(
port, reactor.listenSSL(
factory, port,
context_factory, factory,
backlog, context_factory,
address backlog,
address
)
) )
except error.CannotListenError as e: except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses) check_bind_error(e, address, bind_addresses)
logger.info("Synapse now listening on port %d (TLS)", port)
return r
def check_bind_error(e, address, bind_addresses): def check_bind_error(e, address, bind_addresses):
""" """

View File

@ -17,6 +17,7 @@
import gc import gc
import logging import logging
import os import os
import signal
import sys import sys
import traceback import traceback
@ -27,6 +28,7 @@ from prometheus_client import Gauge
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File from twisted.web.static import File
@ -84,6 +86,7 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore
_listening_services = []
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
@ -121,7 +124,7 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource) root_resource = create_resource_tree(resources, root_resource)
if tls: if tls:
listen_ssl( return listen_ssl(
bind_addresses, bind_addresses,
port, port,
SynapseSite( SynapseSite(
@ -135,7 +138,7 @@ class SynapseHomeServer(HomeServer):
) )
else: else:
listen_tcp( return listen_tcp(
bind_addresses, bind_addresses,
port, port,
SynapseSite( SynapseSite(
@ -146,7 +149,6 @@ class SynapseHomeServer(HomeServer):
self.version_string, self.version_string,
) )
) )
logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False): def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource """Build a resource map for a named resource
@ -242,7 +244,9 @@ class SynapseHomeServer(HomeServer):
for listener in config.listeners: for listener in config.listeners:
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listening_services.extend(
self._listener_http(config, listener)
)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
listen_tcp( listen_tcp(
listener["bind_addresses"], listener["bind_addresses"],
@ -322,7 +326,19 @@ def setup(config_options):
# generating config files and shouldn't try to continue. # generating config files and shouldn't try to continue.
sys.exit(0) sys.exit(0)
synapse.config.logger.setup_logging(config, use_worker_options=False) sighup_callbacks = []
synapse.config.logger.setup_logging(
config,
use_worker_options=False,
register_sighup=sighup_callbacks.append
)
def handle_sighup(*args, **kwargs):
for i in sighup_callbacks:
i(*args, **kwargs)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, handle_sighup)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -359,6 +375,31 @@ def setup(config_options):
hs.setup() hs.setup()
def refresh_certificate(*args):
"""
Refresh the TLS certificates that Synapse is using by re-reading them
from disk and updating the TLS context factories to use them.
"""
logging.info("Reloading certificate from disk...")
hs.config.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
logging.info("Certificate reloaded.")
logging.info("Updating context factories...")
for i in hs._listening_services:
if isinstance(i.factory, TLSMemoryBIOFactory):
i.factory = TLSMemoryBIOFactory(
hs.tls_server_context_factory,
False,
i.factory.wrappedFactory
)
logging.info("Context factories updated.")
sighup_callbacks.append(refresh_certificate)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(): def start():
try: try:

View File

@ -127,7 +127,7 @@ class LoggingConfig(Config):
) )
def setup_logging(config, use_worker_options=False): def setup_logging(config, use_worker_options=False, register_sighup=None):
""" Set up python logging """ Set up python logging
Args: Args:
@ -136,7 +136,16 @@ def setup_logging(config, use_worker_options=False):
use_worker_options (bool): True to use 'worker_log_config' and use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'. 'worker_log_file' options instead of 'log_config' and 'log_file'.
register_sighup (func | None): Function to call to register a
sighup handler.
""" """
if not register_sighup:
if getattr(signal, "SIGHUP"):
register_sighup = lambda x: signal.signal(signal.SIGHUP, x)
else:
register_sighup = lambda x: None
log_config = (config.worker_log_config if use_worker_options log_config = (config.worker_log_config if use_worker_options
else config.log_config) else config.log_config)
log_file = (config.worker_log_file if use_worker_options log_file = (config.worker_log_file if use_worker_options
@ -198,13 +207,7 @@ def setup_logging(config, use_worker_options=False):
load_log_config() load_log_config()
# TODO(paul): obviously this is a terrible mechanism for register_sighup(sighup)
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
# make sure that the first thing we log is a thing we can grep backwards # make sure that the first thing we log is a thing we can grep backwards
# for # for