Fix tests
This commit is contained in:
parent
50943ab942
commit
3f11953fcb
|
@ -67,6 +67,8 @@ class StateGroupStore(object):
|
|||
self._event_to_state_group = {}
|
||||
self._group_to_state = {}
|
||||
|
||||
self._event_id_to_event = {}
|
||||
|
||||
self._next_group = 1
|
||||
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
|
@ -96,6 +98,16 @@ class StateGroupStore(object):
|
|||
|
||||
self._event_to_state_group[event.event_id] = state_group
|
||||
|
||||
def get_events(self, event_ids, **kwargs):
|
||||
return {
|
||||
e_id: self._event_id_to_event[e_id] for e_id in event_ids
|
||||
if e_id in self._event_id_to_event
|
||||
}
|
||||
|
||||
def register_events(self, events):
|
||||
for e in events:
|
||||
self._event_id_to_event[e.event_id] = e
|
||||
|
||||
|
||||
class DictObj(dict):
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase):
|
|||
spec_set=[
|
||||
"get_state_groups_ids",
|
||||
"add_event_hashes",
|
||||
"get_events",
|
||||
]
|
||||
)
|
||||
hs = Mock(spec_set=[
|
||||
|
@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "C"},
|
||||
{e.event_id for e in context_store["D"].current_state.values()}
|
||||
{e_id for e_id in context_store["D"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "B", "C"},
|
||||
{e.event_id for e in context_store["E"].current_state.values()}
|
||||
{e for e in context_store["E"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
self.assertSetEqual(
|
||||
{"A1", "A2", "A3", "A5", "B"},
|
||||
{e.event_id for e in context_store["D"].current_state.values()}
|
||||
{e for e in context_store["D"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
def _add_depths(self, nodes, edges):
|
||||
|
@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase):
|
|||
create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase):
|
|||
create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase):
|
|||
create_event(type="test1", state_key="1", depth=2),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
|
||||
self.assertEqual(
|
||||
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
# Reverse the depth to make sure we are actually using the depths
|
||||
# during state resolution.
|
||||
|
@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase):
|
|||
create_event(type="test1", state_key="1", depth=1),
|
||||
]
|
||||
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
|
||||
self.assertEqual(
|
||||
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
def _get_context(self, event, old_state_1, old_state_2):
|
||||
group_name_1 = "group_name_1"
|
||||
|
|
Loading…
Reference in New Issue