Check whether to ratelimit sooner to avoid work
This commit is contained in:
parent
50ac1d843d
commit
550308c7a1
|
@ -23,7 +23,7 @@ class Ratelimiter(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.message_counts = collections.OrderedDict()
|
self.message_counts = collections.OrderedDict()
|
||||||
|
|
||||||
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
|
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
|
||||||
"""Can the user send a message?
|
"""Can the user send a message?
|
||||||
Args:
|
Args:
|
||||||
user_id: The user sending a message.
|
user_id: The user sending a message.
|
||||||
|
@ -32,12 +32,15 @@ class Ratelimiter(object):
|
||||||
second.
|
second.
|
||||||
burst_count: How many messages the user can send before being
|
burst_count: How many messages the user can send before being
|
||||||
limited.
|
limited.
|
||||||
|
update (bool): Whether to update the message rates or not. This is
|
||||||
|
useful to check if a message would be allowed to be sent before
|
||||||
|
its ready to be actually sent.
|
||||||
Returns:
|
Returns:
|
||||||
A pair of a bool indicating if they can send a message now and a
|
A pair of a bool indicating if they can send a message now and a
|
||||||
time in seconds of when they can next send a message.
|
time in seconds of when they can next send a message.
|
||||||
"""
|
"""
|
||||||
self.prune_message_counts(time_now_s)
|
self.prune_message_counts(time_now_s)
|
||||||
message_count, time_start, _ignored = self.message_counts.pop(
|
message_count, time_start, _ignored = self.message_counts.get(
|
||||||
user_id, (0., time_now_s, None),
|
user_id, (0., time_now_s, None),
|
||||||
)
|
)
|
||||||
time_delta = time_now_s - time_start
|
time_delta = time_now_s - time_start
|
||||||
|
@ -52,6 +55,7 @@ class Ratelimiter(object):
|
||||||
allowed = True
|
allowed = True
|
||||||
message_count += 1
|
message_count += 1
|
||||||
|
|
||||||
|
if update:
|
||||||
self.message_counts[user_id] = (
|
self.message_counts[user_id] = (
|
||||||
message_count, time_start, msg_rate_hz
|
message_count, time_start, msg_rate_hz
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
|
||||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
@ -239,6 +239,18 @@ class MessageHandler(BaseHandler):
|
||||||
"Tried to send member event through non-member codepath"
|
"Tried to send member event through non-member codepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
time_now = self.clock.time()
|
||||||
|
allowed, time_allowed = self.ratelimiter.send_message(
|
||||||
|
event.sender, time_now,
|
||||||
|
msg_rate_hz=self.hs.config.rc_messages_per_second,
|
||||||
|
burst_count=self.hs.config.rc_message_burst_count,
|
||||||
|
update=False,
|
||||||
|
)
|
||||||
|
if not allowed:
|
||||||
|
raise LimitExceededError(
|
||||||
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID.from_string(event.sender)
|
user = UserID.from_string(event.sender)
|
||||||
|
|
||||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||||
|
|
Loading…
Reference in New Issue