Move database preparing code out of homserver.py into storage where it belongs

This commit is contained in:
Paul "LeoNerd" Evans 2014-09-10 15:42:15 +01:00
parent 80b5470663
commit ce55a8cc4b
3 changed files with 71 additions and 64 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import read_schema from synapse.storage import prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -36,7 +36,6 @@ from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
import logging import logging
import sqlite3
import os import os
import re import re
import sys import sys
@ -44,22 +43,6 @@ import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 3
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
@ -80,52 +63,12 @@ class SynapseHomeServer(HomeServer):
) )
def build_db_pool(self): def build_db_pool(self):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we return adbapi.ConnectionPool(
don't have to worry about overwriting existing content. "sqlite3", self.get_db_name(),
""" check_same_thread=False,
logging.info("Preparing database: %s...", self.db_name) cp_min=1,
cp_max=1
with sqlite3.connect(self.db_name) as db_conn: )
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
logging.info("Database prepared in %s.", self.db_name)
pool = adbapi.ConnectionPool(
'sqlite3', self.db_name, check_same_thread=False,
cp_min=1, cp_max=1)
return pool
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server. """Create the resource tree for this Home Server.
@ -270,6 +213,8 @@ def setup():
) )
hs.start_listening(config.bind_port, config.unsecure_port) hs.start_listening(config.bind_port, config.unsecure_port)
prepare_database(hs.get_db_name())
hs.get_db_pool() hs.get_db_pool()
if config.manhole: if config.manhole:

View File

@ -57,6 +57,7 @@ class BaseHomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
'clock', 'clock',
'http_client', 'http_client',
'db_name',
'db_pool', 'db_pool',
'persistence_service', 'persistence_service',
'replication_layer', 'replication_layer',

View File

@ -43,10 +43,28 @@ from .keys import KeyStore
import json import json
import logging import logging
import os import os
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 3
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying
something went wrong. something went wrong.
@ -350,3 +368,46 @@ def read_schema(schema):
""" """
with open(schema_path(schema)) as schema_file: with open(schema_path(schema)) as schema_file:
return schema_file.read() return schema_file.read()
def prepare_database(db_name):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content.
"""
logging.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn:
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
logging.info("Database prepared in %s.", db_name)