Change slave storage to use new replication interface
As the TCP replication uses a slightly different API and streams than the HTTP replication. This breaks HTTP replication.
This commit is contained in:
parent
52bfa604e1
commit
3a1f3f8388
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
||||||
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
|
self.hs = hs
|
||||||
self.http_client = hs.get_simple_http_client()
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
pos = {}
|
pos = {}
|
||||||
|
@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("caches")
|
if stream_name == "caches":
|
||||||
if stream:
|
self._cache_id_gen.advance(token)
|
||||||
for row in stream["rows"]:
|
for row in rows:
|
||||||
(
|
|
||||||
position, cache_func, keys, invalidation_ts,
|
|
||||||
) = row
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
getattr(self, cache_func).invalidate(tuple(keys))
|
getattr(self, row.cache_func).invalidate(tuple(row.keys))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# We probably haven't pulled in the cache in this worker,
|
# We probably haven't pulled in the cache in this worker,
|
||||||
# which is fine.
|
# which is fine.
|
||||||
pass
|
pass
|
||||||
self._cache_id_gen.advance(int(stream["position"]))
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||||
txn.call_after(cache_func.invalidate, keys)
|
txn.call_after(cache_func.invalidate, keys)
|
||||||
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _send_invalidation_poke(self, cache_func, keys):
|
def _send_invalidation_poke(self, cache_func, keys):
|
||||||
try:
|
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
|
||||||
yield self.http_client.post_json_get_json(self.expire_cache_url, {
|
|
||||||
"invalidate": [{
|
|
||||||
"name": cache_func.__name__,
|
|
||||||
"keys": list(keys),
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
except:
|
|
||||||
logger.exception("Failed to poke on expire_cache")
|
|
||||||
|
|
|
@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
||||||
result["tag_account_data"] = position
|
result["tag_account_data"] = position
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("user_account_data")
|
if stream_name == "tag_account_data":
|
||||||
if stream:
|
self._account_data_id_gen.advance(token)
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
self.get_tags_for_user.invalidate((row.user_id,))
|
||||||
position, user_id, data_type = row[:3]
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
|
row.user_id, token
|
||||||
|
)
|
||||||
|
elif stream_name == "account_data":
|
||||||
|
self._account_data_id_gen.advance(token)
|
||||||
|
for row in rows:
|
||||||
|
if not row.room_id:
|
||||||
self.get_global_account_data_by_type_for_user.invalidate(
|
self.get_global_account_data_by_type_for_user.invalidate(
|
||||||
(data_type, user_id,)
|
(row.data_type, row.user_id,)
|
||||||
)
|
)
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
self.get_account_data_for_user.invalidate((row.user_id,))
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
return super(SlavedAccountDataStore, self).process_replication_rows(
|
||||||
stream = result.get("room_account_data")
|
stream_name, token, rows
|
||||||
if stream:
|
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position, user_id = row[:2]
|
|
||||||
self.get_account_data_for_user.invalidate((user_id,))
|
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
|
||||||
user_id, position
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = result.get("tag_account_data")
|
|
||||||
if stream:
|
|
||||||
self._account_data_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position, user_id = row[:2]
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
|
||||||
self._account_data_stream_cache.entity_has_changed(
|
|
||||||
user_id, position
|
|
||||||
)
|
|
||||||
|
|
||||||
return super(SlavedAccountDataStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
||||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("to_device")
|
if stream_name == "to_device":
|
||||||
if stream:
|
self._device_inbox_id_gen.advance(token)
|
||||||
self._device_inbox_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
if row.entity.startswith("@"):
|
||||||
stream_id = row[0]
|
|
||||||
entity = row[1]
|
|
||||||
|
|
||||||
if entity.startswith("@"):
|
|
||||||
self._device_inbox_stream_cache.entity_has_changed(
|
self._device_inbox_stream_cache.entity_has_changed(
|
||||||
entity, stream_id
|
row.entity, token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||||
entity, stream_id
|
row.entity, token
|
||||||
|
)
|
||||||
|
return super(SlavedDeviceInboxStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
|
||||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("device_lists")
|
if stream_name == "device_lists":
|
||||||
if stream:
|
self._device_list_id_gen.advance(token)
|
||||||
self._device_list_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
|
||||||
stream_id = row[0]
|
|
||||||
user_id = row[1]
|
|
||||||
destination = row[2]
|
|
||||||
|
|
||||||
self._device_list_stream_cache.entity_has_changed(
|
self._device_list_stream_cache.entity_has_changed(
|
||||||
user_id, stream_id
|
row.user_id, token
|
||||||
)
|
)
|
||||||
|
|
||||||
if destination:
|
if row.destination:
|
||||||
self._device_list_federation_stream_cache.entity_has_changed(
|
self._device_list_federation_stream_cache.entity_has_changed(
|
||||||
destination, stream_id
|
row.destination, token
|
||||||
|
)
|
||||||
|
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
return super(SlavedDeviceStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("events")
|
if stream_name == "events":
|
||||||
if stream:
|
self._stream_id_gen.advance(token)
|
||||||
self._stream_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
|
|
||||||
if stream["rows"]:
|
|
||||||
logger.info("Got %d event rows", len(stream["rows"]))
|
|
||||||
|
|
||||||
for row in stream["rows"]:
|
|
||||||
self._process_replication_row(
|
|
||||||
row, backfilled=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = result.get("backfill")
|
|
||||||
if stream:
|
|
||||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
self._process_replication_row(
|
|
||||||
row, backfilled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = result.get("forward_ex_outliers")
|
|
||||||
if stream:
|
|
||||||
self._stream_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
event_id = row[1]
|
|
||||||
self._invalidate_get_event_cache(event_id)
|
|
||||||
|
|
||||||
stream = result.get("backward_ex_outliers")
|
|
||||||
if stream:
|
|
||||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
event_id = row[1]
|
|
||||||
self._invalidate_get_event_cache(event_id)
|
|
||||||
|
|
||||||
return super(SlavedEventStore, self).process_replication(result)
|
|
||||||
|
|
||||||
def _process_replication_row(self, row, backfilled):
|
|
||||||
stream_ordering = row[0] if not backfilled else -row[0]
|
|
||||||
self.invalidate_caches_for_event(
|
self.invalidate_caches_for_event(
|
||||||
stream_ordering, row[1], row[2], row[3], row[4], row[5],
|
token, row.event_id, row.room_id, row.type, row.state_key,
|
||||||
backfilled=backfilled,
|
row.redacts,
|
||||||
|
backfilled=False,
|
||||||
|
)
|
||||||
|
elif stream_name == "backfill":
|
||||||
|
self._backfill_id_gen.advance(-token)
|
||||||
|
for row in rows:
|
||||||
|
self.invalidate_caches_for_event(
|
||||||
|
-token, row.event_id, row.room_id, row.type, row.state_key,
|
||||||
|
row.redacts,
|
||||||
|
backfilled=True,
|
||||||
|
)
|
||||||
|
return super(SlavedEventStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
||||||
|
|
|
@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||||
result["presence"] = position
|
result["presence"] = position
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("presence")
|
if stream_name == "presence":
|
||||||
if stream:
|
self._presence_id_gen.advance(token)
|
||||||
self._presence_id_gen.advance(int(stream["position"]))
|
for row in rows:
|
||||||
for row in stream["rows"]:
|
|
||||||
position, user_id = row[:2]
|
|
||||||
self.presence_stream_cache.entity_has_changed(
|
self.presence_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
|
)
|
||||||
|
self._get_presence_for_user.invalidate((row.user_id,))
|
||||||
|
return super(SlavedPresenceStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
self._get_presence_for_user.invalidate((user_id,))
|
|
||||||
|
|
||||||
return super(SlavedPresenceStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
|
||||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("push_rules")
|
if stream_name == "push_rules":
|
||||||
if stream:
|
self._push_rules_stream_id_gen.advance(token)
|
||||||
for row in stream["rows"]:
|
for row in rows:
|
||||||
position = row[0]
|
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||||
user_id = row[2]
|
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||||
self.get_push_rules_for_user.invalidate((user_id,))
|
|
||||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
|
||||||
self.push_rules_stream_cache.entity_has_changed(
|
self.push_rules_stream_cache.entity_has_changed(
|
||||||
user_id, position
|
row.user_id, token
|
||||||
|
)
|
||||||
|
return super(SlavedPushRuleStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
self._push_rules_stream_id_gen.advance(int(stream["position"]))
|
|
||||||
|
|
||||||
return super(SlavedPushRuleStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("pushers")
|
if stream_name == "pushers":
|
||||||
if stream:
|
self._pushers_id_gen.advance(token)
|
||||||
self._pushers_id_gen.advance(int(stream["position"]))
|
return super(SlavedPusherStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
stream = result.get("deleted_pushers")
|
)
|
||||||
if stream:
|
|
||||||
self._pushers_id_gen.advance(int(stream["position"]))
|
|
||||||
|
|
||||||
return super(SlavedPusherStore, self).process_replication(result)
|
|
||||||
|
|
|
@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
||||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
|
||||||
stream = result.get("receipts")
|
|
||||||
if stream:
|
|
||||||
self._receipts_id_gen.advance(int(stream["position"]))
|
|
||||||
for row in stream["rows"]:
|
|
||||||
position, room_id, receipt_type, user_id = row[:4]
|
|
||||||
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
|
||||||
self._receipts_stream_cache.entity_has_changed(room_id, position)
|
|
||||||
|
|
||||||
return super(SlavedReceiptsStore, self).process_replication(result)
|
|
||||||
|
|
||||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||||
self.get_last_receipt_event_id_for_user.invalidate(
|
self.get_last_receipt_event_id_for_user.invalidate(
|
||||||
(user_id, room_id, receipt_type)
|
(user_id, room_id, receipt_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
|
if stream_name == "receipts":
|
||||||
|
self._receipts_id_gen.advance(token)
|
||||||
|
for row in rows:
|
||||||
|
self.invalidate_caches_for_receipt(
|
||||||
|
row.room_id, row.receipt_type, row.user_id
|
||||||
|
)
|
||||||
|
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
||||||
|
|
||||||
|
return super(SlavedReceiptsStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
|
||||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
stream = result.get("public_rooms")
|
if stream_name == "public_rooms":
|
||||||
if stream:
|
self._public_room_id_gen.advance(token)
|
||||||
self._public_room_id_gen.advance(int(stream["position"]))
|
|
||||||
|
|
||||||
return super(RoomStore, self).process_replication(result)
|
return super(RoomStore, self).process_replication_rows(
|
||||||
|
stream_name, token, rows
|
||||||
|
)
|
||||||
|
|
|
@ -12,12 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, reactor
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.resource import ReplicationResource
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
from synapse.replication.tcp.client import (
|
||||||
|
ReplicationClientHandler, ReplicationClientFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseSlavedStoreTestCase(unittest.TestCase):
|
class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
|
@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||||
|
|
||||||
self.replication = ReplicationResource(self.hs)
|
|
||||||
|
|
||||||
self.master_store = self.hs.get_datastore()
|
self.master_store = self.hs.get_datastore()
|
||||||
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
|
listener = reactor.listenUNIX("\0xxx", server_factory)
|
||||||
|
self.addCleanup(listener.stopListening)
|
||||||
|
self.streamer = server_factory.streamer
|
||||||
|
|
||||||
|
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||||
|
client_factory = ReplicationClientFactory(
|
||||||
|
self.hs, "client_name", self.replication_handler
|
||||||
|
)
|
||||||
|
client_connector = reactor.connectUNIX("\0xxx", client_factory)
|
||||||
|
self.addCleanup(client_factory.stopTrying)
|
||||||
|
self.addCleanup(client_connector.disconnect)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
streams = self.slaved_store.stream_positions()
|
yield self.streamer.on_notifier_poke()
|
||||||
writer = yield self.replication.replicate(streams, 100)
|
d = self.replication_handler.await_sync("replication_test")
|
||||||
result = writer.finish()
|
self.streamer.send_sync_to_all_connections("replication_test")
|
||||||
yield self.slaved_store.process_replication(result)
|
yield d
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check(self, method, args, expected_result=None):
|
def check(self, method, args, expected_result=None):
|
||||||
|
|
Loading…
Reference in New Issue