Add ratelimiting function to basehandler
This commit is contained in:
parent
dd2cd9312a
commit
c7a7cdf734
|
@ -28,6 +28,7 @@ class Codes(object):
|
|||
UNKNOWN = "M_UNKNOWN"
|
||||
NOT_FOUND = "M_NOT_FOUND"
|
||||
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
||||
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||
|
||||
|
||||
class CodeMessageException(Exception):
|
||||
|
|
|
@ -247,6 +247,7 @@ def setup():
|
|||
upload_dir=os.path.abspath("uploads"),
|
||||
db_name=config.database_path,
|
||||
tls_context_factory=tls_context_factory,
|
||||
config=config,
|
||||
)
|
||||
|
||||
hs.register_servlets()
|
||||
|
|
|
@ -17,8 +17,10 @@ from .tls import TlsConfig
|
|||
from .server import ServerConfig
|
||||
from .logger import LoggingConfig
|
||||
from .database import DatabaseConfig
|
||||
from .ratelimiting import RatelimitConfig
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig):
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig):
|
||||
pass
|
||||
|
||||
if __name__=='__main__':
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.api.errors import cs_error, Codes
|
||||
|
||||
class BaseHandler(object):
|
||||
|
||||
|
@ -25,8 +26,24 @@ class BaseHandler(object):
|
|||
self.room_lock = hs.get_room_lock_manager()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
def ratelimit(self, user_id):
|
||||
time_now = self.clock.time()
|
||||
allowed, time_allowed = self.ratelimiter.send_message(
|
||||
user_id, time_now,
|
||||
msg_rate_hz=self.hs.config.rc_messages_per_second,
|
||||
burst_count=self.hs.config.rc_messsage_burst_count,
|
||||
)
|
||||
if not allowed:
|
||||
raise cs_error(
|
||||
"Limit exceeded",
|
||||
Codes.M_LIMIT_EXCEEDED,
|
||||
retry_after_ms=1000*(time_allowed - time_now),
|
||||
)
|
||||
|
||||
|
||||
class BaseRoomHandler(BaseHandler):
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from synapse.util import Clock
|
|||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.lockutils import LockManager
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
|
||||
|
||||
class BaseHomeServer(object):
|
||||
|
@ -73,6 +74,7 @@ class BaseHomeServer(object):
|
|||
'resource_for_web_client',
|
||||
'resource_for_content_repo',
|
||||
'event_sources',
|
||||
'ratelimiter',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
|
@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
|
|||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
def build_ratelimiter(self):
|
||||
return Ratelimiter()
|
||||
|
||||
def register_servlets(self):
|
||||
""" Register all servlets associated with this HomeServer.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue