Change slave storage to use new replication interface

As the TCP replication uses a slightly different API and streams than
the HTTP replication.

This breaks HTTP replication.
This commit is contained in:
Erik Johnston 2017-03-27 16:32:07 +01:00
parent 52bfa604e1
commit 3a1f3f8388
11 changed files with 128 additions and 179 deletions

View File

@ -15,7 +15,6 @@
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
else: else:
self._cache_id_gen = None self._cache_id_gen = None
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" self.hs = hs
self.http_client = hs.get_simple_http_client()
def stream_positions(self): def stream_positions(self):
pos = {} pos = {}
@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token() pos["caches"] = self._cache_id_gen.get_current_token()
return pos return pos
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("caches") if stream_name == "caches":
if stream: self._cache_id_gen.advance(token)
for row in stream["rows"]: for row in rows:
(
position, cache_func, keys, invalidation_ts,
) = row
try: try:
getattr(self, cache_func).invalidate(tuple(keys)) getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError: except AttributeError:
# We probably haven't pulled in the cache in this worker, # We probably haven't pulled in the cache in this worker,
# which is fine. # which is fine.
pass pass
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys)
@defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys): def _send_invalidation_poke(self, cache_func, keys):
try: self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
yield self.http_client.post_json_get_json(self.expire_cache_url, {
"invalidate": [{
"name": cache_func.__name__,
"keys": list(keys),
}]
})
except:
logger.exception("Failed to poke on expire_cache")

View File

@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position result["tag_account_data"] = position
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("user_account_data") if stream_name == "tag_account_data":
if stream: self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]: self.get_tags_for_user.invalidate((row.user_id,))
position, user_id, data_type = row[:3] self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)
elif stream_name == "account_data":
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate( self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,) (row.data_type, row.user_id,)
) )
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(
user_id, position row.user_id, token
) )
return super(SlavedAccountDataStore, self).process_replication_rows(
stream = result.get("room_account_data") stream_name, token, rows
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
) )
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View File

@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
result["to_device"] = self._device_inbox_id_gen.get_current_token() result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("to_device") if stream_name == "to_device":
if stream: self._device_inbox_id_gen.advance(token)
self._device_inbox_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]: if row.entity.startswith("@"):
stream_id = row[0]
entity = row[1]
if entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed( self._device_inbox_stream_cache.entity_has_changed(
entity, stream_id row.entity, token
) )
else: else:
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
entity, stream_id row.entity, token
)
return super(SlavedDeviceInboxStore, self).process_replication_rows(
stream_name, token, rows
) )
return super(SlavedDeviceInboxStore, self).process_replication(result)

View File

@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
result["device_lists"] = self._device_list_id_gen.get_current_token() result["device_lists"] = self._device_list_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("device_lists") if stream_name == "device_lists":
if stream: self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed( self._device_list_stream_cache.entity_has_changed(
user_id, stream_id row.user_id, token
) )
if destination: if row.destination:
self._device_list_federation_stream_cache.entity_has_changed( self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id row.destination, token
)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
) )
return super(SlavedDeviceStore, self).process_replication(result)

View File

@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
result["backfill"] = -self._backfill_id_gen.get_current_token() result["backfill"] = -self._backfill_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("events") if stream_name == "events":
if stream: self._stream_id_gen.advance(token)
self._stream_id_gen.advance(int(stream["position"])) for row in rows:
if stream["rows"]:
logger.info("Got %d event rows", len(stream["rows"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False,
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True,
)
stream = result.get("forward_ex_outliers")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled):
stream_ordering = row[0] if not backfilled else -row[0]
self.invalidate_caches_for_event( self.invalidate_caches_for_event(
stream_ordering, row[1], row[2], row[3], row[4], row[5], token, row.event_id, row.room_id, row.type, row.state_key,
backfilled=backfilled, row.redacts,
backfilled=False,
)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
row.redacts,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
stream_name, token, rows
) )
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,

View File

@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
result["presence"] = position result["presence"] = position
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("presence") if stream_name == "presence":
if stream: self._presence_id_gen.advance(token)
self._presence_id_gen.advance(int(stream["position"])) for row in rows:
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed( self.presence_stream_cache.entity_has_changed(
user_id, position row.user_id, token
)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
) )
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("push_rules") if stream_name == "push_rules":
if stream: self._push_rules_stream_id_gen.advance(token)
for row in stream["rows"]: for row in rows:
position = row[0] self.get_push_rules_for_user.invalidate((row.user_id,))
user_id = row[2] self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed( self.push_rules_stream_cache.entity_has_changed(
user_id, position row.user_id, token
)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
) )
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View File

@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token() result["pushers"] = self._pushers_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("pushers") if stream_name == "pushers":
if stream: self._pushers_id_gen.advance(token)
self._pushers_id_gen.advance(int(stream["position"])) return super(SlavedPusherStore, self).process_replication_rows(
stream_name, token, rows
stream = result.get("deleted_pushers") )
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
return super(SlavedPusherStore, self).process_replication(result)

View File

@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
result["receipts"] = self._receipts_id_gen.get_current_token() result["receipts"] = self._receipts_id_gen.get_current_token()
return result return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
result["public_rooms"] = self._public_room_id_gen.get_current_token() result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication_rows(self, stream_name, token, rows):
stream = result.get("public_rooms") if stream_name == "public_rooms":
if stream: self._public_room_id_gen.advance(token)
self._public_room_id_gen.advance(int(stream["position"]))
return super(RoomStore, self).process_replication(result) return super(RoomStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from tests import unittest from tests import unittest
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.replication.tcp.client import (
ReplicationClientHandler, ReplicationClientFactory,
)
class BaseSlavedStoreTestCase(unittest.TestCase): class BaseSlavedStoreTestCase(unittest.TestCase):
@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
) )
self.hs.get_ratelimiter().send_message.return_value = (True, 0) self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0 self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
listener = reactor.listenUNIX("\0xxx", server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX("\0xxx", client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def replicate(self):
streams = self.slaved_store.stream_positions() yield self.streamer.on_notifier_poke()
writer = yield self.replication.replicate(streams, 100) d = self.replication_handler.await_sync("replication_test")
result = writer.finish() self.streamer.send_sync_to_all_connections("replication_test")
yield self.slaved_store.process_replication(result) yield d
@defer.inlineCallbacks @defer.inlineCallbacks
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):