Add a Homeserver.setup method.

This is for setting up dependencies that require work on startup. This
is useful for the DataStore that wants to read a bunch from the database
before initiliazing.
This commit is contained in:
Erik Johnston 2016-01-26 15:51:06 +00:00
parent 9959d9ece8
commit 87f9477b10
9 changed files with 121 additions and 116 deletions

View File

@ -254,6 +254,17 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self):
db_conn = self.database_engine.module.connect(
**{
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string): def quit_with_error(error_string):
message_lines = error_string.split("\n") message_lines = error_string.split("\n")
@ -390,13 +401,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = database_engine.module.connect( db_conn = hs.get_db_conn()
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn) database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
@ -411,14 +416,18 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name']) logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening() hs.start_listening()
def start():
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs return hs

View File

@ -21,6 +21,7 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -28,7 +29,7 @@ from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import get_datastore
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
@ -40,6 +41,11 @@ from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
import logging
logger = logging.getLogger(__name__)
class HomeServer(object): class HomeServer(object):
"""A basic homeserver object without lazy component builders. """A basic homeserver object without lazy component builders.
@ -102,10 +108,19 @@ class HomeServer(object):
self.hostname = hostname self.hostname = hostname
self._building = {} self._building = {}
self.clock = Clock()
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
# Other kwargs are explicit dependencies # Other kwargs are explicit dependencies
for depname in kwargs: for depname in kwargs:
setattr(self, depname, kwargs[depname]) setattr(self, depname, kwargs[depname])
def setup(self):
logger.info("Setting up.")
self.datastore = get_datastore(self)
logger.info("Finished setting up.")
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type. # X-Forwarded-For is handled by our custom request type.
return request.getClientIP() return request.getClientIP()
@ -116,15 +131,9 @@ class HomeServer(object):
def is_mine_id(self, string): def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname return string.split(":", 1)[1] == self.hostname
def build_clock(self):
return Clock()
def build_replication_layer(self): def build_replication_layer(self):
return initialize_http_replication(self) return initialize_http_replication(self)
def build_datastore(self):
return DataStore(self)
def build_handlers(self): def build_handlers(self):
return Handlers(self) return Handlers(self)
@ -135,10 +144,9 @@ class HomeServer(object):
return Auth(self) return Auth(self)
def build_http_client_context_factory(self): def build_http_client_context_factory(self):
config = self.get_config()
return ( return (
InsecureInterceptableContextFactory() InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS() else BrowserLikePolicyForHTTPS()
) )
@ -157,15 +165,9 @@ class HomeServer(object):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_distributor(self):
return Distributor()
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def build_keyring(self): def build_keyring(self):
return Keyring(self) return Keyring(self)

View File

@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
import logging import logging
@ -58,6 +61,22 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120*1000 LAST_SEEN_GRANULARITY = 120*1000
def get_datastore(hs):
logger.info("getting called!")
conn = hs.get_db_conn()
try:
cur = conn.cursor()
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
min_token = min(min_token, -1)
return DataStore(conn, hs, min_token)
finally:
conn.close()
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore EventPushActionsStore
): ):
def __init__(self, hs): def __init__(self, db_conn, hs, min_stream_token):
super(DataStore, self).__init__(hs)
self.hs = hs self.hs = hs
self.min_token_deferred = self._get_min_token() self.min_stream_token = min_stream_token
self.min_token = None
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
) )
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
super(DataStore, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -15,13 +15,11 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result) defer.returnValue(result)
def cursor_to_dict(self, cursor): @staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
Args: Args:
@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore: if not or_ignore:
raise raise
@log_function @staticmethod
def _simple_insert_txn(self, txn, table, values): def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items()) keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values): @staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values: if not values:
return return
@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none, table, keyvalues, retcol, allow_none=allow_none,
) )
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, @classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False): allow_none=False):
ret = self._simple_select_onecol_txn( ret = cls._simple_select_onecol_txn(
txn, txn,
table=table, table=table,
keyvalues=keyvalues, keyvalues=keyvalues,
@ -554,7 +545,8 @@ class SQLBaseStore(object):
else: else:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s" "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % { ) % {
@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols table, keyvalues, retcols
) )
def _simple_select_list_txn(self, txn, table, keyvalues, retcols): @classmethod
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -627,7 +620,7 @@ class SQLBaseStore(object):
) )
txn.execute(sql) txn.execute(sql)
return self.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols, def _simple_select_many_batch(self, table, column, iterable, retcols,
@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results) defer.returnValue(results)
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols): @classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -699,7 +693,7 @@ class SQLBaseStore(object):
) )
txn.execute(sql, values) txn.execute(sql, values)
return self.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"): desc="_simple_update_one"):
@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues, table, keyvalues, updatevalues,
) )
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues): @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % ( update_sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols, @staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False): allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
def _simple_delete_txn(self, txn, table, keyvalues): @staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)

View File

@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return return
if backfilled: if backfilled:
if not self.min_token_deferred.called: start = self.min_stream_token - 1
yield self.min_token_deferred self.min_stream_token -= len(events_and_contexts) + 1
start = self.min_token - 1 stream_orderings = range(start, self.min_stream_token, -1)
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
@contextmanager @contextmanager
def stream_ordering_manager(): def stream_ordering_manager():
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None): is_new_state=True, current_state=None):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: self.min_stream_token -= 1
yield self.min_token_deferred stream_ordering = self.min_stream_token
self.min_token -= 1
stream_ordering = self.min_token
if stream_ordering is None: if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self) stream_ordering_manager = yield self._stream_id_gen.get_next(self)

View File

@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache() self._receipts_stream_cache = _RoomStreamChangeCache(
self._receipts_id_gen.get_max_token(None)
)
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object):
may have changed since that key. If the key is too old then the cache may have changed since that key. If the key is too old then the cache
will simply return all rooms. will simply return all rooms.
""" """
def __init__(self, size_of_cache=10000): def __init__(self, current_key, size_of_cache=10000):
self._size_of_cache = size_of_cache self._size_of_cache = size_of_cache
self._room_to_key = {} self._room_to_key = {}
self._cache = sorteddict() self._cache = sorteddict()
self._earliest_key = None self._earliest_key = current_key
self.name = "ReceiptsRoomChangeCache" self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache caches_by_name[self.name] = self._cache

View File

@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] if rows else 0 return rows[0][0] if rows else 0
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
@staticmethod @staticmethod
def _set_before_and_after(events, rows): def _set_before_and_after(events, rows):
for event, row in zip(events, rows): for event, row in zip(events, rows):

View File

@ -16,7 +16,6 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json import ujson as json
import logging import logging
@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore): class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._account_data_id_gen = StreamIdGenerator(
"account_data_max_stream_id", "stream_id"
)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream

View File

@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, table, column): def __init__(self, db_conn, table, column):
self.table = table self.table = table
self.column = column self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = None cur = db_conn.cursor()
self._current_max = self._get_or_compute_current_max(cur)
cur.close()
self._unfinished_ids = deque() self._unfinished_ids = deque()
@defer.inlineCallbacks
def get_next(self, store): def get_next(self, store):
""" """
Usage: Usage:
with yield stream_id_gen.get_next as stream_id: with yield stream_id_gen.get_next as stream_id:
# ... persist event ... # ... persist event ...
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock: with self._lock:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
defer.returnValue(manager()) return manager()
@defer.inlineCallbacks
def get_next_mult(self, store, n): def get_next_mult(self, store, n):
""" """
Usage: Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids: with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n self._current_max += n
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
defer.returnValue(manager()) return manager()
@defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1) return self._unfinished_ids[0] - 1
defer.returnValue(self._current_max) return self._current_max
def _get_or_compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
with self._lock: with self._lock: