Add ratelimiting function to basehandler
This commit is contained in:
parent
dd2cd9312a
commit
c7a7cdf734
|
@ -28,6 +28,7 @@ class Codes(object):
|
||||||
UNKNOWN = "M_UNKNOWN"
|
UNKNOWN = "M_UNKNOWN"
|
||||||
NOT_FOUND = "M_NOT_FOUND"
|
NOT_FOUND = "M_NOT_FOUND"
|
||||||
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
||||||
|
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(Exception):
|
class CodeMessageException(Exception):
|
||||||
|
|
|
@ -247,6 +247,7 @@ def setup():
|
||||||
upload_dir=os.path.abspath("uploads"),
|
upload_dir=os.path.abspath("uploads"),
|
||||||
db_name=config.database_path,
|
db_name=config.database_path,
|
||||||
tls_context_factory=tls_context_factory,
|
tls_context_factory=tls_context_factory,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.register_servlets()
|
hs.register_servlets()
|
||||||
|
|
|
@ -17,8 +17,10 @@ from .tls import TlsConfig
|
||||||
from .server import ServerConfig
|
from .server import ServerConfig
|
||||||
from .logger import LoggingConfig
|
from .logger import LoggingConfig
|
||||||
from .database import DatabaseConfig
|
from .database import DatabaseConfig
|
||||||
|
from .ratelimiting import RatelimitConfig
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig):
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
|
RatelimitConfig):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if __name__=='__main__':
|
if __name__=='__main__':
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from synapse.api.errors import cs_error, Codes
|
||||||
|
|
||||||
class BaseHandler(object):
|
class BaseHandler(object):
|
||||||
|
|
||||||
|
@ -25,8 +26,24 @@ class BaseHandler(object):
|
||||||
self.room_lock = hs.get_room_lock_manager()
|
self.room_lock = hs.get_room_lock_manager()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
def ratelimit(self, user_id):
|
||||||
|
time_now = self.clock.time()
|
||||||
|
allowed, time_allowed = self.ratelimiter.send_message(
|
||||||
|
user_id, time_now,
|
||||||
|
msg_rate_hz=self.hs.config.rc_messages_per_second,
|
||||||
|
burst_count=self.hs.config.rc_messsage_burst_count,
|
||||||
|
)
|
||||||
|
if not allowed:
|
||||||
|
raise cs_error(
|
||||||
|
"Limit exceeded",
|
||||||
|
Codes.M_LIMIT_EXCEEDED,
|
||||||
|
retry_after_ms=1000*(time_allowed - time_now),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseRoomHandler(BaseHandler):
|
class BaseRoomHandler(BaseHandler):
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from synapse.util import Clock
|
||||||
from synapse.util.distributor import Distributor
|
from synapse.util.distributor import Distributor
|
||||||
from synapse.util.lockutils import LockManager
|
from synapse.util.lockutils import LockManager
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
|
||||||
|
|
||||||
class BaseHomeServer(object):
|
class BaseHomeServer(object):
|
||||||
|
@ -73,6 +74,7 @@ class BaseHomeServer(object):
|
||||||
'resource_for_web_client',
|
'resource_for_web_client',
|
||||||
'resource_for_content_repo',
|
'resource_for_content_repo',
|
||||||
'event_sources',
|
'event_sources',
|
||||||
|
'ratelimiter',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
|
@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
|
def build_ratelimiter(self):
|
||||||
|
return Ratelimiter()
|
||||||
|
|
||||||
def register_servlets(self):
|
def register_servlets(self):
|
||||||
""" Register all servlets associated with this HomeServer.
|
""" Register all servlets associated with this HomeServer.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue