Make RateLimiter class check for ratelimit overrides (#9711)
This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited. We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits. Fixes #9663
This commit is contained in:
parent
3a446c21f8
commit
963f4309fe
|
@ -0,0 +1 @@
|
||||||
|
Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
|
|
@ -17,6 +17,7 @@ from collections import OrderedDict
|
||||||
from typing import Hashable, Optional, Tuple
|
from typing import Hashable, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -31,10 +32,13 @@ class Ratelimiter:
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
|
def __init__(
|
||||||
|
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
|
||||||
|
):
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.rate_hz = rate_hz
|
self.rate_hz = rate_hz
|
||||||
self.burst_count = burst_count
|
self.burst_count = burst_count
|
||||||
|
self.store = store
|
||||||
|
|
||||||
# A ordered dictionary keeping track of actions, when they were last
|
# A ordered dictionary keeping track of actions, when they were last
|
||||||
# performed and how often. Each entry is a mapping from a key of arbitrary type
|
# performed and how often. Each entry is a mapping from a key of arbitrary type
|
||||||
|
@ -46,45 +50,10 @@ class Ratelimiter:
|
||||||
OrderedDict()
|
OrderedDict()
|
||||||
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
|
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
|
||||||
|
|
||||||
def can_requester_do_action(
|
async def can_do_action(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Optional[Requester],
|
||||||
rate_hz: Optional[float] = None,
|
key: Optional[Hashable] = None,
|
||||||
burst_count: Optional[int] = None,
|
|
||||||
update: bool = True,
|
|
||||||
_time_now_s: Optional[int] = None,
|
|
||||||
) -> Tuple[bool, float]:
|
|
||||||
"""Can the requester perform the action?
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requester: The requester to key off when rate limiting. The user property
|
|
||||||
will be used.
|
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
|
||||||
Overrides the value set during instantiation if set.
|
|
||||||
burst_count: How many actions that can be performed before being limited.
|
|
||||||
Overrides the value set during instantiation if set.
|
|
||||||
update: Whether to count this check as performing the action
|
|
||||||
_time_now_s: The current time. Optional, defaults to the current time according
|
|
||||||
to self.clock. Only used by tests.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
* A bool indicating if they can perform the action now
|
|
||||||
* The reactor timestamp for when the action can be performed next.
|
|
||||||
-1 if rate_hz is less than or equal to zero
|
|
||||||
"""
|
|
||||||
# Disable rate limiting of users belonging to any AS that is configured
|
|
||||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
|
||||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
|
||||||
return True, -1.0
|
|
||||||
|
|
||||||
return self.can_do_action(
|
|
||||||
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
|
|
||||||
)
|
|
||||||
|
|
||||||
def can_do_action(
|
|
||||||
self,
|
|
||||||
key: Hashable,
|
|
||||||
rate_hz: Optional[float] = None,
|
rate_hz: Optional[float] = None,
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
|
@ -92,9 +61,16 @@ class Ratelimiter:
|
||||||
) -> Tuple[bool, float]:
|
) -> Tuple[bool, float]:
|
||||||
"""Can the entity (e.g. user or IP address) perform the action?
|
"""Can the entity (e.g. user or IP address) perform the action?
|
||||||
|
|
||||||
|
Checks if the user has ratelimiting disabled in the database by looking
|
||||||
|
for null/zero values in the `ratelimit_override` table. (Non-zero
|
||||||
|
values aren't honoured, as they're specific to the event sending
|
||||||
|
ratelimiter, rather than all ratelimiters)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key we should use when rate limiting. Can be a user ID
|
requester: The requester that is doing the action, if any. Used to check
|
||||||
(when sending events), an IP address, etc.
|
if the user has ratelimits disabled in the database.
|
||||||
|
key: An arbitrary key used to classify an action. Defaults to the
|
||||||
|
requester's user ID.
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
rate_hz: The long term number of actions that can be performed in a second.
|
||||||
Overrides the value set during instantiation if set.
|
Overrides the value set during instantiation if set.
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
|
@ -109,6 +85,30 @@ class Ratelimiter:
|
||||||
* The reactor timestamp for when the action can be performed next.
|
* The reactor timestamp for when the action can be performed next.
|
||||||
-1 if rate_hz is less than or equal to zero
|
-1 if rate_hz is less than or equal to zero
|
||||||
"""
|
"""
|
||||||
|
if key is None:
|
||||||
|
if not requester:
|
||||||
|
raise ValueError("Must supply at least one of `requester` or `key`")
|
||||||
|
|
||||||
|
key = requester.user.to_string()
|
||||||
|
|
||||||
|
if requester:
|
||||||
|
# Disable rate limiting of users belonging to any AS that is configured
|
||||||
|
# not to be rate limited in its registration file (rate_limited: true|false).
|
||||||
|
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||||
|
return True, -1.0
|
||||||
|
|
||||||
|
# Check if ratelimiting has been disabled for the user.
|
||||||
|
#
|
||||||
|
# Note that we don't use the returned rate/burst count, as the table
|
||||||
|
# is specifically for the event sending ratelimiter. Instead, we
|
||||||
|
# only use it to (somewhat cheekily) infer whether the user should
|
||||||
|
# be subject to any rate limiting or not.
|
||||||
|
override = await self.store.get_ratelimit_for_user(
|
||||||
|
requester.authenticated_entity
|
||||||
|
)
|
||||||
|
if override and not override.messages_per_second:
|
||||||
|
return True, -1.0
|
||||||
|
|
||||||
# Override default values if set
|
# Override default values if set
|
||||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
|
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
|
||||||
|
@ -175,9 +175,10 @@ class Ratelimiter:
|
||||||
else:
|
else:
|
||||||
del self.actions[key]
|
del self.actions[key]
|
||||||
|
|
||||||
def ratelimit(
|
async def ratelimit(
|
||||||
self,
|
self,
|
||||||
key: Hashable,
|
requester: Optional[Requester],
|
||||||
|
key: Optional[Hashable] = None,
|
||||||
rate_hz: Optional[float] = None,
|
rate_hz: Optional[float] = None,
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
|
@ -185,8 +186,16 @@ class Ratelimiter:
|
||||||
):
|
):
|
||||||
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
||||||
|
|
||||||
|
Checks if the user has ratelimiting disabled in the database by looking
|
||||||
|
for null/zero values in the `ratelimit_override` table. (Non-zero
|
||||||
|
values aren't honoured, as they're specific to the event sending
|
||||||
|
ratelimiter, rather than all ratelimiters)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: An arbitrary key used to classify an action
|
requester: The requester that is doing the action, if any. Used to check for
|
||||||
|
if the user has ratelimits disabled.
|
||||||
|
key: An arbitrary key used to classify an action. Defaults to the
|
||||||
|
requester's user ID.
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
rate_hz: The long term number of actions that can be performed in a second.
|
||||||
Overrides the value set during instantiation if set.
|
Overrides the value set during instantiation if set.
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
|
@ -201,7 +210,8 @@ class Ratelimiter:
|
||||||
"""
|
"""
|
||||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
|
|
||||||
allowed, time_allowed = self.can_do_action(
|
allowed, time_allowed = await self.can_do_action(
|
||||||
|
requester,
|
||||||
key,
|
key,
|
||||||
rate_hz=rate_hz,
|
rate_hz=rate_hz,
|
||||||
burst_count=burst_count,
|
burst_count=burst_count,
|
||||||
|
|
|
@ -870,6 +870,7 @@ class FederationHandlerRegistry:
|
||||||
|
|
||||||
# A rate limiter for incoming room key requests per origin.
|
# A rate limiter for incoming room key requests per origin.
|
||||||
self._room_key_request_rate_limiter = Ratelimiter(
|
self._room_key_request_rate_limiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.config.rc_key_requests.per_second,
|
rate_hz=self.config.rc_key_requests.per_second,
|
||||||
burst_count=self.config.rc_key_requests.burst_count,
|
burst_count=self.config.rc_key_requests.burst_count,
|
||||||
|
@ -930,7 +931,9 @@ class FederationHandlerRegistry:
|
||||||
# the limit, drop them.
|
# the limit, drop them.
|
||||||
if (
|
if (
|
||||||
edu_type == EduTypes.RoomKeyRequest
|
edu_type == EduTypes.RoomKeyRequest
|
||||||
and not self._room_key_request_rate_limiter.can_do_action(origin)
|
and not await self._room_key_request_rate_limiter.can_do_action(
|
||||||
|
None, origin
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class BaseHandler:
|
||||||
|
|
||||||
# The rate_hz and burst_count are overridden on a per-user basis
|
# The rate_hz and burst_count are overridden on a per-user basis
|
||||||
self.request_ratelimiter = Ratelimiter(
|
self.request_ratelimiter = Ratelimiter(
|
||||||
clock=self.clock, rate_hz=0, burst_count=0
|
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
|
||||||
)
|
)
|
||||||
self._rc_message = self.hs.config.rc_message
|
self._rc_message = self.hs.config.rc_message
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ class BaseHandler:
|
||||||
# by the presence of rate limits in the config
|
# by the presence of rate limits in the config
|
||||||
if self.hs.config.rc_admin_redaction:
|
if self.hs.config.rc_admin_redaction:
|
||||||
self.admin_redaction_ratelimiter = Ratelimiter(
|
self.admin_redaction_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||||
|
@ -91,11 +92,6 @@ class BaseHandler:
|
||||||
if app_service is not None:
|
if app_service is not None:
|
||||||
return # do not ratelimit app service senders
|
return # do not ratelimit app service senders
|
||||||
|
|
||||||
# Disable rate limiting of users belonging to any AS that is configured
|
|
||||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
|
||||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
|
||||||
return
|
|
||||||
|
|
||||||
messages_per_second = self._rc_message.per_second
|
messages_per_second = self._rc_message.per_second
|
||||||
burst_count = self._rc_message.burst_count
|
burst_count = self._rc_message.burst_count
|
||||||
|
|
||||||
|
@ -113,11 +109,11 @@ class BaseHandler:
|
||||||
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
||||||
# If we have separate config for admin redactions, use a separate
|
# If we have separate config for admin redactions, use a separate
|
||||||
# ratelimiter as to not have user_ids clash
|
# ratelimiter as to not have user_ids clash
|
||||||
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
|
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
|
||||||
else:
|
else:
|
||||||
# Override rate and burst count per-user
|
# Override rate and burst count per-user
|
||||||
self.request_ratelimiter.ratelimit(
|
await self.request_ratelimiter.ratelimit(
|
||||||
user_id,
|
requester,
|
||||||
rate_hz=messages_per_second,
|
rate_hz=messages_per_second,
|
||||||
burst_count=burst_count,
|
burst_count=burst_count,
|
||||||
update=update,
|
update=update,
|
||||||
|
|
|
@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
|
@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# Ratelimitier for failed /login attempts
|
# Ratelimitier for failed /login attempts
|
||||||
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
|
@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
# Check if we should be ratelimited due to too many previous failed attempts
|
# Check if we should be ratelimited due to too many previous failed attempts
|
||||||
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
|
await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
|
||||||
|
|
||||||
# build a list of supported flows
|
# build a list of supported flows
|
||||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||||
|
@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
except LoginError:
|
except LoginError:
|
||||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||||
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
|
await self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||||
|
requester,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# find the completed login type
|
# find the completed login type
|
||||||
|
@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
|
||||||
# We also apply account rate limiting using the 3PID as a key, as
|
# We also apply account rate limiting using the 3PID as a key, as
|
||||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
(medium, address), update=False
|
None, (medium, address), update=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for login providers that support 3pid login types
|
# Check for login providers that support 3pid login types
|
||||||
|
@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
|
||||||
# this code path, which is fine as then the per-user ratelimit
|
# this code path, which is fine as then the per-user ratelimit
|
||||||
# will kick in below.
|
# will kick in below.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
(medium, address)
|
None, (medium, address)
|
||||||
)
|
)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# Check if we've hit the failed ratelimit (but don't update it)
|
# Check if we've hit the failed ratelimit (but don't update it)
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
qualified_user_id.lower(), update=False
|
None, qualified_user_id.lower(), update=False
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
|
||||||
# exception and masking the LoginError. The actual ratelimiting
|
# exception and masking the LoginError. The actual ratelimiting
|
||||||
# should have happened above.
|
# should have happened above.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
qualified_user_id.lower()
|
None, qualified_user_id.lower()
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
@ -81,6 +81,7 @@ class DeviceMessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._ratelimiter = Ratelimiter(
|
self._ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.rc_key_requests.per_second,
|
rate_hz=hs.config.rc_key_requests.per_second,
|
||||||
burst_count=hs.config.rc_key_requests.burst_count,
|
burst_count=hs.config.rc_key_requests.burst_count,
|
||||||
|
@ -191,8 +192,8 @@ class DeviceMessageHandler:
|
||||||
if (
|
if (
|
||||||
message_type == EduTypes.RoomKeyRequest
|
message_type == EduTypes.RoomKeyRequest
|
||||||
and user_id != sender_user_id
|
and user_id != sender_user_id
|
||||||
and self._ratelimiter.can_do_action(
|
and await self._ratelimiter.can_do_action(
|
||||||
(sender_user_id, requester.device_id)
|
requester, (sender_user_id, requester.device_id)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler):
|
||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
# We don't rate limit based on room ID, as that should be done by
|
# We don't rate limit based on room ID, as that should be done by
|
||||||
# sending server.
|
# sending server.
|
||||||
member_handler.ratelimit_invite(None, event.state_key)
|
await member_handler.ratelimit_invite(None, None, event.state_key)
|
||||||
|
|
||||||
# keep a record of the room version, if we don't yet know it.
|
# keep a record of the room version, if we don't yet know it.
|
||||||
# (this may get overwritten if we later get a different room version in a
|
# (this may get overwritten if we later get a different room version in a
|
||||||
|
|
|
@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
|
||||||
|
|
||||||
# Ratelimiters for `/requestToken` endpoints.
|
# Ratelimiters for `/requestToken` endpoints.
|
||||||
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||||
)
|
)
|
||||||
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ratelimit_request_token_requests(
|
async def ratelimit_request_token_requests(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
medium: str,
|
medium: str,
|
||||||
|
@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
|
||||||
address: The actual threepid ID, e.g. the phone number or email address
|
address: The actual threepid ID, e.g. the phone number or email address
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
|
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||||
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
|
None, (medium, request.getClientIP())
|
||||||
|
)
|
||||||
|
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||||
|
None, (medium, address)
|
||||||
|
)
|
||||||
|
|
||||||
async def threepid_from_creds(
|
async def threepid_from_creds(
|
||||||
self, id_server: str, creds: Dict[str, str]
|
self, id_server: str, creds: Dict[str, str]
|
||||||
|
|
|
@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
self.check_registration_ratelimit(address)
|
await self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
result = await self.spam_checker.check_registration_for_spam(
|
result = await self.spam_checker.check_registration_for_spam(
|
||||||
threepid,
|
threepid,
|
||||||
|
@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
async def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
||||||
"""A simple helper method to check whether the registration rate limit has been hit
|
"""A simple helper method to check whether the registration rate limit has been hit
|
||||||
for a given IP address
|
for a given IP address
|
||||||
|
|
||||||
|
@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
if not address:
|
if not address:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.ratelimiter.ratelimit(address)
|
await self.ratelimiter.ratelimit(None, address)
|
||||||
|
|
||||||
async def register_with_store(
|
async def register_with_store(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||||
|
|
||||||
self._join_rate_limiter_local = Ratelimiter(
|
self._join_rate_limiter_local = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
||||||
)
|
)
|
||||||
self._join_rate_limiter_remote = Ratelimiter(
|
self._join_rate_limiter_remote = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invites_per_room_limiter = Ratelimiter(
|
self._invites_per_room_limiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
||||||
)
|
)
|
||||||
self._invites_per_user_limiter = Ratelimiter(
|
self._invites_per_user_limiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
||||||
|
@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
async def forget(self, user: UserID, room_id: str) -> None:
|
async def forget(self, user: UserID, room_id: str) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
|
async def ratelimit_invite(
|
||||||
|
self,
|
||||||
|
requester: Optional[Requester],
|
||||||
|
room_id: Optional[str],
|
||||||
|
invitee_user_id: str,
|
||||||
|
):
|
||||||
"""Ratelimit invites by room and by target user.
|
"""Ratelimit invites by room and by target user.
|
||||||
|
|
||||||
If room ID is missing then we just rate limit by target user.
|
If room ID is missing then we just rate limit by target user.
|
||||||
"""
|
"""
|
||||||
if room_id:
|
if room_id:
|
||||||
self._invites_per_room_limiter.ratelimit(room_id)
|
await self._invites_per_room_limiter.ratelimit(requester, room_id)
|
||||||
|
|
||||||
self._invites_per_user_limiter.ratelimit(invitee_user_id)
|
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
|
||||||
|
|
||||||
async def _local_membership_update(
|
async def _local_membership_update(
|
||||||
self,
|
self,
|
||||||
|
@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
(
|
(
|
||||||
allowed,
|
allowed,
|
||||||
time_allowed,
|
time_allowed,
|
||||||
) = self._join_rate_limiter_local.can_requester_do_action(requester)
|
) = await self._join_rate_limiter_local.can_do_action(requester)
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise LimitExceededError(
|
raise LimitExceededError(
|
||||||
|
@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
if effective_membership_state == Membership.INVITE:
|
if effective_membership_state == Membership.INVITE:
|
||||||
target_id = target.to_string()
|
target_id = target.to_string()
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
# Don't ratelimit application services.
|
await self.ratelimit_invite(requester, room_id, target_id)
|
||||||
if not requester.app_service or requester.app_service.is_rate_limited():
|
|
||||||
self.ratelimit_invite(room_id, target_id)
|
|
||||||
|
|
||||||
# block any attempts to invite the server notices mxid
|
# block any attempts to invite the server notices mxid
|
||||||
if target_id == self._server_notices_mxid:
|
if target_id == self._server_notices_mxid:
|
||||||
|
@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
(
|
(
|
||||||
allowed,
|
allowed,
|
||||||
time_allowed,
|
time_allowed,
|
||||||
) = self._join_rate_limiter_remote.can_requester_do_action(
|
) = await self._join_rate_limiter_remote.can_do_action(
|
||||||
requester,
|
requester,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
self.registration_handler.check_registration_ratelimit(content["address"])
|
await self.registration_handler.check_registration_ratelimit(content["address"])
|
||||||
|
|
||||||
await self.registration_handler.register_with_store(
|
await self.registration_handler.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
|
||||||
|
|
||||||
self._well_known_builder = WellKnownBuilder(hs)
|
self._well_known_builder = WellKnownBuilder(hs)
|
||||||
self._address_ratelimiter = Ratelimiter(
|
self._address_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||||
)
|
)
|
||||||
self._account_ratelimiter = Ratelimiter(
|
self._account_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_account.per_second,
|
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||||
|
@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
|
||||||
appservice = self.auth.get_appservice_by_req(request)
|
appservice = self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
if appservice.is_rate_limited():
|
if appservice.is_rate_limited():
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientIP()
|
||||||
|
)
|
||||||
|
|
||||||
result = await self._do_appservice_login(login_submission, appservice)
|
result = await self._do_appservice_login(login_submission, appservice)
|
||||||
elif self.jwt_enabled and (
|
elif self.jwt_enabled and (
|
||||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||||
):
|
):
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_jwt_login(login_submission)
|
result = await self._do_jwt_login(login_submission)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_token_login(login_submission)
|
result = await self._do_token_login(login_submission)
|
||||||
else:
|
else:
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_other_login(login_submission)
|
result = await self._do_other_login(login_submission)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
|
||||||
# too often. This happens here rather than before as we don't
|
# too often. This happens here rather than before as we don't
|
||||||
# necessarily know the user before now.
|
# necessarily know the user before now.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
await self._account_ratelimiter.ratelimit(None, user_id.lower())
|
||||||
|
|
||||||
if create_non_existent_users:
|
if create_non_existent_users:
|
||||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||||
|
|
|
@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||||
# Raise if the provided next_link value isn't valid
|
# Raise if the provided next_link value isn't valid
|
||||||
assert_valid_next_link(self.hs, next_link)
|
assert_valid_next_link(self.hs, next_link)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
# The email will be sent to the stored address.
|
# The email will be sent to the stored address.
|
||||||
# This avoids a potential account hijack by requesting a password reset to
|
# This avoids a potential account hijack by requesting a password reset to
|
||||||
|
@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
if next_link:
|
if next_link:
|
||||||
# Raise if the provided next_link value isn't valid
|
# Raise if the provided next_link value isn't valid
|
||||||
|
@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
request, "msisdn", msisdn
|
request, "msisdn", msisdn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
"email", email
|
"email", email
|
||||||
|
@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
request, "msisdn", msisdn
|
request, "msisdn", msisdn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
client_addr = request.getClientIP()
|
client_addr = request.getClientIP()
|
||||||
|
|
||||||
self.ratelimiter.ratelimit(client_addr, update=False)
|
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||||
|
|
||||||
kind = b"user"
|
kind = b"user"
|
||||||
if b"kind" in request.args:
|
if b"kind" in request.args:
|
||||||
|
|
|
@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_registration_ratelimiter(self) -> Ratelimiter:
|
def get_registration_ratelimiter(self) -> Ratelimiter:
|
||||||
return Ratelimiter(
|
return Ratelimiter(
|
||||||
|
store=self.get_datastore(),
|
||||||
clock=self.get_clock(),
|
clock=self.get_clock(),
|
||||||
rate_hz=self.config.rc_registration.per_second,
|
rate_hz=self.config.rc_registration.per_second,
|
||||||
burst_count=self.config.rc_registration.burst_count,
|
burst_count=self.config.rc_registration.burst_count,
|
||||||
|
|
|
@ -5,38 +5,25 @@ from synapse.types import create_requester
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestRatelimiter(unittest.TestCase):
|
class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
def test_allowed_via_can_do_action(self):
|
def test_allowed_via_can_do_action(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
self.assertTrue(allowed)
|
)
|
||||||
self.assertEquals(10.0, time_allowed)
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id", _time_now_s=0)
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
|
|
||||||
self.assertFalse(allowed)
|
|
||||||
self.assertEquals(10.0, time_allowed)
|
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
|
|
||||||
self.assertTrue(allowed)
|
|
||||||
self.assertEquals(20.0, time_allowed)
|
|
||||||
|
|
||||||
def test_allowed_user_via_can_requester_do_action(self):
|
|
||||||
user_requester = create_requester("@user:example.com")
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
|
||||||
user_requester, _time_now_s=0
|
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
user_requester, _time_now_s=5
|
limiter.can_do_action(None, key="test_id", _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
user_requester, _time_now_s=10
|
limiter.can_do_action(None, key="test_id", _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20.0, time_allowed)
|
self.assertEquals(20.0, time_allowed)
|
||||||
|
@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
as_requester, _time_now_s=0
|
)
|
||||||
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(as_requester, _time_now_s=0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=5
|
limiter.can_do_action(as_requester, _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=10
|
limiter.can_do_action(as_requester, _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20.0, time_allowed)
|
self.assertEquals(20.0, time_allowed)
|
||||||
|
@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
as_requester, _time_now_s=0
|
)
|
||||||
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(as_requester, _time_now_s=0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=5
|
limiter.can_do_action(as_requester, _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=10
|
limiter.can_do_action(as_requester, _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
def test_allowed_via_ratelimit(self):
|
def test_allowed_via_ratelimit(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# Shouldn't raise
|
# Shouldn't raise
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=0)
|
self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
|
||||||
|
|
||||||
# Should raise
|
# Should raise
|
||||||
with self.assertRaises(LimitExceededError) as context:
|
with self.assertRaises(LimitExceededError) as context:
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=5)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key="test_id", _time_now_s=5)
|
||||||
|
)
|
||||||
self.assertEqual(context.exception.retry_after_ms, 5000)
|
self.assertEqual(context.exception.retry_after_ms, 5000)
|
||||||
|
|
||||||
# Shouldn't raise
|
# Shouldn't raise
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key="test_id", _time_now_s=10)
|
||||||
|
)
|
||||||
|
|
||||||
def test_allowed_via_can_do_action_and_overriding_parameters(self):
|
def test_allowed_via_can_do_action_and_overriding_parameters(self):
|
||||||
"""Test that we can override options of can_do_action that would otherwise fail
|
"""Test that we can override options of can_do_action that would otherwise fail
|
||||||
an action
|
an action
|
||||||
"""
|
"""
|
||||||
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# First attempt should be allowed
|
# First attempt should be allowed
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",),
|
limiter.can_do_action(
|
||||||
_time_now_s=0,
|
None,
|
||||||
|
("test_id",),
|
||||||
|
_time_now_s=0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(10.0, time_allowed)
|
self.assertEqual(10.0, time_allowed)
|
||||||
|
|
||||||
# Second attempt, 1s later, will fail
|
# Second attempt, 1s later, will fail
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",),
|
limiter.can_do_action(
|
||||||
_time_now_s=1,
|
None,
|
||||||
|
("test_id",),
|
||||||
|
_time_now_s=1,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEqual(10.0, time_allowed)
|
self.assertEqual(10.0, time_allowed)
|
||||||
|
|
||||||
# But, if we allow 10 actions/sec for this request, we should be allowed
|
# But, if we allow 10 actions/sec for this request, we should be allowed
|
||||||
# to continue.
|
# to continue.
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",), _time_now_s=1, rate_hz=10.0
|
limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(1.1, time_allowed)
|
self.assertEqual(1.1, time_allowed)
|
||||||
|
|
||||||
# Similarly if we allow a burst of 10 actions
|
# Similarly if we allow a burst of 10 actions
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",), _time_now_s=1, burst_count=10
|
limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(1.0, time_allowed)
|
self.assertEqual(1.0, time_allowed)
|
||||||
|
@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
fail an action
|
fail an action
|
||||||
"""
|
"""
|
||||||
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# First attempt should be allowed
|
# First attempt should be allowed
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=0)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
|
||||||
|
)
|
||||||
|
|
||||||
# Second attempt, 1s later, will fail
|
# Second attempt, 1s later, will fail
|
||||||
with self.assertRaises(LimitExceededError) as context:
|
with self.assertRaises(LimitExceededError) as context:
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
|
||||||
|
)
|
||||||
self.assertEqual(context.exception.retry_after_ms, 9000)
|
self.assertEqual(context.exception.retry_after_ms, 9000)
|
||||||
|
|
||||||
# But, if we allow 10 actions/sec for this request, we should be allowed
|
# But, if we allow 10 actions/sec for this request, we should be allowed
|
||||||
# to continue.
|
# to continue.
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
|
||||||
|
)
|
||||||
|
|
||||||
# Similarly if we allow a burst of 10 actions
|
# Similarly if we allow a burst of 10 actions
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
|
||||||
|
)
|
||||||
|
|
||||||
def test_pruning(self):
|
def test_pruning(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
limiter.can_do_action(key="test_id_1", _time_now_s=0)
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test_id_1", limiter.actions)
|
self.assertIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
limiter.can_do_action(key="test_id_2", _time_now_s=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertNotIn("test_id_1", limiter.actions)
|
self.assertNotIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
|
def test_db_user_override(self):
|
||||||
|
"""Test that users that have ratelimiting disabled in the DB aren't
|
||||||
|
ratelimited.
|
||||||
|
"""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
|
||||||
|
user_id = "@user:test"
|
||||||
|
requester = create_requester(user_id)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
table="ratelimit_override",
|
||||||
|
values={
|
||||||
|
"user_id": user_id,
|
||||||
|
"messages_per_second": None,
|
||||||
|
"burst_count": None,
|
||||||
|
},
|
||||||
|
desc="test_db_user_override",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
|
||||||
|
|
||||||
|
# Shouldn't raise
|
||||||
|
for _ in range(20):
|
||||||
|
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
|
||||||
|
|
Loading…
Reference in New Issue