diff --git a/synapse/storage/chunk_ordered_table.py b/synapse/storage/chunk_ordered_table.py index 1d49e24f95..90ca7b7fa9 100644 --- a/synapse/storage/chunk_ordered_table.py +++ b/synapse/storage/chunk_ordered_table.py @@ -21,6 +21,7 @@ from gmpy2 import mpq as Fraction from fractions import Fraction as FractionPy from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from synapse.util.katriel_bodlaender import OrderedListStore import synapse.metrics @@ -91,12 +92,13 @@ class ChunkDBOrderedListStore(OrderedListStore): this. """ def __init__(self, - txn, room_id, clock, + txn, room_id, clock, database_engine, rebalance_max_denominator=100, max_denominator=100000): self.txn = txn self.room_id = room_id self.clock = clock + self.database_engine = database_engine self.rebalance_md = rebalance_max_denominator self.max_denominator = max_denominator @@ -390,69 +392,43 @@ class ChunkDBOrderedListStore(OrderedListStore): logger.info("Rebalancing room %s, chunk %s", self.room_id, node_id) old_order = self._get_order(node_id) - new_order = FractionPy( - int(old_order.numerator), - int(old_order.denominator), - ).limit_denominator( - self.rebalance_md, - ) - new_order = Fraction(new_order.numerator, new_order.denominator) - if new_order < old_order: - new_order += Fraction(1, new_order.denominator) - count_nodes = [node_id] - next_id = node_id - while True: - next_id = self.get_next(next_id) + a, b, c, d = find_farey_terms(old_order, self.rebalance_md) + n = max(b, d) - if not next_id: - max_order = None - break - - count_nodes.append(next_id) - - max_order = self._get_order(next_id) - - if len(count_nodes) < self.rebalance_md * (max_order - new_order): - break - - if len(count_nodes) == 1: - orders = [new_order] - if max_order: - orders = stern_brocot_range(len(count_nodes), new_order, max_order) - orders.sort(reverse=True) - else: - orders = [ - Fraction(int(math.ceil(new_order)) + i, 1) - for i in xrange(0, len(count_nodes)) - ] - orders.reverse() - - assert len(count_nodes) == len(orders) - - next_id = node_id - prev_order = old_order - while orders: - order = orders.pop() - - if max_order: - assert old_order <= new_order <= max_order - else: - assert old_order <= new_order - - assert prev_order < order - - SQLBaseStore._simple_update_txn( - self.txn, - table="chunk_linearized", - keyvalues={"chunk_id": next_id}, - updatevalues={ - "numerator": int(order.numerator), - "denominator": int(order.denominator), - }, + with_sql = """ + WITH RECURSIVE chunks (chunk_id, next, n, a, b, c, d) AS ( + SELECT chunk_id, next_chunk_id, ?, ?, ?, ?, ? + FROM chunk_linearized WHERE chunk_id = ? + UNION ALL + SELECT n.chunk_id, n.next_chunk_id, n, c, d, ((n + b) / d) * c - a, ((n + b) / d) * d - b + FROM chunks AS c + INNER JOIN chunk_linearized AS l ON l.chunk_id = c.chunk_id + INNER JOIN chunk_linearized AS n ON n.chunk_id = l.next_chunk_id + WHERE c * 1.0 / d > n.numerator * 1.0 / n.denominator ) + """ - next_id = self.get_next(next_id) + if isinstance(self.database_engine, PostgresEngine): + sql = with_sql + """ + UPDATE chunk_linearized AS l + SET numerator = a, denominator = b + FROM chunks AS c + WHERE c.chunk_id = l.chunk_id + """ + else: + sql = with_sql + """ + UPDATE chunk_linearized + SET (numerator, denominator) = ( + SELECT a, b FROM chunks + WHERE chunks.chunk_id = chunk_linearized.chunk_id + ) + WHERE chunk_id in (SELECT chunk_id FROM chunks) + """ + + self.txn.execute(sql, ( + n, a, b, c, d, node_id + )) rebalance_counter.inc() @@ -512,7 +488,7 @@ def stern_brocot_single(min_frac, max_frac): def stern_brocot_range_depth(min_frac, max_denom): assert 0 < min_frac - states = stern_brocot_single(min_frac) + states = stern_brocot_singless(min_frac) while len(states): a, b, c, d = states.pop() @@ -534,22 +510,51 @@ def stern_brocot_range_depth(min_frac, max_denom): -def stern_brocot_single(min_frac): +def stern_brocot_state(min_frac, target_d): assert 0 <= min_frac states = [] a, b, c, d = 0, 1, 1, 0 - states.append((a, b, c, d)) - while True: f = Fraction(a + c, b + d) + + if b + d >= target_d: + return a + c, b + d, c, d + if f < min_frac: a, b, c, d = a + c, b + d, c, d elif f == min_frac: - return states + return a + c, b + d, c, d else: a, b, c, d = a, b, a + c, b + d - states.append((a, b, c, d)) + +def find_farey_terms(min_frac, max_denom): + states = deque([(0, 1, 1, 0)]) + + while True: + a, b, c, d = states.popleft() + + left = a / float(b) + mid = (a + c) / float(b + d) + right = c / float(d) if d > 0 else None + + if min_frac < left: + if b >= max_denom or d >= max_denom: + return a, b, c, d + if b + d >= max_denom: + return a + c, b + d, c, d + + states.append((a, b, a + c, b + d)) + elif min_frac < mid: + if b + d >= max_denom: + return a + c, b + d, c, d + + states.append((a, b, a + c, b + d)) + states.append((a + c, b + d, c, d)) + elif right is None or min_frac < right: + states.append((a + c, b + d, c, d)) + else: + states.append((a + c, b + d, c, d)) diff --git a/tests/storage/test_chunk_linearizer_table.py b/tests/storage/test_chunk_linearizer_table.py index f02067404d..5fbb96efd1 100644 --- a/tests/storage/test_chunk_linearizer_table.py +++ b/tests/storage/test_chunk_linearizer_table.py @@ -44,7 +44,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 100, + txn, room_id, self.clock, + self.store.database_engine, + 5, 100, ) table.add_node("A") @@ -71,7 +73,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 100, + txn, room_id, self.clock, + self.store.database_engine, + 5, 100, ) nodes = [(i, "node_%d" % (i,)) for i in xrange(1, 1000)] @@ -116,7 +120,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 1000, + txn, room_id, self.clock, + self.store.database_engine, + 5, 1000, ) table.add_node("a") @@ -152,7 +158,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 100, + txn, room_id, self.clock, + self.store.database_engine, + 5, 100, ) table.add_node("a") @@ -193,7 +201,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 100, + txn, room_id, self.clock, + self.store.database_engine, + 5, 100, ) table.add_node("A") @@ -216,7 +226,9 @@ class ChunkLinearizerStoreTestCase(tests.unittest.TestCase): def test_txn(txn): table = ChunkDBOrderedListStore( - txn, room_id, self.clock, 5, 100, + txn, room_id, self.clock, + self.store.database_engine, + 5, 100, ) table.add_node("A")