diff --git a/changelog.d/4522.feature b/changelog.d/4522.feature new file mode 100644 index 0000000000..ef18daf60b --- /dev/null +++ b/changelog.d/4522.feature @@ -0,0 +1 @@ +Synapse's ACME support will now correctly reprovision a certificate that approaches its expiry while Synapse is running. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 3cbb003035..62c633146f 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -23,6 +23,7 @@ import psutil from daemonize import Daemonize from twisted.internet import error, reactor +from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.app import check_bind_error from synapse.crypto import context_factory @@ -220,6 +221,24 @@ def refresh_certificate(hs): ) logging.info("Certificate loaded.") + if hs._listening_services: + logging.info("Updating context factories...") + for i in hs._listening_services: + # When you listenSSL, it doesn't make an SSL port but a TCP one with + # a TLS wrapping factory around the factory you actually want to get + # requests. This factory attribute is public but missing from + # Twisted's documentation. + if isinstance(i.factory, TLSMemoryBIOFactory): + # We want to replace TLS factories with a new one, with the new + # TLS configuration. We do this by reaching in and pulling out + # the wrappedFactory, and then re-wrapping it. + i.factory = TLSMemoryBIOFactory( + hs.tls_server_context_factory, + False, + i.factory.wrappedFactory + ) + logging.info("Context factories updated.") + def start(hs, listeners=None): """ diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index d1cab07bb6..b4476bf16e 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -83,7 +83,6 @@ def gz_wrap(r): class SynapseHomeServer(HomeServer): DATASTORE_CLASS = DataStore - _listening_services = [] def _listener_http(self, config, listener_config): port = listener_config["port"] @@ -376,42 +375,73 @@ def setup(config_options): hs.setup() + @defer.inlineCallbacks + def do_acme(): + """ + Reprovision an ACME certificate, if it's required. + + Returns: + Deferred[bool]: Whether the cert has been updated. + """ + acme = hs.get_acme_handler() + + # Check how long the certificate is active for. + cert_days_remaining = hs.config.is_disk_cert_valid( + allow_self_signed=False + ) + + # We want to reprovision if cert_days_remaining is None (meaning no + # certificate exists), or the days remaining number it returns + # is less than our re-registration threshold. + provision = False + + if (cert_days_remaining is None): + provision = True + + if cert_days_remaining > hs.config.acme_reprovision_threshold: + provision = True + + if provision: + yield acme.provision_certificate() + + defer.returnValue(provision) + + @defer.inlineCallbacks + def reprovision_acme(): + """ + Provision a certificate from ACME, if required, and reload the TLS + certificate if it's renewed. + """ + reprovisioned = yield do_acme() + if reprovisioned: + _base.refresh_certificate(hs) + @defer.inlineCallbacks def start(): try: - # Check if the certificate is still valid. - cert_days_remaining = hs.config.is_disk_cert_valid() - + # Run the ACME provisioning code, if it's enabled. if hs.config.acme_enabled: - # If ACME is enabled, we might need to provision a certificate - # before starting. acme = hs.get_acme_handler() - # Start up the webservices which we will respond to ACME - # challenges with. + # challenges with, and then provision. yield acme.start_listening() + yield do_acme() - # We want to reprovision if cert_days_remaining is None (meaning no - # certificate exists), or the days remaining number it returns - # is less than our re-registration threshold. - if (cert_days_remaining is None) or ( - not cert_days_remaining > hs.config.acme_reprovision_threshold - ): - yield acme.provision_certificate() + # Check if it needs to be reprovisioned every day. + hs.get_clock().looping_call( + reprovision_acme, + 24 * 60 * 60 * 1000 + ) _base.start(hs, config.listeners) hs.get_pusherpool().start() hs.get_datastore().start_doing_background_updates() - except Exception as e: - # If a DeferredList failed (like in listening on the ACME listener), - # we need to print the subfailure explicitly. - if isinstance(e, defer.FirstError): - e.subFailure.printTraceback(sys.stderr) - sys.exit(1) - - # Something else went wrong when starting. Print it and bail out. + except Exception: + # Print the exception and bail out. traceback.print_exc(file=sys.stderr) + if reactor.running: + reactor.stop() sys.exit(1) reactor.callWhenRunning(start) @@ -420,7 +450,8 @@ def setup(config_options): class SynapseService(service.Service): - """A twisted Service class that will start synapse. Used to run synapse + """ + A twisted Service class that will start synapse. Used to run synapse via twistd and a .tac. """ def __init__(self, config): diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 81b3a659fe..9fcc79816d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -64,10 +64,14 @@ class TlsConfig(Config): self.tls_certificate = None self.tls_private_key = None - def is_disk_cert_valid(self): + def is_disk_cert_valid(self, allow_self_signed=True): """ Is the certificate we have on disk valid, and if so, for how long? + Args: + allow_self_signed (bool): Should we allow the certificate we + read to be self signed? + Returns: int: Days remaining of certificate validity. None: No certificate exists. @@ -88,6 +92,12 @@ class TlsConfig(Config): logger.exception("Failed to parse existing certificate off disk!") raise + if not allow_self_signed: + if tls_certificate.get_subject() == tls_certificate.get_issuer(): + raise ValueError( + "TLS Certificate is self signed, and this is not permitted" + ) + # YYYYMMDDhhmmssZ -- in UTC expires_on = datetime.strptime( tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" diff --git a/synapse/server.py b/synapse/server.py index 6c52101616..a2cf8a91cd 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -112,6 +112,8 @@ class HomeServer(object): Attributes: config (synapse.config.homeserver.HomeserverConfig): + _listening_services (list[twisted.internet.tcp.Port]): TCP ports that + we are listening on to provide HTTP services. """ __metaclass__ = abc.ABCMeta @@ -196,6 +198,7 @@ class HomeServer(object): self._reactor = reactor self.hostname = hostname self._building = {} + self._listening_services = [] self.clock = Clock(reactor) self.distributor = Distributor()