Fix ACL filtering

This commit is contained in:
Erik Johnston 2015-08-11 17:19:21 +01:00
parent 8a345190cc
commit 40affadaaa
4 changed files with 29 additions and 19 deletions

View File

@ -229,7 +229,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
states = yield self.store.get_state_for_events( events_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events), room_id, frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -237,8 +237,6 @@ class FederationHandler(BaseHandler):
) )
) )
events_and_states = zip(events, states)
def redact_disallowed(event_and_state): def redact_disallowed(event_and_state):
event, state = event_and_state event, state = event_and_state
@ -275,9 +273,10 @@ class FederationHandler(BaseHandler):
return event return event
res = map(redact_disallowed, events_and_states) res = map(redact_disallowed, [
(e, events_to_state[e.event_id])
logger.info("_filter_events_for_server %r", res) for e in events
])
defer.returnValue(res) defer.returnValue(res)

View File

@ -137,7 +137,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events): def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events( event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events), room_id, frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -145,8 +145,6 @@ class MessageHandler(BaseHandler):
) )
) )
events_and_states = zip(events, states)
def allowed(event_and_state): def allowed(event_and_state):
event, state = event_and_state event, state = event_and_state
@ -179,10 +177,17 @@ class MessageHandler(BaseHandler):
return True return True
events_and_states = filter(allowed, events_and_states) event_and_state = filter(
allowed,
[
(e, event_id_to_state[e.event_id])
for e in events
]
)
defer.returnValue([ defer.returnValue([
ev ev
for ev, _ in events_and_states for ev, _ in event_and_state
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -294,7 +294,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events): def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events( event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events), room_id, frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -302,8 +302,6 @@ class SyncHandler(BaseHandler):
) )
) )
events_and_states = zip(events, states)
def allowed(event_and_state): def allowed(event_and_state):
event, state = event_and_state event, state = event_and_state
@ -335,10 +333,18 @@ class SyncHandler(BaseHandler):
return membership == Membership.INVITE return membership == Membership.INVITE
return True return True
events_and_states = filter(allowed, events_and_states)
event_and_state = filter(
allowed,
[
(e, event_id_to_state[e.event_id])
for e in events
]
)
defer.returnValue([ defer.returnValue([
ev ev
for ev, _ in events_and_states for ev, _ in event_and_state
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -316,10 +316,10 @@ class StateStore(SQLBaseStore):
for event_id, state_ids in event_to_state_ids.items() for event_id, state_ids in event_to_state_ids.items()
} }
defer.returnValue([ defer.returnValue({
event_to_state[event] event: event_to_state[event]
for event in event_ids for event in event_ids
]) })
def _make_group_id(clock): def _make_group_id(clock):