208 lines
6.5 KiB
Python
208 lines
6.5 KiB
Python
import logging
|
|
import os
|
|
import re
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
import mysql.connector
|
|
import mysql.connector
|
|
from mysql.connector import Error
|
|
|
|
from server import opts
|
|
from .logging import formatter
|
|
|
|
current_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
|
|
|
|
|
class DatabaseConnection:
|
|
def __init__(self, host=None, user=None, password=None, database=None):
|
|
if host:
|
|
self.host = host
|
|
else:
|
|
self.host = opts.mysql['host']
|
|
if user:
|
|
self.user = user
|
|
else:
|
|
self.user = opts.mysql['user']
|
|
if password:
|
|
self.password = password
|
|
else:
|
|
self.password = opts.mysql['password']
|
|
if database:
|
|
self.database = database
|
|
else:
|
|
self.database = opts.mysql['database']
|
|
self.connection = None
|
|
|
|
def __enter__(self):
|
|
self.connection = mysql.connector.connect(
|
|
host=self.host,
|
|
user=self.user,
|
|
password=self.password,
|
|
database=self.database
|
|
)
|
|
return self.connection
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.connection:
|
|
self.connection.close()
|
|
|
|
|
|
def test_mysql_connection() -> (bool, str):
|
|
conn = None
|
|
success = False
|
|
error = None
|
|
try:
|
|
with DatabaseConnection() as conn:
|
|
if conn.is_connected():
|
|
cursor = conn.cursor()
|
|
cursor.execute("CREATE TEMPORARY TABLE test_table(id INT)")
|
|
cursor.execute("DROP TEMPORARY TABLE test_table")
|
|
success = True
|
|
except Error as e:
|
|
success = False
|
|
error = e
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
return success, error
|
|
|
|
|
|
def init_db():
|
|
sql_script = current_dir / 'sql' / 'database.sql'
|
|
log = logging.getLogger('MAIN')
|
|
die = False
|
|
with DatabaseConnection() as conn:
|
|
f = sql_script.read_text()
|
|
cursor = conn.cursor()
|
|
for statement in f.split(';'):
|
|
if statement.strip() != '':
|
|
try:
|
|
cursor.execute(statement)
|
|
except Exception as e:
|
|
log.fatal(f'failed to execute setup SQL. {e.__class__.__name__} - {e}')
|
|
die = True
|
|
if die:
|
|
log.fatal('The setup SQL failed to run. Please erase the existing tables and either re-run the program or execute the SQL script manually.')
|
|
quit(1)
|
|
conn.commit()
|
|
|
|
|
|
def check_if_database_exists(partial: bool = False):
|
|
# Get the tables that should be in the DB based on the creation SQL script
|
|
pattern = re.compile(r'^CREATE TABLE `(.*?)`$')
|
|
sql_script = current_dir / 'sql' / 'database.sql'
|
|
should_exist = []
|
|
for i, line in enumerate(open(sql_script)):
|
|
for match in re.finditer(pattern, line):
|
|
should_exist.append(match.group(1))
|
|
|
|
with DatabaseConnection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("show tables;")
|
|
result = cursor.fetchall()
|
|
if not len(result):
|
|
# No tables in DB
|
|
return False, should_exist
|
|
missing_tables = []
|
|
if partial:
|
|
for s in should_exist:
|
|
found = False
|
|
for table in result:
|
|
t = table[0]
|
|
if s == t:
|
|
found = True
|
|
continue
|
|
if not found:
|
|
missing_tables.append(s)
|
|
return (len(missing_tables) == 0), missing_tables
|
|
|
|
|
|
def get_console_logger(logger: logging.Logger = None, debug: bool = False, stream_handler: bool = False):
|
|
"""
|
|
Sometimes we need a console logger.
|
|
You can pass your own logger to add a console handler to, or get a new one.
|
|
"""
|
|
if not logger:
|
|
logger = logging.getLogger('MAIN')
|
|
if debug:
|
|
logger.setLevel(logging.DEBUG)
|
|
else:
|
|
logger.setLevel(logging.INFO)
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(formatter)
|
|
|
|
if stream_handler:
|
|
return console_handler
|
|
else:
|
|
logger.addHandler(console_handler)
|
|
return logger
|
|
|
|
|
|
def db_logger(name, table, job_id: str = None, level: int = None, console: bool = False):
|
|
"""
|
|
Log to the database and the console.
|
|
"""
|
|
logger = logging.getLogger(name)
|
|
if not level:
|
|
if opts.verbose:
|
|
logger.setLevel(logging.DEBUG)
|
|
else:
|
|
logger.setLevel(logging.INFO)
|
|
else:
|
|
logger.setLevel(level)
|
|
|
|
# Database handler
|
|
db_handler = MySQLHandler(name, table, job_id)
|
|
db_handler.setFormatter(formatter)
|
|
logger.addHandler(db_handler)
|
|
|
|
if console:
|
|
console_handler = get_console_logger(logger, opts.verbose, stream_handler=True)
|
|
logger.addHandler(console_handler)
|
|
return logger
|
|
|
|
|
|
class MySQLHandler(logging.Handler):
|
|
def __init__(self, name, table, job_id: str = None):
|
|
logging.Handler.__init__(self)
|
|
self.name = name
|
|
self.job_id = job_id
|
|
if table not in ['logs', 'jobs']:
|
|
raise ValueError(f'table value must be `logs` or `jobs`, not {table}')
|
|
self.table = table
|
|
self.executor = ThreadPoolExecutor(max_workers=5)
|
|
|
|
def emit(self, record):
|
|
self.executor.submit(self._emit, record)
|
|
|
|
def _emit(self, record):
|
|
with DatabaseConnection() as conn:
|
|
cursor = conn.cursor()
|
|
if self.table == 'logs':
|
|
cursor.execute(
|
|
"INSERT INTO logging_logs (level, name, time, message) VALUES (%s, %s, %i, %s)",
|
|
(self.name, record.levelname, record.created, record.getMessage())
|
|
)
|
|
elif self.table == 'jobs':
|
|
cursor.execute(
|
|
"INSERT INTO logging_job_output (job_id, name, level, time, message) VALUES (%s, %s, %s, %i, %s)",
|
|
(self.job_id, self.name, record.levelname, record.created, record.getMessage())
|
|
)
|
|
else:
|
|
raise ValueError
|
|
conn.commit()
|
|
|
|
|
|
def query(query_str: str, values: tuple, commit: bool = False, dictionary: bool = False):
|
|
with DatabaseConnection() as conn:
|
|
cursor = conn.cursor(dictionary=dictionary)
|
|
if values:
|
|
cursor.execute(query_str, values)
|
|
else:
|
|
cursor.execute(query_str)
|
|
if commit or query_str.startswith('INSERT') or query_str.startswith('UPDATE'):
|
|
conn.commit()
|
|
else:
|
|
return cursor.fetchall()
|