Move database preparing code out of homserver.py into storage where it belongs
This commit is contained in:
parent
80b5470663
commit
ce55a8cc4b
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.storage import read_schema
|
from synapse.storage import prepare_database
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
@ -36,7 +36,6 @@ from daemonize import Daemonize
|
||||||
import twisted.manhole.telnet
|
import twisted.manhole.telnet
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
@ -44,22 +43,6 @@ import sys
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SCHEMAS = [
|
|
||||||
"transactions",
|
|
||||||
"pdu",
|
|
||||||
"users",
|
|
||||||
"profiles",
|
|
||||||
"presence",
|
|
||||||
"im",
|
|
||||||
"room_aliases",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Remember to update this number every time an incompatible change is made to
|
|
||||||
# database schema files, so the users will be informed on server restarts.
|
|
||||||
SCHEMA_VERSION = 3
|
|
||||||
|
|
||||||
|
|
||||||
class SynapseHomeServer(HomeServer):
|
class SynapseHomeServer(HomeServer):
|
||||||
|
|
||||||
def build_http_client(self):
|
def build_http_client(self):
|
||||||
|
@ -80,52 +63,12 @@ class SynapseHomeServer(HomeServer):
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_db_pool(self):
|
def build_db_pool(self):
|
||||||
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
return adbapi.ConnectionPool(
|
||||||
don't have to worry about overwriting existing content.
|
"sqlite3", self.get_db_name(),
|
||||||
"""
|
check_same_thread=False,
|
||||||
logging.info("Preparing database: %s...", self.db_name)
|
cp_min=1,
|
||||||
|
cp_max=1
|
||||||
with sqlite3.connect(self.db_name) as db_conn:
|
)
|
||||||
c = db_conn.cursor()
|
|
||||||
c.execute("PRAGMA user_version")
|
|
||||||
row = c.fetchone()
|
|
||||||
|
|
||||||
if row and row[0]:
|
|
||||||
user_version = row[0]
|
|
||||||
|
|
||||||
if user_version > SCHEMA_VERSION:
|
|
||||||
raise ValueError("Cannot use this database as it is too " +
|
|
||||||
"new for the server to understand"
|
|
||||||
)
|
|
||||||
elif user_version < SCHEMA_VERSION:
|
|
||||||
logging.info("Upgrading database from version %d",
|
|
||||||
user_version
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run every version since after the current version.
|
|
||||||
for v in range(user_version + 1, SCHEMA_VERSION + 1):
|
|
||||||
sql_script = read_schema("delta/v%d" % (v))
|
|
||||||
c.executescript(sql_script)
|
|
||||||
|
|
||||||
db_conn.commit()
|
|
||||||
|
|
||||||
else:
|
|
||||||
for sql_loc in SCHEMAS:
|
|
||||||
sql_script = read_schema(sql_loc)
|
|
||||||
|
|
||||||
c.executescript(sql_script)
|
|
||||||
db_conn.commit()
|
|
||||||
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
|
|
||||||
|
|
||||||
c.close()
|
|
||||||
|
|
||||||
logging.info("Database prepared in %s.", self.db_name)
|
|
||||||
|
|
||||||
pool = adbapi.ConnectionPool(
|
|
||||||
'sqlite3', self.db_name, check_same_thread=False,
|
|
||||||
cp_min=1, cp_max=1)
|
|
||||||
|
|
||||||
return pool
|
|
||||||
|
|
||||||
def create_resource_tree(self, web_client, redirect_root_to_web_client):
|
def create_resource_tree(self, web_client, redirect_root_to_web_client):
|
||||||
"""Create the resource tree for this Home Server.
|
"""Create the resource tree for this Home Server.
|
||||||
|
@ -270,6 +213,8 @@ def setup():
|
||||||
)
|
)
|
||||||
hs.start_listening(config.bind_port, config.unsecure_port)
|
hs.start_listening(config.bind_port, config.unsecure_port)
|
||||||
|
|
||||||
|
prepare_database(hs.get_db_name())
|
||||||
|
|
||||||
hs.get_db_pool()
|
hs.get_db_pool()
|
||||||
|
|
||||||
if config.manhole:
|
if config.manhole:
|
||||||
|
|
|
@ -57,6 +57,7 @@ class BaseHomeServer(object):
|
||||||
DEPENDENCIES = [
|
DEPENDENCIES = [
|
||||||
'clock',
|
'clock',
|
||||||
'http_client',
|
'http_client',
|
||||||
|
'db_name',
|
||||||
'db_pool',
|
'db_pool',
|
||||||
'persistence_service',
|
'persistence_service',
|
||||||
'replication_layer',
|
'replication_layer',
|
||||||
|
|
|
@ -43,10 +43,28 @@ from .keys import KeyStore
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMAS = [
|
||||||
|
"transactions",
|
||||||
|
"pdu",
|
||||||
|
"users",
|
||||||
|
"profiles",
|
||||||
|
"presence",
|
||||||
|
"im",
|
||||||
|
"room_aliases",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Remember to update this number every time an incompatible change is made to
|
||||||
|
# database schema files, so the users will be informed on server restarts.
|
||||||
|
SCHEMA_VERSION = 3
|
||||||
|
|
||||||
|
|
||||||
class _RollbackButIsFineException(Exception):
|
class _RollbackButIsFineException(Exception):
|
||||||
""" This exception is used to rollback a transaction without implying
|
""" This exception is used to rollback a transaction without implying
|
||||||
something went wrong.
|
something went wrong.
|
||||||
|
@ -350,3 +368,46 @@ def read_schema(schema):
|
||||||
"""
|
"""
|
||||||
with open(schema_path(schema)) as schema_file:
|
with open(schema_path(schema)) as schema_file:
|
||||||
return schema_file.read()
|
return schema_file.read()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_database(db_name):
|
||||||
|
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
||||||
|
don't have to worry about overwriting existing content.
|
||||||
|
"""
|
||||||
|
logging.info("Preparing database: %s...", db_name)
|
||||||
|
|
||||||
|
with sqlite3.connect(db_name) as db_conn:
|
||||||
|
c = db_conn.cursor()
|
||||||
|
c.execute("PRAGMA user_version")
|
||||||
|
row = c.fetchone()
|
||||||
|
|
||||||
|
if row and row[0]:
|
||||||
|
user_version = row[0]
|
||||||
|
|
||||||
|
if user_version > SCHEMA_VERSION:
|
||||||
|
raise ValueError("Cannot use this database as it is too " +
|
||||||
|
"new for the server to understand"
|
||||||
|
)
|
||||||
|
elif user_version < SCHEMA_VERSION:
|
||||||
|
logging.info("Upgrading database from version %d",
|
||||||
|
user_version
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run every version since after the current version.
|
||||||
|
for v in range(user_version + 1, SCHEMA_VERSION + 1):
|
||||||
|
sql_script = read_schema("delta/v%d" % (v))
|
||||||
|
c.executescript(sql_script)
|
||||||
|
|
||||||
|
db_conn.commit()
|
||||||
|
|
||||||
|
else:
|
||||||
|
for sql_loc in SCHEMAS:
|
||||||
|
sql_script = read_schema(sql_loc)
|
||||||
|
|
||||||
|
c.executescript(sql_script)
|
||||||
|
db_conn.commit()
|
||||||
|
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
|
||||||
|
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
logging.info("Database prepared in %s.", db_name)
|
||||||
|
|
Loading…
Reference in New Issue