Add management endpoints for account validity
This commit is contained in:
parent
20f0617e87
commit
eaf41a943b
|
@ -0,0 +1 @@
|
|||
Add time-based account expiration.
|
|
@ -0,0 +1,42 @@
|
|||
Account validity API
|
||||
====================
|
||||
|
||||
This API allows a server administrator to manage the validity of an account. To
|
||||
use it, you must enable the account validity feature (under
|
||||
``account_validity``) in Synapse's configuration.
|
||||
|
||||
Renew account
|
||||
-------------
|
||||
|
||||
This API extends the validity of an account by as much time as configured in the
|
||||
``period`` parameter from the ``account_validity`` configuration.
|
||||
|
||||
The API is::
|
||||
|
||||
POST /_matrix/client/unstable/account_validity/send_mail
|
||||
|
||||
with the following body:
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"user_id": "<user ID for the account to renew>",
|
||||
"expiration_ts": 0,
|
||||
"enable_renewal_emails": true
|
||||
}
|
||||
|
||||
|
||||
``expiration_ts`` is an optional parameter and overrides the expiration date,
|
||||
which otherwise defaults to now + validity period.
|
||||
|
||||
``enable_renewal_emails`` is also an optional parameter and enables/disables
|
||||
sending renewal emails to the user. Defaults to true.
|
||||
|
||||
The API returns with the new expiration date for this account, as a timestamp in
|
||||
milliseconds since epoch:
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"expiration_ts": 0
|
||||
}
|
|
@ -232,7 +232,7 @@ class Auth(object):
|
|||
if self._account_validity.enabled:
|
||||
user_id = user.to_string()
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
if expiration_ts and self.clock.time_msec() >= expiration_ts:
|
||||
if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
|
||||
raise AuthError(
|
||||
403,
|
||||
"User account has expired",
|
||||
|
|
|
@ -90,6 +90,11 @@ class AccountValidityHandler(object):
|
|||
expiration_ts=user["expiration_ts_ms"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_email_to_user(self, user_id):
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
yield self._send_renewal_email(user_id, expiration_ts)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_renewal_email(self, user_id, expiration_ts):
|
||||
"""Sends out a renewal email to every email address attached to the given user
|
||||
|
@ -217,12 +222,32 @@ class AccountValidityHandler(object):
|
|||
renewal_token (str): Token sent with the renewal request.
|
||||
"""
|
||||
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
|
||||
|
||||
logger.debug("Renewing an account for user %s", user_id)
|
||||
yield self.renew_account_for_user(user_id)
|
||||
|
||||
new_expiration_date = self.clock.time_msec() + self._account_validity.period
|
||||
@defer.inlineCallbacks
|
||||
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
|
||||
"""Renews the account attached to a given user by pushing back the
|
||||
expiration date by the current validity period in the server's
|
||||
configuration.
|
||||
|
||||
yield self.store.renew_account_for_user(
|
||||
Args:
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
expiration_ts (int): New expiration date. Defaults to now + validity period.
|
||||
email_sent (bool): Whether an email has been sent for this validity period.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[int]: New expiration date for this account, as a timestamp
|
||||
in milliseconds since epoch.
|
||||
"""
|
||||
if expiration_ts is None:
|
||||
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
||||
|
||||
yield self.store.set_account_validity_for_user(
|
||||
user_id=user_id,
|
||||
new_expiration_ts=new_expiration_date,
|
||||
expiration_ts=expiration_ts,
|
||||
email_sent=email_sent,
|
||||
)
|
||||
|
||||
defer.returnValue(expiration_ts)
|
||||
|
|
|
@ -786,6 +786,44 @@ class SearchUsersRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class AccountValidityRenewServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/account_validity/validity$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(AccountValidityRenewServlet, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "user_id" not in body:
|
||||
raise SynapseError(400, "Missing property 'user_id' in the request body")
|
||||
|
||||
expiration_ts = yield self.account_activity_handler.renew_account_for_user(
|
||||
body["user_id"], body.get("expiration_ts"),
|
||||
not body.get("enable_renewal_emails", True),
|
||||
)
|
||||
|
||||
res = {
|
||||
"expiration_ts": expiration_ts,
|
||||
}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
|
@ -801,3 +839,4 @@ def register_servlets(hs, http_server):
|
|||
ListMediaInRoom(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
|
@ -39,6 +39,7 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -58,5 +59,33 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
defer.returnValue(None)
|
||||
|
||||
|
||||
class AccountValiditySendMailServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account_validity/send_mail$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(AccountValiditySendMailServlet, self).__init__()
|
||||
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.account_validity = self.hs.config.account_validity
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if not self.account_validity.renew_by_email_enabled:
|
||||
raise AuthError(403, "Account renewal via email is disabled on this server.")
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
AccountValiditySendMailServlet(hs).register(http_server)
|
||||
|
|
|
@ -108,25 +108,30 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def renew_account_for_user(self, user_id, new_expiration_ts):
|
||||
"""Updates the account validity table with a new timestamp for a given
|
||||
user, removes the existing renewal token from this user, and unsets the
|
||||
flag indicating that an email has been sent for renewing this account.
|
||||
def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
|
||||
renewal_token=None):
|
||||
"""Updates the account validity properties of the given account, with the
|
||||
given values.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user whose account validity to renew.
|
||||
new_expiration_ts: New expiration date, as a timestamp in milliseconds
|
||||
user_id (str): ID of the account to update properties for.
|
||||
expiration_ts (int): New expiration date, as a timestamp in milliseconds
|
||||
since epoch.
|
||||
email_sent (bool): True means a renewal email has been sent for this
|
||||
account and there's no need to send another one for the current validity
|
||||
period.
|
||||
renewal_token (str): Renewal token the user can use to extend the validity
|
||||
of their account. Defaults to no token.
|
||||
"""
|
||||
def renew_account_for_user_txn(txn):
|
||||
def set_account_validity_for_user_txn(txn):
|
||||
self._simple_update_txn(
|
||||
txn=txn,
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={
|
||||
"expiration_ts_ms": new_expiration_ts,
|
||||
"email_sent": False,
|
||||
"renewal_token": None,
|
||||
"expiration_ts_ms": expiration_ts,
|
||||
"email_sent": email_sent,
|
||||
"renewal_token": renewal_token,
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
|
@ -134,8 +139,8 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"renew_account_for_user",
|
||||
renew_account_for_user_txn,
|
||||
"set_account_validity_for_user",
|
||||
set_account_validity_for_user_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -201,6 +201,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
|||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
account_validity.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
@ -238,6 +239,68 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
|||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
|
||||
)
|
||||
|
||||
def test_manual_renewal(self):
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
|
||||
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
|
||||
|
||||
# If we register the admin user at the beginning of the test, it will
|
||||
# expire at the same time as the normal user and the renewal request
|
||||
# will be denied.
|
||||
self.register_user("admin", "adminpassword", admin=True)
|
||||
admin_tok = self.login("admin", "adminpassword")
|
||||
|
||||
url = "/_matrix/client/unstable/admin/account_validity/validity"
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(
|
||||
b"POST", url, request_data, access_token=admin_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||
# endpoint.
|
||||
request, channel = self.make_request(
|
||||
b"GET", "/sync", access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
def test_manual_expire(self):
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
|
||||
self.register_user("admin", "adminpassword", admin=True)
|
||||
admin_tok = self.login("admin", "adminpassword")
|
||||
|
||||
url = "/_matrix/client/unstable/admin/account_validity/validity"
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"expiration_ts": 0,
|
||||
"enable_renewal_emails": False,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(
|
||||
b"POST", url, request_data, access_token=admin_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||
# endpoint.
|
||||
request, channel = self.make_request(
|
||||
b"GET", "/sync", access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(
|
||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
|
||||
)
|
||||
|
||||
|
||||
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
@ -287,6 +350,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
|||
return self.hs
|
||||
|
||||
def test_renewal_email(self):
|
||||
self.email_attempts = []
|
||||
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
# We need to manually add an email address otherwise the handler will do
|
||||
|
@ -297,14 +362,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
|||
validated_at=now, added_at=now,
|
||||
))
|
||||
|
||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||
# endpoint.
|
||||
request, channel = self.make_request(
|
||||
b"GET", "/sync", access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
# Move 6 days forward. This should trigger a renewal email to be sent.
|
||||
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
|
||||
self.assertEqual(len(self.email_attempts), 1)
|
||||
|
@ -326,3 +383,25 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
def test_manual_email_send(self):
|
||||
self.email_attempts = []
|
||||
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
# We need to manually add an email address otherwise the handler will do
|
||||
# nothing.
|
||||
now = self.hs.clock.time_msec()
|
||||
self.get_success(self.store.user_add_threepid(
|
||||
user_id=user_id, medium="email", address="kermit@example.com",
|
||||
validated_at=now, added_at=now,
|
||||
))
|
||||
|
||||
request, channel = self.make_request(
|
||||
b"POST", "/_matrix/client/unstable/account_validity/send_mail",
|
||||
access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.assertEqual(len(self.email_attempts), 1)
|
||||
|
|
Loading…
Reference in New Issue