port to mysql, use vllm tokenizer endpoint
This commit is contained in:
parent
2d390e6268
commit
03e3ec5490
|
@ -1,191 +0,0 @@
|
|||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def init_db():
|
||||
if not Path(opts.database_path).exists():
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
c = conn.cursor()
|
||||
c.execute('''
|
||||
CREATE TABLE prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
model TEXT,
|
||||
backend_mode TEXT,
|
||||
backend_url TEXT,
|
||||
request_url TEXT,
|
||||
generation_time FLOAT,
|
||||
prompt TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response TEXT,
|
||||
response_tokens INTEGER,
|
||||
response_status INTEGER,
|
||||
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
||||
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
||||
timestamp INTEGER
|
||||
)
|
||||
''')
|
||||
c.execute('''
|
||||
CREATE TABLE token_auth (
|
||||
token TEXT UNIQUE,
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER default 9999,
|
||||
uses INTEGER default 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
disabled BOOLEAN default 0
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = len(tokenizer.encode(response, disallowed_special=()))
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
gen_time = round(gen_time, 3)
|
||||
if is_error:
|
||||
gen_time = None
|
||||
|
||||
if not opts.log_prompts:
|
||||
prompt = None
|
||||
|
||||
if not opts.log_prompts and not is_error:
|
||||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
c = conn.cursor()
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def increment_uses(api_key):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
|
||||
conn.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
|
||||
result = cur.fetchone()
|
||||
conn.close()
|
||||
return result[0]
|
||||
|
||||
|
||||
def average_column(table_name, column_name):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0]
|
||||
|
||||
|
||||
def average_column_for_model(table_name, column_name, model_name):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = ?", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0]
|
||||
|
||||
|
||||
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
# cursor.execute(f"SELECT DISTINCT model, backend_mode FROM {table_name}")
|
||||
# models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
|
||||
#
|
||||
# model_averages = {}
|
||||
# for model, backend in models_backends:
|
||||
# if backend != backend_name:
|
||||
# continue
|
||||
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend_mode = ? AND backend_url = ? ORDER BY ROWID DESC", (model_name, backend_name, backend_url))
|
||||
results = cursor.fetchall()
|
||||
|
||||
# if not results:
|
||||
# continue
|
||||
|
||||
total_weight = 0
|
||||
weighted_sum = 0
|
||||
for i, (value, rowid) in enumerate(results):
|
||||
if value is None or (exclude_zeros and value == 0):
|
||||
continue
|
||||
weight = i + 1
|
||||
total_weight += weight
|
||||
weighted_sum += weight * value
|
||||
|
||||
# if total_weight == 0:
|
||||
# continue
|
||||
|
||||
if total_weight > 0:
|
||||
# Avoid division by zero
|
||||
calculated_avg = weighted_sum / total_weight
|
||||
else:
|
||||
calculated_avg = 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return calculated_avg
|
||||
|
||||
|
||||
def sum_column(table_name, column_name):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0] if result[0] else 0
|
||||
|
||||
|
||||
def get_distinct_ips_24h():
|
||||
# Get the current time and subtract 24 hours (in seconds)
|
||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= ?", (past_24_hours,))
|
||||
result = cur.fetchone()
|
||||
conn.close()
|
||||
return result[0] if result else 0
|
|
@ -0,0 +1,25 @@
|
|||
import pymysql
|
||||
from dbutils.pooled_db import PooledDB
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
db_pool: PooledDB = None
|
||||
|
||||
def init_db(self, host, username, password, database):
|
||||
self.db_pool = PooledDB(
|
||||
creator=pymysql,
|
||||
maxconnections=10,
|
||||
host=host,
|
||||
user=username,
|
||||
password=password,
|
||||
database=database,
|
||||
charset='utf8'
|
||||
)
|
||||
conn = self.db_pool.connection()
|
||||
del conn
|
||||
|
||||
def connection(self):
|
||||
return self.db_pool.connection()
|
||||
|
||||
|
||||
db_pool = DatabaseConnection()
|
|
@ -0,0 +1,41 @@
|
|||
from llm_server.database.conn import db_pool
|
||||
|
||||
|
||||
def create_db():
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
model TEXT,
|
||||
backend_mode TEXT,
|
||||
backend_url TEXT,
|
||||
request_url TEXT,
|
||||
generation_time FLOAT,
|
||||
prompt TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response TEXT,
|
||||
response_tokens INTEGER,
|
||||
response_status INTEGER,
|
||||
parameters TEXT,
|
||||
CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
||||
headers TEXT,
|
||||
CHECK (headers IS NULL OR JSON_VALID(headers)),
|
||||
timestamp INTEGER
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_auth (
|
||||
token TEXT,
|
||||
UNIQUE (token),
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER DEFAULT 9999,
|
||||
uses INTEGER DEFAULT 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
disabled BOOLEAN DEFAULT 0
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database.conn import db_pool
|
||||
from llm_server.llm.vllm import tokenize
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
prompt_tokens = llm_server.llm.tokenizer(prompt)
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.tokenizer(response)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
gen_time = round(gen_time, 3)
|
||||
if is_error:
|
||||
gen_time = None
|
||||
|
||||
if not opts.log_prompts:
|
||||
prompt = None
|
||||
|
||||
if not opts.log_prompts and not is_error:
|
||||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
|
||||
timestamp = int(time.time())
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO prompts
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def increment_uses(api_key):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT token FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = %s", (api_key,))
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_number_of_rows(table_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'SELECT COUNT(*) FROM {table_name}')
|
||||
result = cursor.fetchone()
|
||||
conn.commit()
|
||||
return result[0]
|
||||
|
||||
|
||||
def average_column(table_name, column_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
conn.commit()
|
||||
return result[0]
|
||||
|
||||
|
||||
def average_column_for_model(table_name, column_name, model_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s", (model_name,))
|
||||
result = cursor.fetchone()
|
||||
conn.commit()
|
||||
return result[0]
|
||||
|
||||
|
||||
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,))
|
||||
results = cursor.fetchall()
|
||||
|
||||
total_weight = 0
|
||||
weighted_sum = 0
|
||||
for i, (value, rowid) in enumerate(results):
|
||||
if value is None or (exclude_zeros and value == 0):
|
||||
continue
|
||||
weight = i + 1
|
||||
total_weight += weight
|
||||
weighted_sum += weight * value
|
||||
|
||||
if total_weight > 0:
|
||||
# Avoid division by zero
|
||||
calculated_avg = weighted_sum / total_weight
|
||||
else:
|
||||
calculated_avg = 0
|
||||
|
||||
conn.commit()
|
||||
return calculated_avg
|
||||
|
||||
|
||||
def sum_column(table_name, column_name):
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
conn.commit()
|
||||
return result[0] if result[0] else 0
|
||||
|
||||
|
||||
def get_distinct_ips_24h():
|
||||
# Get the current time and subtract 24 hours (in seconds)
|
||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,))
|
||||
result = cursor.fetchone()
|
||||
conn.commit()
|
||||
return result[0] if result else 0
|
|
@ -5,7 +5,6 @@ from pathlib import Path
|
|||
|
||||
from flask import make_response
|
||||
|
||||
|
||||
def resolve_path(*p: str):
|
||||
return Path(*p).expanduser().resolve().absolute()
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
tokenizer = None
|
|
@ -3,7 +3,7 @@ from typing import Tuple, Union
|
|||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
from llm_server.llm.backend import tokenizer
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
|
@ -30,9 +30,8 @@ class LLMBackend:
|
|||
def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def validate_prompt(prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = len(tokenizer.encode(prompt))
|
||||
if prompt_len > opts.context_size - 10: # Our tokenizer isn't 100% accurate so we cut it down a bit. TODO: add a tokenizer endpoint to VLLM
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = len(tokenizer(prompt))
|
||||
if prompt_len > opts.context_size - 10:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from flask import jsonify
|
||||
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database import log_prompt
|
||||
from ...database.database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.cache import redis
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .tokenize import tokenize
|
|
@ -7,8 +7,8 @@ from uuid import uuid4
|
|||
|
||||
import requests
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
|
@ -43,8 +43,8 @@ def transform_to_text(json_request, api_response):
|
|||
if data['choices'][0]['finish_reason']:
|
||||
finish_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
completion_tokens = len(tokenizer.encode(text))
|
||||
prompt_tokens = len(llm_server.llm.tokenizer(prompt))
|
||||
completion_tokens = len(llm_server.llm.tokenizer(text))
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import traceback
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except:
|
||||
# Fall back to whatever the superclass is doing.
|
||||
print(traceback.format_exc())
|
||||
return len(tokenizer.encode(prompt)) + 10
|
|
@ -1,10 +1,13 @@
|
|||
import traceback
|
||||
from typing import Tuple, Union
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
||||
|
||||
|
@ -45,3 +48,19 @@ class VLLMBackend(LLMBackend):
|
|||
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
||||
return True, None
|
||||
|
||||
# def tokenize(self, prompt):
|
||||
# try:
|
||||
# r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
# j = r.json()
|
||||
# return j['length']
|
||||
# except:
|
||||
# # Fall back to whatever the superclass is doing.
|
||||
# print(traceback.format_exc())
|
||||
# return super().tokenize(prompt)
|
||||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = llm_server.llm.tokenizer(prompt)
|
||||
if prompt_len > opts.context_size:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -8,7 +8,6 @@ mode = 'oobabooga'
|
|||
backend_url = None
|
||||
context_size = 5555
|
||||
max_new_tokens = 500
|
||||
database_path = './proxy-server.db'
|
||||
auth_required = False
|
||||
log_prompts = False
|
||||
frontend_api_client = ''
|
||||
|
|
|
@ -8,7 +8,7 @@ from flask import make_response, Request
|
|||
from flask import request, jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import is_valid_api_key
|
||||
from llm_server.database.database import is_valid_api_key
|
||||
|
||||
|
||||
def cache_control(seconds):
|
||||
|
|
|
@ -4,7 +4,7 @@ import flask
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
|
|
@ -10,8 +10,9 @@ import requests
|
|||
import tiktoken
|
||||
from flask import jsonify
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
@ -130,8 +131,8 @@ def build_openai_response(prompt, response):
|
|||
if len(x) > 1:
|
||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
response_tokens = len(tokenizer.encode(response))
|
||||
prompt_tokens = llm_server.llm.tokenizer(prompt)
|
||||
response_tokens = llm_server.llm.tokenizer(response)
|
||||
return jsonify({
|
||||
"id": f"chatcmpl-{uuid4()}",
|
||||
"object": "chat.completion",
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import sqlite3
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
@ -6,7 +5,8 @@ import flask
|
|||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.database.conn import db_pool
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
|
@ -42,13 +42,14 @@ class RequestHandler:
|
|||
|
||||
def get_priority(self):
|
||||
if self.token:
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
conn = db_pool.connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return result[0]
|
||||
conn.commit()
|
||||
return DEFAULT_PRIORITY
|
||||
|
||||
def get_parameters(self):
|
||||
|
|
|
@ -2,7 +2,7 @@ import time
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.database.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort, round_up_base
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.netdata import get_power_states
|
||||
|
|
|
@ -6,7 +6,7 @@ from flask import request
|
|||
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...database.database import log_prompt
|
||||
from ...helpers import indefinite_article
|
||||
from ...stream import sock
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import time
|
|||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import weighted_average_column_for_model
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
|
|
@ -7,9 +7,10 @@ from typing import AsyncGenerator
|
|||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
|
@ -24,6 +25,20 @@ async def generate(request: Request) -> Response:
|
|||
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
|
||||
|
||||
|
||||
@app.post("/tokenize")
|
||||
async def generate(request: Request) -> Response:
|
||||
request_dict = await request.json()
|
||||
to_tokenize = request_dict.get("input")
|
||||
if not to_tokenize:
|
||||
JSONResponse({'error': 'must have input field'}, status_code=400)
|
||||
tokens = tokenizer.tokenize(to_tokenize)
|
||||
response = {}
|
||||
if request_dict.get("return", False):
|
||||
response['tokens'] = tokens
|
||||
response['length'] = len(tokens)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
@ -82,11 +97,14 @@ if __name__ == "__main__":
|
|||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
served_model = Path(args.model).name
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
served_model = Path(args.model).name
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
flask
|
||||
flask~=2.3.3
|
||||
flask_cors
|
||||
pyyaml
|
||||
pyyaml~=6.0.1
|
||||
flask_caching
|
||||
requests
|
||||
tiktoken
|
||||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gunicorn
|
||||
redis
|
||||
redis~=5.0.0
|
||||
gevent
|
||||
async-timeout
|
||||
flask-sock
|
||||
auto_gptq
|
||||
uvicorn~=0.23.2
|
||||
fastapi~=0.103.1
|
||||
torch~=2.0.1
|
||||
urllib3
|
||||
PyMySQL~=1.1.0
|
||||
DBUtils~=3.0.3
|
||||
simplejson
|
23
server.py
23
server.py
|
@ -1,11 +1,16 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
import llm_server
|
||||
from llm_server.database.conn import db_pool
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
|
||||
|
@ -19,7 +24,6 @@ except ModuleNotFoundError as e:
|
|||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||
from llm_server.database import get_number_of_rows, init_db
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import cache, redis
|
||||
|
@ -48,8 +52,8 @@ if not success:
|
|||
if config['database_path'].startswith('./'):
|
||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||
|
||||
opts.database_path = resolve_path(config['database_path'])
|
||||
init_db()
|
||||
db_pool.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
|
@ -93,6 +97,15 @@ if config['average_generation_time_mode'] not in ['database', 'minute']:
|
|||
sys.exit(1)
|
||||
opts.average_generation_time_mode = config['average_generation_time_mode']
|
||||
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
raise NotImplementedError
|
||||
# llm_server.llm.tokenizer = OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
llm_server.llm.tokenizer = llm_server.llm.vllm.tokenize
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
|
|
Reference in New Issue