Make context.auth_events grap auth events from current state. Otherwise auth is wrong.

This commit is contained in:
Erik Johnston 2015-03-16 00:18:08 +00:00
parent ab8229479b
commit ea8590cf66
2 changed files with 12 additions and 18 deletions

View File

@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
class Auth(object): class Auth(object):
def __init__(self, hs): def __init__(self, hs):
@ -427,7 +433,7 @@ class Auth(object):
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
if v.event_id in auth_ids if v.type in AuthEventTypes
} }
def compute_auth_events(self, event, current_state): def compute_auth_events(self, event, current_state):

View File

@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
SIZE_OF_CACHE = 1000 SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20 EVICTION_TIMEOUT_SECONDS = 20
@ -187,17 +182,10 @@ class StateHandler(object):
replaces = context.current_state[key] replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: e for k, e in context.current_state.items()
for k, v in context.current_state.items() if k[0] in AuthEventTypes
if v.event_id in auth_ids
} }
else:
context.auth_events = {}
context.prev_state_events = prev_state context.prev_state_events = prev_state
defer.returnValue(context) defer.returnValue(context)