Change context.auth_events to what the auth_events would be bases on context.current_state, rather than based on the auth_events from the event.

This commit is contained in:
Erik Johnston 2015-02-04 14:06:42 +00:00
parent 03d415a6a2
commit 650e32d455
3 changed files with 15 additions and 9 deletions

View File

@ -358,7 +358,7 @@ class Auth(object):
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor() yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@ -372,26 +372,26 @@ class Auth(object):
if v.event_id in auth_ids if v.event_id in auth_ids
} }
def compute_auth_events(self, event, context): def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = context.current_state.get(key) power_level_event = current_state.get(key)
if power_level_event: if power_level_event:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = context.current_state.get(key) create_event = current_state.get(key)
if create_event: if create_event:
auth_ids.append(create_event.event_id) auth_ids.append(create_event.event_id)

View File

@ -842,7 +842,9 @@ class FederationHandler(BaseHandler):
logger.debug("Different auth: %s", different_auth) logger.debug("Different auth: %s", different_auth)
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(event, context) auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
try: try:

View File

@ -103,7 +103,9 @@ class StateHandler(object):
context.state_group = None context.state_group = None
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
@ -149,7 +151,9 @@ class StateHandler(object):
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()