diff --git a/changelog.d/5047.feature b/changelog.d/5047.feature
new file mode 100644
index 0000000000..12766a82a7
--- /dev/null
+++ b/changelog.d/5047.feature
@@ -0,0 +1 @@
+Add time-based account expiration.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8888fd49c4..ab02e8f20e 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -646,11 +646,31 @@ uploads_path: "DATADIR/uploads"
#
#enable_registration: false
-# Optional account validity parameter. This allows for, e.g., accounts to
-# be denied any request after a given period.
+# Optional account validity configuration. This allows for accounts to be denied
+# any request after a given period.
+#
+# ``enabled`` defines whether the account validity feature is enabled. Defaults
+# to False.
+#
+# ``period`` allows setting the period after which an account is valid
+# after its registration. When renewing the account, its validity period
+# will be extended by this amount of time. This parameter is required when using
+# the account validity feature.
+#
+# ``renew_at`` is the amount of time before an account's expiry date at which
+# Synapse will send an email to the account's email address with a renewal link.
+# This needs the ``email`` and ``public_baseurl`` configuration sections to be
+# filled.
+#
+# ``renew_email_subject`` is the subject of the email sent out with the renewal
+# link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter
+# from the ``email`` section.
#
#account_validity:
+# enabled: True
# period: 6w
+# renew_at: 1w
+# renew_email_subject: "Renew your %(app)s account"
# The user must provide all of the below types of 3PID when registering.
#
@@ -897,7 +917,7 @@ password_config:
-# Enable sending emails for notification events
+# Enable sending emails for notification events or expiry notices
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
@@ -919,6 +939,9 @@ password_config:
# #template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
+# # Templates for account expiry notices.
+# expiry_template_html: notice_expiry.html
+# expiry_template_text: notice_expiry.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 976e0dd18b..4482962510 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -230,8 +230,9 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled:
- expiration_ts = yield self.store.get_expiration_ts_for_user(user)
- if self.clock.time_msec() >= expiration_ts:
+ user_id = user.to_string()
+ expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ if expiration_ts and self.clock.time_msec() >= expiration_ts:
raise AuthError(
403,
"User account has expired",
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 93d70cff14..60827be72f 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -71,6 +71,8 @@ class EmailConfig(Config):
self.email_notif_from = email_config["notif_from"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
+ self.email_expiry_template_html = email_config["expiry_template_html"]
+ self.email_expiry_template_text = email_config["expiry_template_text"]
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
@@ -120,7 +122,7 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
- # Enable sending emails for notification events
+ # Enable sending emails for notification events or expiry notices
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
@@ -142,6 +144,9 @@ class EmailConfig(Config):
# #template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
+ # # Templates for account expiry notices.
+ # expiry_template_html: notice_expiry.html
+ # expiry_template_text: notice_expiry.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
"""
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index dd242b1211..1309bce3ee 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -21,12 +21,26 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
- def __init__(self, config):
- self.enabled = (len(config) > 0)
+ def __init__(self, config, synapse_config):
+ self.enabled = config.get("enabled", False)
+ self.renew_by_email_enabled = ("renew_at" in config)
- period = config.get("period", None)
- if period:
- self.period = self.parse_duration(period)
+ if self.enabled:
+ if "period" in config:
+ self.period = self.parse_duration(config["period"])
+ else:
+ raise ConfigError("'period' is required when using account validity")
+
+ if "renew_at" in config:
+ self.renew_at = self.parse_duration(config["renew_at"])
+
+ if "renew_email_subject" in config:
+ self.renew_email_subject = config["renew_email_subject"]
+ else:
+ self.renew_email_subject = "Renew your %(app)s account"
+
+ if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
+ raise ConfigError("Can't send renewal emails without 'public_baseurl'")
class RegistrationConfig(Config):
@@ -40,7 +54,9 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
- self.account_validity = AccountValidityConfig(config.get("account_validity", {}))
+ self.account_validity = AccountValidityConfig(
+ config.get("account_validity", {}), config,
+ )
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
@@ -87,11 +103,31 @@ class RegistrationConfig(Config):
#
#enable_registration: false
- # Optional account validity parameter. This allows for, e.g., accounts to
- # be denied any request after a given period.
+ # Optional account validity configuration. This allows for accounts to be denied
+ # any request after a given period.
+ #
+ # ``enabled`` defines whether the account validity feature is enabled. Defaults
+ # to False.
+ #
+ # ``period`` allows setting the period after which an account is valid
+ # after its registration. When renewing the account, its validity period
+ # will be extended by this amount of time. This parameter is required when using
+ # the account validity feature.
+ #
+ # ``renew_at`` is the amount of time before an account's expiry date at which
+ # Synapse will send an email to the account's email address with a renewal link.
+ # This needs the ``email`` and ``public_baseurl`` configuration sections to be
+ # filled.
+ #
+ # ``renew_email_subject`` is the subject of the email sent out with the renewal
+ # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
+ # from the ``email`` section.
#
#account_validity:
+ # enabled: True
# period: 6w
+ # renew_at: 1w
+ # renew_email_subject: "Renew your %%(app)s account"
# The user must provide all of the below types of 3PID when registering.
#
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
new file mode 100644
index 0000000000..e82049e42d
--- /dev/null
+++ b/synapse/handlers/account_validity.py
@@ -0,0 +1,228 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.mime.multipart
+import email.utils
+import logging
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.types import UserID
+from synapse.util import stringutils
+from synapse.util.logcontext import make_deferred_yieldable
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except ImportError:
+ load_jinja2_templates = None
+
+logger = logging.getLogger(__name__)
+
+
+class AccountValidityHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = self.hs.get_datastore()
+ self.sendmail = self.hs.get_sendmail()
+ self.clock = self.hs.get_clock()
+
+ self._account_validity = self.hs.config.account_validity
+
+ if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
+ # Don't do email-specific configuration if renewal by email is disabled.
+ try:
+ app_name = self.hs.config.email_app_name
+
+ self._subject = self._account_validity.renew_email_subject % {
+ "app": app_name,
+ }
+
+ self._from_string = self.hs.config.email_notif_from % {
+ "app": app_name,
+ }
+ except Exception:
+ # If substitution failed, fall back to the bare strings.
+ self._subject = self._account_validity.renew_email_subject
+ self._from_string = self.hs.config.email_notif_from
+
+ self._raw_from = email.utils.parseaddr(self._from_string)[1]
+
+ self._template_html, self._template_text = load_jinja2_templates(
+ config=self.hs.config,
+ template_html_name=self.hs.config.email_expiry_template_html,
+ template_text_name=self.hs.config.email_expiry_template_text,
+ )
+
+ # Check the renewal emails to send and send them every 30min.
+ self.clock.looping_call(
+ self.send_renewal_emails,
+ 30 * 60 * 1000,
+ )
+
+ @defer.inlineCallbacks
+ def send_renewal_emails(self):
+ """Gets the list of users whose account is expiring in the amount of time
+ configured in the ``renew_at`` parameter from the ``account_validity``
+ configuration, and sends renewal emails to all of these users as long as they
+ have an email 3PID attached to their account.
+ """
+ expiring_users = yield self.store.get_users_expiring_soon()
+
+ if expiring_users:
+ for user in expiring_users:
+ yield self._send_renewal_email(
+ user_id=user["user_id"],
+ expiration_ts=user["expiration_ts_ms"],
+ )
+
+ @defer.inlineCallbacks
+ def _send_renewal_email(self, user_id, expiration_ts):
+ """Sends out a renewal email to every email address attached to the given user
+ with a unique link allowing them to renew their account.
+
+ Args:
+ user_id (str): ID of the user to send email(s) to.
+ expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ this user's account (used in the email templates).
+ """
+ addresses = yield self._get_email_addresses_for_user(user_id)
+
+ # Stop right here if the user doesn't have at least one email address.
+ # In this case, they will have to ask their server admin to renew their
+ # account manually.
+ if not addresses:
+ return
+
+ try:
+ user_display_name = yield self.store.get_profile_displayname(
+ UserID.from_string(user_id).localpart
+ )
+ if user_display_name is None:
+ user_display_name = user_id
+ except StoreError:
+ user_display_name = user_id
+
+ renewal_token = yield self._get_renewal_token(user_id)
+ url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
+ self.hs.config.public_baseurl,
+ renewal_token,
+ )
+
+ template_vars = {
+ "display_name": user_display_name,
+ "expiration_ts": expiration_ts,
+ "url": url,
+ }
+
+ html_text = self._template_html.render(**template_vars)
+ html_part = MIMEText(html_text, "html", "utf8")
+
+ plain_text = self._template_text.render(**template_vars)
+ text_part = MIMEText(plain_text, "plain", "utf8")
+
+ for address in addresses:
+ raw_to = email.utils.parseaddr(address)[1]
+
+ multipart_msg = MIMEMultipart('alternative')
+ multipart_msg['Subject'] = self._subject
+ multipart_msg['From'] = self._from_string
+ multipart_msg['To'] = address
+ multipart_msg['Date'] = email.utils.formatdate()
+ multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg.attach(text_part)
+ multipart_msg.attach(html_part)
+
+ logger.info("Sending renewal email to %s", address)
+
+ yield make_deferred_yieldable(self.sendmail(
+ self.hs.config.email_smtp_host,
+ self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security
+ ))
+
+ yield self.store.set_renewal_mail_status(
+ user_id=user_id,
+ email_sent=True,
+ )
+
+ @defer.inlineCallbacks
+ def _get_email_addresses_for_user(self, user_id):
+ """Retrieve the list of email addresses attached to a user's account.
+
+ Args:
+ user_id (str): ID of the user to lookup email addresses for.
+
+ Returns:
+ defer.Deferred[list[str]]: Email addresses for this account.
+ """
+ threepids = yield self.store.user_get_threepids(user_id)
+
+ addresses = []
+ for threepid in threepids:
+ if threepid["medium"] == "email":
+ addresses.append(threepid["address"])
+
+ defer.returnValue(addresses)
+
+ @defer.inlineCallbacks
+ def _get_renewal_token(self, user_id):
+ """Generates a 32-byte long random string that will be inserted into the
+ user's renewal email's unique link, then saves it into the database.
+
+ Args:
+ user_id (str): ID of the user to generate a string for.
+
+ Returns:
+ defer.Deferred[str]: The generated string.
+
+ Raises:
+ StoreError(500): Couldn't generate a unique string after 5 attempts.
+ """
+ attempts = 0
+ while attempts < 5:
+ try:
+ renewal_token = stringutils.random_string(32)
+ yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ defer.returnValue(renewal_token)
+ except StoreError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a unique string as refresh string.")
+
+ @defer.inlineCallbacks
+ def renew_account(self, renewal_token):
+ """Renews the account attached to a given renewal token by pushing back the
+ expiration date by the current validity period in the server's configuration.
+
+ Args:
+ renewal_token (str): Token sent with the renewal request.
+ """
+ user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+
+ logger.debug("Renewing an account for user %s", user_id)
+
+ new_expiration_date = self.clock.time_msec() + self._account_validity.period
+
+ yield self.store.renew_account_for_user(
+ user_id=user_id,
+ new_expiration_ts=new_expiration_date,
+ )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 1eb5be0957..c269bcf4a4 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -521,11 +521,11 @@ def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
-def load_jinja2_templates(config):
+def load_jinja2_templates(config, template_html_name, template_text_name):
"""Load the jinja2 email templates from disk
Returns:
- (notif_template_html, notif_template_text)
+ (template_html, template_text)
"""
logger.info("loading email templates from '%s'", config.email_template_dir)
loader = jinja2.FileSystemLoader(config.email_template_dir)
@@ -533,14 +533,10 @@ def load_jinja2_templates(config):
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
- notif_template_html = env.get_template(
- config.email_notif_template_html
- )
- notif_template_text = env.get_template(
- config.email_notif_template_text
- )
+ template_html = env.get_template(template_html_name)
+ template_text = env.get_template(template_text_name)
- return notif_template_html, notif_template_text
+ return template_html, template_text
def _create_mxc_to_http_filter(config):
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index b33f2a357b..14bc7823cf 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -44,7 +44,11 @@ class PusherFactory(object):
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
- templates = load_jinja2_templates(hs.config)
+ templates = load_jinja2_templates(
+ config=hs.config,
+ template_html_name=hs.config.email_notif_template_html,
+ template_text_name=hs.config.email_notif_template_text,
+ )
self.notif_template_html, self.notif_template_text = templates
self.pusher_types["email"] = self._create_email_pusher
diff --git a/synapse/res/templates/mail-expiry.css b/synapse/res/templates/mail-expiry.css
new file mode 100644
index 0000000000..3dea486467
--- /dev/null
+++ b/synapse/res/templates/mail-expiry.css
@@ -0,0 +1,4 @@
+.noticetext {
+ margin-top: 10px;
+ margin-bottom: 10px;
+}
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
new file mode 100644
index 0000000000..f0d7c66e1b
--- /dev/null
+++ b/synapse/res/templates/notice_expiry.html
@@ -0,0 +1,43 @@
+
+
+
+
+
diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt
new file mode 100644
index 0000000000..41f1c4279c
--- /dev/null
+++ b/synapse/res/templates/notice_expiry.txt
@@ -0,0 +1,7 @@
+Hi {{ display_name }},
+
+Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
+
+To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
+
+{{ url }}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 91f5247d52..a66885d349 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -33,6 +33,7 @@ from synapse.rest.client.v1 import (
from synapse.rest.client.v2_alpha import (
account,
account_data,
+ account_validity,
auth,
capabilities,
devices,
@@ -109,3 +110,4 @@ class ClientRestResource(JsonResource):
groups.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
+ account_validity.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
new file mode 100644
index 0000000000..1ff6a6b638
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.server import finish_request
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account_validity/renew$")
+ SUCCESS_HTML = b"Your account has been successfully renewed."
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(AccountValidityRenewServlet, self).__init__()
+
+ self.hs = hs
+ self.account_activity_handler = hs.get_account_validity_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ if b"token" not in request.args:
+ raise SynapseError(400, "Missing renewal token")
+ renewal_token = request.args[b"token"][0]
+
+ yield self.account_activity_handler.renew_account(renewal_token.decode('utf8'))
+
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (
+ len(AccountValidityRenewServlet.SUCCESS_HTML),
+ ))
+ request.write(AccountValidityRenewServlet.SUCCESS_HTML)
+ finish_request(request)
+ defer.returnValue(None)
+
+
+def register_servlets(hs, http_server):
+ AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/server.py b/synapse/server.py
index dc8f1ccb8c..8c30ac2fa5 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -47,6 +47,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers
+from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
@@ -183,6 +184,7 @@ class HomeServer(object):
'room_context_handler',
'sendmail',
'registration_handler',
+ 'account_validity_handler',
]
REQUIRED_ON_MASTER_STARTUP = [
@@ -506,6 +508,9 @@ class HomeServer(object):
def build_registration_handler(self):
return RegistrationHandler(self)
+ def build_account_validity_handler(self):
+ return AccountValidityHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..a1085ad80c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
@@ -87,25 +88,156 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user):
+ def get_expiration_ts_for_user(self, user_id):
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user (str): The ID of the user.
+ user_id (str): The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
- keyvalues={"user_id": user.to_string()},
+ keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
- desc="get_expiration_date_for_user",
+ desc="get_expiration_ts_for_user",
)
defer.returnValue(res)
+ @defer.inlineCallbacks
+ def renew_account_for_user(self, user_id, new_expiration_ts):
+ """Updates the account validity table with a new timestamp for a given
+ user, removes the existing renewal token from this user, and unsets the
+ flag indicating that an email has been sent for renewing this account.
+
+ Args:
+ user_id (str): ID of the user whose account validity to renew.
+ new_expiration_ts: New expiration date, as a timestamp in milliseconds
+ since epoch.
+ """
+ def renew_account_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": new_expiration_ts,
+ "email_sent": False,
+ "renewal_token": None,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "renew_account_for_user",
+ renew_account_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
+ )
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
+ )
+
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
@@ -584,20 +716,22 @@ class RegistrationStore(
},
)
- if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- }
- )
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+ if self._account_validity.enabled:
+ now_ms = self.clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ }
+ )
+
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql
index 57249262d7..2357626000 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity.sql
@@ -13,8 +13,15 @@
* limitations under the License.
*/
+DROP TABLE IF EXISTS account_validity;
+
-- Track what users are in public rooms.
CREATE TABLE IF NOT EXISTS account_validity (
user_id TEXT PRIMARY KEY,
- expiration_ts_ms BIGINT NOT NULL
+ expiration_ts_ms BIGINT NOT NULL,
+ email_sent BOOLEAN NOT NULL,
+ renewal_token TEXT
);
+
+CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
+CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d3611ed21f..8fb5140a05 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,14 +1,22 @@
import datetime
import json
+import os
+
+import pkg_resources
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import admin, login
-from synapse.rest.client.v2_alpha import register, sync
+from synapse.rest.client.v2_alpha import account_validity, register, sync
from tests import unittest
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except ImportError:
+ load_jinja2_templates = None
+
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@@ -197,6 +205,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ # Test for account expiring after a week.
config.enable_registration = True
config.account_validity.enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week
@@ -228,3 +237,92 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
)
+
+
+class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
+
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ register.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # Test for account expiring after a week and renewal emails being sent 2
+ # days before expiry.
+ config.enable_registration = True
+ config.account_validity.enabled = True
+ config.account_validity.renew_by_email_enabled = True
+ config.account_validity.period = 604800000 # Time in ms for 1 week
+ config.account_validity.renew_at = 172800000 # Time in ms for 2 days
+ config.account_validity.renew_email_subject = "Renew your account"
+
+ # Email config.
+ self.email_attempts = []
+
+ def sendmail(*args, **kwargs):
+ self.email_attempts.append((args, kwargs))
+ return
+
+ config.email_template_dir = os.path.abspath(
+ pkg_resources.resource_filename('synapse', 'res/templates')
+ )
+ config.email_expiry_template_html = "notice_expiry.html"
+ config.email_expiry_template_text = "notice_expiry.txt"
+ config.email_smtp_host = "127.0.0.1"
+ config.email_smtp_port = 20
+ config.require_transport_security = False
+ config.email_smtp_user = None
+ config.email_smtp_pass = None
+ config.email_notif_from = "test@example.com"
+
+ self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+
+ self.store = self.hs.get_datastore()
+
+ return self.hs
+
+ def test_renewal_email(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ # We need to manually add an email address otherwise the handler will do
+ # nothing.
+ now = self.hs.clock.time_msec()
+ self.get_success(self.store.user_add_threepid(
+ user_id=user_id, medium="email", address="kermit@example.com",
+ validated_at=now, added_at=now,
+ ))
+
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Move 6 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ self.assertEqual(len(self.email_attempts), 1)
+
+ # Retrieving the URL from the email is too much pain for now, so we
+ # retrieve the token from the DB.
+ renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ request, channel = self.make_request(b"GET", url)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Move 3 days forward. If the renewal failed, every authed request with
+ # our access token should be denied from now, otherwise they should
+ # succeed.
+ self.reactor.advance(datetime.timedelta(days=3).total_seconds())
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)