diff --git a/synapse/storage/chunk_ordered_table.py b/synapse/storage/chunk_ordered_table.py index 79d0ca44ec..87a57f87b3 100644 --- a/synapse/storage/chunk_ordered_table.py +++ b/synapse/storage/chunk_ordered_table.py @@ -281,10 +281,32 @@ class ChunkDBOrderedListStore(OrderedListStore): # We pick the interval to try and minimise the number of decimal # places, i.e. we round to nearest float with `rebalance_digits` and # use that as one side of the interval + order = self._get_order(node_id) + rebalance_digits = self.rebalance_digits a = round(order, self.rebalance_digits) - min_order = a - 10 ** -self.rebalance_digits - max_order = a + 10 ** -self.rebalance_digits + diff = 10 ** - self.rebalance_digits + + while True: + min_order = a - diff + max_order = a + diff + + sql = """ + SELECT count(chunk_id) FROM chunk_linearized + WHERE ordering >= ? AND ordering <= ? AND room_id = ? + """ + self.txn.execute(sql, ( + min_order - self.min_difference, + max_order + self.min_difference, + self.room_id, + )) + + cnt, = self.txn.fetchone() + step = (max_order - min_order) / cnt + if step > 1 / self.min_difference: + break + + diff *= 2 # Now we get all the nodes in the range. We add the minimum difference # to the bounds to ensure that we don't accidentally move a node to be @@ -292,6 +314,7 @@ class ChunkDBOrderedListStore(OrderedListStore): sql = """ SELECT chunk_id FROM chunk_linearized WHERE ordering >= ? AND ordering <= ? AND room_id = ? + ORDER BY ordering ASC """ self.txn.execute(sql, ( min_order - self.min_difference, diff --git a/synapse/util/katriel_bodlaender.py b/synapse/util/katriel_bodlaender.py index 16126ec936..d030e37013 100644 --- a/synapse/util/katriel_bodlaender.py +++ b/synapse/util/katriel_bodlaender.py @@ -112,6 +112,12 @@ class OrderedListStore(object): pe_s = self.get_nodes_with_edges_to(s) fe_t = self.get_nodes_with_edges_from(t) + for n, _ in pe_s: + assert n not in to_s + + for n, _ in fe_t: + assert n not in from_t + l_s = len(pe_s) l_t = len(fe_t) @@ -145,15 +151,19 @@ class OrderedListStore(object): if t is None: t = self.get_next(source) + for node_id in to_s: + self._delete_ordering(node_id) + while to_s: s1 = to_s.pop() - self._delete_ordering(s1) self._insert_after(s1, s) s = s1 + for node_id in from_t: + self._delete_ordering(node_id) + while from_t: t1 = from_t.pop() - self._delete_ordering(t1) self._insert_before(t1, t) t = t1 diff --git a/tests/storage/test_chunk_linearizer_table.py b/tests/storage/test_chunk_linearizer_table.py index beb1ac9a42..9cac62061b 100644 --- a/tests/storage/test_chunk_linearizer_table.py +++ b/tests/storage/test_chunk_linearizer_table.py @@ -48,6 +48,7 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): table.add_node("A") table._insert_after("B", "A") table._insert_before("C", "A") + table._insert_after("D", "A") sql = """ SELECT chunk_id FROM chunk_linearized @@ -58,7 +59,7 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): ordered = [r for r, in txn] - self.assertEqual(["C", "A", "B"], ordered) + self.assertEqual(["C", "A", "D", "B"], ordered) yield self.store.runInteraction("test", test_txn) @@ -183,3 +184,44 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): self.assertEqual(expected, ordered) yield self.store.runInteraction("test", test_txn) + + @defer.inlineCallbacks + def test_get_edges_to(self): + room_id = "foo_room4" + + def test_txn(txn): + table = ChunkDBOrderedListStore( + txn, room_id, self.clock, 1, 100, + ) + + table.add_node("A") + table._insert_after("B", "A") + table._add_edge_to_graph("A", "B") + table._insert_before("C", "A") + table._add_edge_to_graph("C", "A") + + nodes = table.get_nodes_with_edges_from("A") + self.assertEqual([n for _, n in nodes], ["B"]) + + nodes = table.get_nodes_with_edges_to("A") + self.assertEqual([n for _, n in nodes], ["C"]) + + yield self.store.runInteraction("test", test_txn) + + @defer.inlineCallbacks + def test_get_next_and_prev(self): + room_id = "foo_room5" + + def test_txn(txn): + table = ChunkDBOrderedListStore( + txn, room_id, self.clock, 1, 100, + ) + + table.add_node("A") + table._insert_after("B", "A") + table._insert_before("C", "A") + + self.assertEqual(table.get_next("A"), "B") + self.assertEqual(table.get_prev("A"), "C") + + yield self.store.runInteraction("test", test_txn) diff --git a/tests/util/test_katriel_bodlaender.py b/tests/util/test_katriel_bodlaender.py index 5768408604..72126bdea9 100644 --- a/tests/util/test_katriel_bodlaender.py +++ b/tests/util/test_katriel_bodlaender.py @@ -56,3 +56,29 @@ class KatrielBodlaenderTests(unittest.TestCase): store.add_edge("node_4", "node_3") self.assertEqual(list(reversed(nodes)), store.list) + + def test_divergent_graph(self): + store = InMemoryOrderedListStore() + + nodes = [ + "node_1", + "node_2", + "node_3", + "node_4", + "node_5", + "node_6", + ] + + for node in reversed(nodes): + store.add_node(node) + + store.add_edge("node_2", "node_3") + store.add_edge("node_2", "node_5") + store.add_edge("node_1", "node_2") + store.add_edge("node_3", "node_4") + store.add_edge("node_1", "node_3") + store.add_edge("node_4", "node_5") + store.add_edge("node_5", "node_6") + store.add_edge("node_4", "node_6") + + self.assertEqual(nodes, store.list)