Don't push if an user account has expired (#8353)
This commit is contained in:
parent
4bb203ea4f
commit
916bb9d0d1
|
@ -0,0 +1 @@
|
|||
Don't send push notifications to expired user accounts.
|
|
@ -218,11 +218,7 @@ class Auth:
|
|||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled and not allow_expired:
|
||||
user_id = user.to_string()
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
if (
|
||||
expiration_ts is not None
|
||||
and self.clock.time_msec() >= expiration_ts
|
||||
):
|
||||
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
||||
raise AuthError(
|
||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
)
|
||||
|
|
|
@ -60,6 +60,8 @@ class PusherPool:
|
|||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
|
||||
# We shard the handling of push notifications by user ID.
|
||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -202,6 +204,14 @@ class PusherPool:
|
|||
)
|
||||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity.enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
p.on_new_notifications(max_stream_id)
|
||||
|
@ -222,6 +232,14 @@ class PusherPool:
|
|||
)
|
||||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity.enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||
|
|
|
@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="get_expiration_ts_for_user",
|
||||
)
|
||||
|
||||
async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
|
||||
"""
|
||||
Returns whether an user account is expired.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
current_ts: The current timestamp
|
||||
|
||||
Returns:
|
||||
Whether the user account has expired
|
||||
"""
|
||||
expiration_ts = await self.get_expiration_ts_for_user(user_id)
|
||||
return expiration_ts is not None and current_ts >= expiration_ts
|
||||
|
||||
async def set_account_validity_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
Loading…
Reference in New Issue