Add ratelimiting function to basehandler

This commit is contained in:
Mark Haines 2014-09-02 17:57:04 +01:00
parent dd2cd9312a
commit c7a7cdf734
5 changed files with 27 additions and 1 deletions

View File

@ -28,6 +28,7 @@ class Codes(object):
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
class CodeMessageException(Exception): class CodeMessageException(Exception):

View File

@ -247,6 +247,7 @@ def setup():
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config,
) )
hs.register_servlets() hs.register_servlets()

View File

@ -17,8 +17,10 @@ from .tls import TlsConfig
from .server import ServerConfig from .server import ServerConfig
from .logger import LoggingConfig from .logger import LoggingConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig): class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig):
pass pass
if __name__=='__main__': if __name__=='__main__':

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import cs_error, Codes
class BaseHandler(object): class BaseHandler(object):
@ -25,8 +26,24 @@ class BaseHandler(object):
self.room_lock = hs.get_room_lock_manager() self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs self.hs = hs
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_messsage_burst_count,
)
if not allowed:
raise cs_error(
"Limit exceeded",
Codes.M_LIMIT_EXCEEDED,
retry_after_ms=1000*(time_allowed - time_now),
)
class BaseRoomHandler(BaseHandler): class BaseRoomHandler(BaseHandler):

View File

@ -32,6 +32,7 @@ from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
class BaseHomeServer(object): class BaseHomeServer(object):
@ -73,6 +74,7 @@ class BaseHomeServer(object):
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'event_sources', 'event_sources',
'ratelimiter',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def register_servlets(self): def register_servlets(self):
""" Register all servlets associated with this HomeServer. """ Register all servlets associated with this HomeServer.
""" """