prepare_database() on db_conn, not plain name, so we can pass in the connection from outside
This commit is contained in:
parent
2faffc52ee
commit
55397f6347
|
@ -39,6 +39,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -208,7 +209,14 @@ def setup():
|
||||||
redirect_root_to_web_client=True,
|
redirect_root_to_web_client=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_database(hs.get_db_name())
|
db_name = hs.get_db_name()
|
||||||
|
|
||||||
|
logging.info("Preparing database: %s...", db_name)
|
||||||
|
|
||||||
|
with sqlite3.connect(db_name) as db_conn:
|
||||||
|
prepare_database(db_conn)
|
||||||
|
|
||||||
|
logging.info("Database prepared in %s.", db_name)
|
||||||
|
|
||||||
hs.get_db_pool()
|
hs.get_db_pool()
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,6 @@ from .keys import KeyStore
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -370,44 +369,40 @@ def read_schema(schema):
|
||||||
return schema_file.read()
|
return schema_file.read()
|
||||||
|
|
||||||
|
|
||||||
def prepare_database(db_name):
|
def prepare_database(db_conn):
|
||||||
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
||||||
don't have to worry about overwriting existing content.
|
don't have to worry about overwriting existing content.
|
||||||
"""
|
"""
|
||||||
logging.info("Preparing database: %s...", db_name)
|
c = db_conn.cursor()
|
||||||
|
c.execute("PRAGMA user_version")
|
||||||
|
row = c.fetchone()
|
||||||
|
|
||||||
with sqlite3.connect(db_name) as db_conn:
|
if row and row[0]:
|
||||||
c = db_conn.cursor()
|
user_version = row[0]
|
||||||
c.execute("PRAGMA user_version")
|
|
||||||
row = c.fetchone()
|
|
||||||
|
|
||||||
if row and row[0]:
|
if user_version > SCHEMA_VERSION:
|
||||||
user_version = row[0]
|
raise ValueError("Cannot use this database as it is too " +
|
||||||
|
"new for the server to understand"
|
||||||
if user_version > SCHEMA_VERSION:
|
)
|
||||||
raise ValueError("Cannot use this database as it is too " +
|
elif user_version < SCHEMA_VERSION:
|
||||||
"new for the server to understand"
|
logging.info("Upgrading database from version %d",
|
||||||
)
|
user_version
|
||||||
elif user_version < SCHEMA_VERSION:
|
)
|
||||||
logging.info("Upgrading database from version %d",
|
|
||||||
user_version
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run every version since after the current version.
|
|
||||||
for v in range(user_version + 1, SCHEMA_VERSION + 1):
|
|
||||||
sql_script = read_schema("delta/v%d" % (v))
|
|
||||||
c.executescript(sql_script)
|
|
||||||
|
|
||||||
db_conn.commit()
|
|
||||||
|
|
||||||
else:
|
|
||||||
for sql_loc in SCHEMAS:
|
|
||||||
sql_script = read_schema(sql_loc)
|
|
||||||
|
|
||||||
|
# Run every version since after the current version.
|
||||||
|
for v in range(user_version + 1, SCHEMA_VERSION + 1):
|
||||||
|
sql_script = read_schema("delta/v%d" % (v))
|
||||||
c.executescript(sql_script)
|
c.executescript(sql_script)
|
||||||
|
|
||||||
db_conn.commit()
|
db_conn.commit()
|
||||||
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
|
|
||||||
|
|
||||||
c.close()
|
else:
|
||||||
|
for sql_loc in SCHEMAS:
|
||||||
|
sql_script = read_schema(sql_loc)
|
||||||
|
|
||||||
|
c.executescript(sql_script)
|
||||||
|
db_conn.commit()
|
||||||
|
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
|
||||||
|
|
||||||
|
c.close()
|
||||||
|
|
||||||
logging.info("Database prepared in %s.", db_name)
|
|
||||||
|
|
Loading…
Reference in New Issue