Use abstract base class to access stream IDs
This commit is contained in:
parent
f5ac4dc2d4
commit
e316bbb4c0
|
@ -31,11 +31,16 @@ from synapse.storage.receipts import ReceiptsWorkerStore
|
|||
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
receipts_id_gen = SlavedIdTracker(
|
||||
# We instansiate this first as the ReceiptsWorkerStore constructor
|
||||
# needs to be able to call get_max_receipt_stream_id
|
||||
self._receipts_id_gen = SlavedIdTracker(
|
||||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
|
||||
super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
|
||||
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
|
@ -29,21 +30,30 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def __init__(self, receipts_id_gen, db_conn, hs):
|
||||
"""
|
||||
Args:
|
||||
receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
|
||||
db_conn: Database connection
|
||||
hs (Homeserver)
|
||||
"""
|
||||
"""This is an abstract base class where subclasses must implement
|
||||
`get_max_receipt_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
# This ABCMeta metaclass ensures that we cannot be instantiated without
|
||||
# the abstract methods being implemented.
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._receipts_id_gen = receipts_id_gen
|
||||
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_max_receipt_stream_id(self):
|
||||
"""Get the current max stream ID for receipts stream
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
pass
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_users_with_read_receipts_in_room(self, room_id):
|
||||
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
||||
|
@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
}
|
||||
defer.returnValue(results)
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
class ReceiptsStore(ReceiptsWorkerStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
receipts_id_gen = StreamIdGenerator(
|
||||
# We instansiate this first as the ReceiptsWorkerStore constructor
|
||||
# needs to be able to call get_max_receipt_stream_id
|
||||
self._receipts_id_gen = StreamIdGenerator(
|
||||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
|
||||
super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
|
||||
super(ReceiptsStore, self).__init__(db_conn, hs)
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
||||
user_id):
|
||||
|
|
Loading…
Reference in New Issue