Fix upgrade db script

This commit is contained in:
Erik Johnston 2019-12-10 13:32:34 +00:00
parent bc5cb8bfe8
commit 67c991b78f
1 changed files with 6 additions and 27 deletions

View File

@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database") logger = logging.getLogger("update_database")
@ -35,21 +34,11 @@ logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer): class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore
def __init__(self, config, database_engine, db_conn, **kwargs): def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__( super(MockHomeserver, self).__init__(
config.server_name, config.server_name, reactor=reactor, config=config, **kwargs
reactor=reactor,
config=config,
database_engine=database_engine,
**kwargs
) )
self.database_engine = database_engine
self.db_conn = db_conn
def get_db_conn(self):
return self.db_conn
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -85,24 +74,14 @@ if __name__ == "__main__":
config = HomeServerConfig() config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "") config.parse_config_dict(hs_config, "", "")
# Create the database engine and a connection to it. # Instantiate and initialise the homeserver object.
database_engine = create_engine(config.database_config) hs = MockHomeserver(config)
db_conn = database_engine.module.connect(
**{
k: v
for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
db_conn = hs.get_db_conn()
# Update the database to the latest schema. # Update the database to the latest schema.
prepare_database(db_conn, database_engine, config=config) prepare_database(db_conn, hs.database_engine, config=config)
db_conn.commit() db_conn.commit()
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(
config, database_engine, db_conn, db_config=config.database_config,
)
# setup instantiates the store within the homeserver object. # setup instantiates the store within the homeserver object.
hs.setup() hs.setup()
store = hs.get_datastore() store = hs.get_datastore()