Add some basic event validation
This commit is contained in:
parent
6a8148f15b
commit
b245ee34ed
|
@ -15,6 +15,7 @@
|
|||
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
|
||||
class EventValidator(object):
|
||||
|
@ -23,14 +24,19 @@ class EventValidator(object):
|
|||
EventID.from_string(event.event_id)
|
||||
RoomID.from_string(event.room_id)
|
||||
|
||||
hasattr(event, "auth_events")
|
||||
hasattr(event, "content")
|
||||
hasattr(event, "hashes")
|
||||
hasattr(event, "origin")
|
||||
hasattr(event, "prev_events")
|
||||
hasattr(event, "prev_events")
|
||||
hasattr(event, "sender")
|
||||
hasattr(event, "type")
|
||||
required = [
|
||||
# "auth_events",
|
||||
"content",
|
||||
# "hashes",
|
||||
"origin",
|
||||
# "prev_events",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
|
||||
for k in required:
|
||||
if not hasattr(event, k):
|
||||
raise SynapseError(400, "Event does not have key %s" % (k,))
|
||||
|
||||
# Check that the following keys have string values
|
||||
strings = [
|
||||
|
@ -46,6 +52,13 @@ class EventValidator(object):
|
|||
if not isinstance(getattr(event, s), basestring):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if "membership" not in event.content:
|
||||
raise SynapseError(400, "Content has not membership key")
|
||||
|
||||
if event.content["membership"] not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
# Check that the following keys have dictionary values
|
||||
# TODO
|
||||
|
||||
|
|
|
@ -19,6 +19,9 @@ from synapse.api.constants import EventTypes, Membership
|
|||
from synapse.api.errors import RoomError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from synapse.events.validator import EventValidator
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
@ -33,6 +36,7 @@ class MessageHandler(BaseHandler):
|
|||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.validator = EventValidator()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_message(self, msg_id=None, room_id=None, sender_id=None,
|
||||
|
@ -137,6 +141,8 @@ class MessageHandler(BaseHandler):
|
|||
def handle_event(self, event_dict):
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
self.validator.validate(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
if membership == Membership.JOIN:
|
||||
|
@ -152,8 +158,6 @@ class MessageHandler(BaseHandler):
|
|||
builder=builder,
|
||||
)
|
||||
|
||||
# TODO: self.validator.validate(event)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.change_membership(event, context)
|
||||
|
|
Loading…
Reference in New Issue