diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9b02ce0dfd..47dcc6544d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -133,7 +133,7 @@ class MessageHandler(BaseRoomHandler): if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - yield self.state_handler.handle_new_event(event) + yield self.state_handler.handle_new_event(event, snapshot) yield self._on_new_room_event(event, snapshot) @@ -362,6 +362,13 @@ class RoomCreationHandler(BaseRoomHandler): content=config, ) + snapshot = yield self.store.snapshot_room( + room_id=room_id, + user_id=user_id, + state_type=RoomConfigEvent.TYPE, + state_key="", + ) + if room_alias: yield self.store.create_room_alias_association( room_id=room_id, @@ -369,11 +376,11 @@ class RoomCreationHandler(BaseRoomHandler): servers=[self.hs.hostname], ) - yield self.state_handler.handle_new_event(config_event) + yield self.state_handler.handle_new_event(config_event, snapshot) # store_id = persist... federation_handler = self.hs.get_handlers().federation_handler - yield federation_handler.handle_new_event(config_event) + yield federation_handler.handle_new_event(config_event, snapshot) # self.notifier.on_new_room_event(event, store_id) content = {"membership": Membership.JOIN} diff --git a/synapse/state.py b/synapse/state.py index ca8e1ca630..e1a1a159bb 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -45,7 +45,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def handle_new_event(self, event): + def handle_new_event(self, event, snapshot): """ Given an event this works out if a) we have sufficient power level to update the state and b) works out what the prev_state should be. @@ -70,25 +70,13 @@ class StateHandler(object): # Now I need to fill out the prev state and work out if it has auth # (w.r.t. to power levels) - results = yield self.store.get_latest_pdus_in_context( - event.room_id - ) + snapshot.fill_out_prev_events(event) - event.prev_events = [ - encode_event_id(p_id, origin) for p_id, origin, _ in results - ] event.prev_events = [ e for e in event.prev_events if e != event.event_id ] - if results: - event.depth = max([int(v) for _, _, v in results]) + 1 - else: - event.depth = 0 - - current_state = yield self.store.get_current_state_pdu( - key.context, key.type, key.state_key - ) + current_state = snapshot.prev_state_pdu if current_state: event.prev_state = encode_event_id( diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index 613f5c307e..a84dbcc471 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -330,6 +330,7 @@ class RoomCreationTest(unittest.TestCase): db_pool=None, datastore=NonCallableMock(spec_set=[ "store_room", + "snapshot_room", ]), http_client=NonCallableMock(spec_set=[]), notifier=NonCallableMock(spec_set=["on_new_room_event"]), diff --git a/tests/test_state.py b/tests/test_state.py index e64d15a3a2..58fd0bf3be 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -243,21 +243,24 @@ class StateTestCase(unittest.TestCase): state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) - tup = ("pdu_id", "origin.com", 5) - pdus = [tup] + snapshot = Mock() + snapshot.prev_state_pdu = state_pdu + event_id = "pdu_id@origin.com" - self.persistence.get_latest_pdus_in_context.return_value = pdus - self.persistence.get_current_state_pdu.return_value = state_pdu + def fill_out_prev_events(event): + event.prev_events = [event_id] + event.depth = 6 + snapshot.fill_out_prev_events = fill_out_prev_events - yield self.state.handle_new_event(event) + yield self.state.handle_new_event(event, snapshot) - self.assertLess(tup[2], event.depth) + self.assertLess(5, event.depth) self.assertEquals(1, len(event.prev_events)) prev_id = event.prev_events[0] - self.assertEqual(encode_event_id(tup[0], tup[1]), prev_id) + self.assertEqual(event_id, prev_id) self.assertEqual( encode_event_id(state_pdu.pdu_id, state_pdu.origin),