From 03e3ec549088a1d0ddc15fd9dca14c91aea8a84b Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 20 Sep 2023 20:30:31 -0600 Subject: [PATCH] port to mysql, use vllm tokenizer endpoint --- llm_server/database.py | 191 -------------------- llm_server/database/__init__.py | 0 llm_server/database/conn.py | 25 +++ llm_server/database/create.py | 41 +++++ llm_server/database/database.py | 144 +++++++++++++++ llm_server/helpers.py | 1 - llm_server/llm/backend.py | 1 + llm_server/llm/llm_backend.py | 9 +- llm_server/llm/oobabooga/ooba_backend.py | 2 +- llm_server/llm/vllm/__init__.py | 1 + llm_server/llm/vllm/generate.py | 6 +- llm_server/llm/vllm/tokenize.py | 19 ++ llm_server/llm/vllm/vllm_backend.py | 21 ++- llm_server/opts.py | 1 - llm_server/routes/helpers/http.py | 2 +- llm_server/routes/ooba_request_handler.py | 2 +- llm_server/routes/openai_request_handler.py | 7 +- llm_server/routes/request_handler.py | 11 +- llm_server/routes/v1/generate_stats.py | 2 +- llm_server/routes/v1/generate_stream.py | 2 +- llm_server/threads.py | 2 +- other/vllm/vllm_api_server.py | 24 ++- requirements.txt | 19 +- server.py | 23 ++- 24 files changed, 326 insertions(+), 230 deletions(-) delete mode 100644 llm_server/database.py create mode 100644 llm_server/database/__init__.py create mode 100644 llm_server/database/conn.py create mode 100644 llm_server/database/create.py create mode 100644 llm_server/database/database.py create mode 100644 llm_server/llm/backend.py create mode 100644 llm_server/llm/vllm/tokenize.py diff --git a/llm_server/database.py b/llm_server/database.py deleted file mode 100644 index 41c3d0d..0000000 --- a/llm_server/database.py +++ /dev/null @@ -1,191 +0,0 @@ -import json -import sqlite3 -import time -from pathlib import Path - -import tiktoken - -from llm_server import opts - -tokenizer = tiktoken.get_encoding("cl100k_base") - - -def init_db(): - if not Path(opts.database_path).exists(): - conn = sqlite3.connect(opts.database_path) - c = conn.cursor() - c.execute(''' - CREATE TABLE prompts ( - ip TEXT, - token TEXT DEFAULT NULL, - model TEXT, - backend_mode TEXT, - backend_url TEXT, - request_url TEXT, - generation_time FLOAT, - prompt TEXT, - prompt_tokens INTEGER, - response TEXT, - response_tokens INTEGER, - response_status INTEGER, - parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), - headers TEXT CHECK (headers IS NULL OR json_valid(headers)), - timestamp INTEGER - ) - ''') - c.execute(''' - CREATE TABLE token_auth ( - token TEXT UNIQUE, - type TEXT NOT NULL, - priority INTEGER default 9999, - uses INTEGER default 0, - max_uses INTEGER, - expire INTEGER, - disabled BOOLEAN default 0 - ) - ''') - conn.commit() - conn.close() - - -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): - prompt_tokens = len(tokenizer.encode(prompt)) - - if not is_error: - if not response_tokens: - response_tokens = len(tokenizer.encode(response, disallowed_special=())) - else: - response_tokens = None - - # Sometimes we may want to insert null into the DB, but - # usually we want to insert a float. - if gen_time: - gen_time = round(gen_time, 3) - if is_error: - gen_time = None - - if not opts.log_prompts: - prompt = None - - if not opts.log_prompts and not is_error: - # TODO: test and verify this works as expected - response = None - - timestamp = int(time.time()) - conn = sqlite3.connect(opts.database_path) - c = conn.cursor() - c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - conn.commit() - conn.close() - - -def is_valid_api_key(api_key): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,)) - row = cursor.fetchone() - if row is not None: - token, uses, max_uses, expire, disabled = row - disabled = bool(disabled) - if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: - return True - return False - - -def increment_uses(api_key): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,)) - row = cursor.fetchone() - if row is not None: - cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,)) - conn.commit() - return True - return False - - -def get_number_of_rows(table_name): - conn = sqlite3.connect(opts.database_path) - cur = conn.cursor() - cur.execute(f'SELECT COUNT(*) FROM {table_name}') - result = cur.fetchone() - conn.close() - return result[0] - - -def average_column(table_name, column_name): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") - result = cursor.fetchone() - conn.close() - return result[0] - - -def average_column_for_model(table_name, column_name, model_name): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = ?", (model_name,)) - result = cursor.fetchone() - conn.close() - return result[0] - - -def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - # cursor.execute(f"SELECT DISTINCT model, backend_mode FROM {table_name}") - # models_backends = [(row[0], row[1]) for row in cursor.fetchall()] - # - # model_averages = {} - # for model, backend in models_backends: - # if backend != backend_name: - # continue - cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend_mode = ? AND backend_url = ? ORDER BY ROWID DESC", (model_name, backend_name, backend_url)) - results = cursor.fetchall() - - # if not results: - # continue - - total_weight = 0 - weighted_sum = 0 - for i, (value, rowid) in enumerate(results): - if value is None or (exclude_zeros and value == 0): - continue - weight = i + 1 - total_weight += weight - weighted_sum += weight * value - - # if total_weight == 0: - # continue - - if total_weight > 0: - # Avoid division by zero - calculated_avg = weighted_sum / total_weight - else: - calculated_avg = 0 - - conn.close() - - return calculated_avg - - -def sum_column(table_name, column_name): - conn = sqlite3.connect(opts.database_path) - cursor = conn.cursor() - cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") - result = cursor.fetchone() - conn.close() - return result[0] if result[0] else 0 - - -def get_distinct_ips_24h(): - # Get the current time and subtract 24 hours (in seconds) - past_24_hours = int(time.time()) - 24 * 60 * 60 - conn = sqlite3.connect(opts.database_path) - cur = conn.cursor() - cur.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= ?", (past_24_hours,)) - result = cur.fetchone() - conn.close() - return result[0] if result else 0 diff --git a/llm_server/database/__init__.py b/llm_server/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py new file mode 100644 index 0000000..d831990 --- /dev/null +++ b/llm_server/database/conn.py @@ -0,0 +1,25 @@ +import pymysql +from dbutils.pooled_db import PooledDB + + +class DatabaseConnection: + db_pool: PooledDB = None + + def init_db(self, host, username, password, database): + self.db_pool = PooledDB( + creator=pymysql, + maxconnections=10, + host=host, + user=username, + password=password, + database=database, + charset='utf8' + ) + conn = self.db_pool.connection() + del conn + + def connection(self): + return self.db_pool.connection() + + +db_pool = DatabaseConnection() diff --git a/llm_server/database/create.py b/llm_server/database/create.py new file mode 100644 index 0000000..dd6d458 --- /dev/null +++ b/llm_server/database/create.py @@ -0,0 +1,41 @@ +from llm_server.database.conn import db_pool + + +def create_db(): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS prompts ( + ip TEXT, + token TEXT DEFAULT NULL, + model TEXT, + backend_mode TEXT, + backend_url TEXT, + request_url TEXT, + generation_time FLOAT, + prompt TEXT, + prompt_tokens INTEGER, + response TEXT, + response_tokens INTEGER, + response_status INTEGER, + parameters TEXT, + CHECK (parameters IS NULL OR JSON_VALID(parameters)), + headers TEXT, + CHECK (headers IS NULL OR JSON_VALID(headers)), + timestamp INTEGER + ) + ''') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS token_auth ( + token TEXT, + UNIQUE (token), + type TEXT NOT NULL, + priority INTEGER DEFAULT 9999, + uses INTEGER DEFAULT 0, + max_uses INTEGER, + expire INTEGER, + disabled BOOLEAN DEFAULT 0 + ) + ''') + conn.commit() + diff --git a/llm_server/database/database.py b/llm_server/database/database.py new file mode 100644 index 0000000..583cf2b --- /dev/null +++ b/llm_server/database/database.py @@ -0,0 +1,144 @@ +import json +import time + +import llm_server +from llm_server import opts +from llm_server.database.conn import db_pool +from llm_server.llm.vllm import tokenize + + +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): + prompt_tokens = llm_server.llm.tokenizer(prompt) + + if not is_error: + if not response_tokens: + response_tokens = llm_server.llm.tokenizer(response) + else: + response_tokens = None + + # Sometimes we may want to insert null into the DB, but + # usually we want to insert a float. + if gen_time: + gen_time = round(gen_time, 3) + if is_error: + gen_time = None + + if not opts.log_prompts: + prompt = None + + if not opts.log_prompts and not is_error: + # TODO: test and verify this works as expected + response = None + + timestamp = int(time.time()) + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(""" + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + conn.commit() + + +def is_valid_api_key(api_key): + conn = db_pool.connection() + cursor = conn.cursor() + try: + cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) + row = cursor.fetchone() + if row is not None: + token, uses, max_uses, expire, disabled = row + disabled = bool(disabled) + if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: + return True + return False + finally: + conn.commit() + + +def increment_uses(api_key): + conn = db_pool.connection() + cursor = conn.cursor() + try: + cursor.execute("SELECT token FROM token_auth WHERE token = %s", (api_key,)) + row = cursor.fetchone() + if row is not None: + cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = %s", (api_key,)) + return True + return False + finally: + conn.commit() + + +def get_number_of_rows(table_name): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(f'SELECT COUNT(*) FROM {table_name}') + result = cursor.fetchone() + conn.commit() + return result[0] + + +def average_column(table_name, column_name): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") + result = cursor.fetchone() + conn.commit() + return result[0] + + +def average_column_for_model(table_name, column_name, model_name): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s", (model_name,)) + result = cursor.fetchone() + conn.commit() + return result[0] + + +def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,)) + results = cursor.fetchall() + + total_weight = 0 + weighted_sum = 0 + for i, (value, rowid) in enumerate(results): + if value is None or (exclude_zeros and value == 0): + continue + weight = i + 1 + total_weight += weight + weighted_sum += weight * value + + if total_weight > 0: + # Avoid division by zero + calculated_avg = weighted_sum / total_weight + else: + calculated_avg = 0 + + conn.commit() + return calculated_avg + + +def sum_column(table_name, column_name): + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") + result = cursor.fetchone() + conn.commit() + return result[0] if result[0] else 0 + + +def get_distinct_ips_24h(): + # Get the current time and subtract 24 hours (in seconds) + past_24_hours = int(time.time()) - 24 * 60 * 60 + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,)) + result = cursor.fetchone() + conn.commit() + return result[0] if result else 0 diff --git a/llm_server/helpers.py b/llm_server/helpers.py index bf56fac..b12bedb 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -5,7 +5,6 @@ from pathlib import Path from flask import make_response - def resolve_path(*p: str): return Path(*p).expanduser().resolve().absolute() diff --git a/llm_server/llm/backend.py b/llm_server/llm/backend.py new file mode 100644 index 0000000..cb1d405 --- /dev/null +++ b/llm_server/llm/backend.py @@ -0,0 +1 @@ +tokenizer = None diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index d6a1f25..f3d08f5 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -3,7 +3,7 @@ from typing import Tuple, Union import flask from llm_server import opts -from llm_server.database import tokenizer +from llm_server.llm.backend import tokenizer class LLMBackend: @@ -30,9 +30,8 @@ class LLMBackend: def validate_request(self, parameters: dict) -> Tuple[bool, Union[str, None]]: raise NotImplementedError - @staticmethod - def validate_prompt(prompt: str) -> Tuple[bool, Union[str, None]]: - prompt_len = len(tokenizer.encode(prompt)) - if prompt_len > opts.context_size - 10: # Our tokenizer isn't 100% accurate so we cut it down a bit. TODO: add a tokenizer endpoint to VLLM + def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: + prompt_len = len(tokenizer(prompt)) + if prompt_len > opts.context_size - 10: return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size' return True, None diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 48a7336..578f663 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -1,7 +1,7 @@ from flask import jsonify from ..llm_backend import LLMBackend -from ...database import log_prompt +from ...database.database import log_prompt from ...helpers import safe_list_get from ...routes.cache import redis from ...routes.helpers.client import format_sillytavern_err diff --git a/llm_server/llm/vllm/__init__.py b/llm_server/llm/vllm/__init__.py index e69de29..2491e2a 100644 --- a/llm_server/llm/vllm/__init__.py +++ b/llm_server/llm/vllm/__init__.py @@ -0,0 +1 @@ +from .tokenize import tokenize diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 8e580e4..f796a39 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -7,8 +7,8 @@ from uuid import uuid4 import requests +import llm_server from llm_server import opts -from llm_server.database import tokenizer # TODO: make the VLMM backend return TPS and time elapsed @@ -43,8 +43,8 @@ def transform_to_text(json_request, api_response): if data['choices'][0]['finish_reason']: finish_reason = data['choices'][0]['finish_reason'] - prompt_tokens = len(tokenizer.encode(prompt)) - completion_tokens = len(tokenizer.encode(text)) + prompt_tokens = len(llm_server.llm.tokenizer(prompt)) + completion_tokens = len(llm_server.llm.tokenizer(text)) # https://platform.openai.com/docs/api-reference/making-requests?lang=python return { diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py new file mode 100644 index 0000000..9913603 --- /dev/null +++ b/llm_server/llm/vllm/tokenize.py @@ -0,0 +1,19 @@ +import traceback + +import requests +import tiktoken + +from llm_server import opts + +tokenizer = tiktoken.get_encoding("cl100k_base") + + +def tokenize(prompt: str) -> int: + try: + r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + j = r.json() + return j['length'] + except: + # Fall back to whatever the superclass is doing. + print(traceback.format_exc()) + return len(tokenizer.encode(prompt)) + 10 diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 980ebd1..9c5af5c 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,10 +1,13 @@ +import traceback from typing import Tuple, Union +import requests from flask import jsonify from vllm import SamplingParams +import llm_server from llm_server import opts -from llm_server.database import log_prompt +from llm_server.database.database import log_prompt from llm_server.llm.llm_backend import LLMBackend @@ -45,3 +48,19 @@ class VLLMBackend(LLMBackend): if parameters.get('max_new_tokens', 0) > opts.max_new_tokens: return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}' return True, None + + # def tokenize(self, prompt): + # try: + # r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + # j = r.json() + # return j['length'] + # except: + # # Fall back to whatever the superclass is doing. + # print(traceback.format_exc()) + # return super().tokenize(prompt) + + def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: + prompt_len = llm_server.llm.tokenizer(prompt) + if prompt_len > opts.context_size: + return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size' + return True, None diff --git a/llm_server/opts.py b/llm_server/opts.py index bebea9f..72f441c 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -8,7 +8,6 @@ mode = 'oobabooga' backend_url = None context_size = 5555 max_new_tokens = 500 -database_path = './proxy-server.db' auth_required = False log_prompts = False frontend_api_client = '' diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index 8cc7a02..2ade636 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -8,7 +8,7 @@ from flask import make_response, Request from flask import request, jsonify from llm_server import opts -from llm_server.database import is_valid_api_key +from llm_server.database.database import is_valid_api_key def cache_control(seconds): diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 44b740c..d3ca482 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -4,7 +4,7 @@ import flask from flask import jsonify from llm_server import opts -from llm_server.database import log_prompt +from llm_server.database.database import log_prompt from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index e5530e3..ea54db6 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -10,8 +10,9 @@ import requests import tiktoken from flask import jsonify +import llm_server from llm_server import opts -from llm_server.database import log_prompt +from llm_server.database.database import log_prompt from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler @@ -130,8 +131,8 @@ def build_openai_response(prompt, response): if len(x) > 1: response = re.sub(r'\n$', '', y[0].strip(' ')) - prompt_tokens = len(tokenizer.encode(prompt)) - response_tokens = len(tokenizer.encode(response)) + prompt_tokens = llm_server.llm.tokenizer(prompt) + response_tokens = llm_server.llm.tokenizer(response) return jsonify({ "id": f"chatcmpl-{uuid4()}", "object": "chat.completion", diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 959b7a3..b0371e8 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -1,4 +1,3 @@ -import sqlite3 import time from typing import Tuple, Union @@ -6,7 +5,8 @@ import flask from flask import Response, request from llm_server import opts -from llm_server.database import log_prompt +from llm_server.database.conn import db_pool +from llm_server.database.database import log_prompt from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis @@ -42,13 +42,14 @@ class RequestHandler: def get_priority(self): if self.token: - conn = sqlite3.connect(opts.database_path) + conn = db_pool.connection() cursor = conn.cursor() - cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,)) + cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,)) result = cursor.fetchone() - conn.close() + if result: return result[0] + conn.commit() return DEFAULT_PRIORITY def get_parameters(self): diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index ca24b74..4a30bef 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,7 +2,7 @@ import time from datetime import datetime from llm_server import opts -from llm_server.database import get_distinct_ips_24h, sum_column +from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model from llm_server.netdata import get_power_states diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 11940c7..30e0a32 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -6,7 +6,7 @@ from flask import request from ..helpers.client import format_sillytavern_err from ... import opts -from ...database import log_prompt +from ...database.database import log_prompt from ...helpers import indefinite_article from ...stream import sock diff --git a/llm_server/threads.py b/llm_server/threads.py index f97d4d6..23da835 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -2,7 +2,7 @@ import time from threading import Thread from llm_server import opts -from llm_server.database import weighted_average_column_for_model +from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_running_model from llm_server.routes.cache import redis from llm_server.routes.v1.generate_stats import generate_stats diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py index f5b5f45..0b067fd 100644 --- a/other/vllm/vllm_api_server.py +++ b/other/vllm/vllm_api_server.py @@ -7,9 +7,10 @@ from typing import AsyncGenerator import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -24,6 +25,20 @@ async def generate(request: Request) -> Response: return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) +@app.post("/tokenize") +async def generate(request: Request) -> Response: + request_dict = await request.json() + to_tokenize = request_dict.get("input") + if not to_tokenize: + JSONResponse({'error': 'must have input field'}, status_code=400) + tokens = tokenizer.tokenize(to_tokenize) + response = {} + if request_dict.get("return", False): + response['tokens'] = tokens + response['length'] = len(tokens) + return JSONResponse(response) + + @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. @@ -82,11 +97,14 @@ if __name__ == "__main__": parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - served_model = Path(args.model).name - engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + served_model = Path(args.model).name + tokenizer = get_tokenizer(engine_args.tokenizer, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code) + uvicorn.run(app, host=args.host, port=args.port, diff --git a/requirements.txt b/requirements.txt index 59a0753..d425b5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,19 @@ -flask +flask~=2.3.3 flask_cors -pyyaml +pyyaml~=6.0.1 flask_caching -requests -tiktoken +requests~=2.31.0 +tiktoken~=0.5.0 gunicorn -redis +redis~=5.0.0 gevent async-timeout flask-sock -auto_gptq \ No newline at end of file +auto_gptq +uvicorn~=0.23.2 +fastapi~=0.103.1 +torch~=2.0.1 +urllib3 +PyMySQL~=1.1.0 +DBUtils~=3.0.3 +simplejson \ No newline at end of file diff --git a/server.py b/server.py index 3b386fc..d9d01b7 100644 --- a/server.py +++ b/server.py @@ -1,11 +1,16 @@ -import json import os import sys from pathlib import Path from threading import Thread - +import simplejson as json from flask import Flask, jsonify, render_template, request +import llm_server +from llm_server.database.conn import db_pool +from llm_server.database.create import create_db +from llm_server.database.database import get_number_of_rows +from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend +from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error @@ -19,7 +24,6 @@ except ModuleNotFoundError as e: import config from llm_server import opts from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names -from llm_server.database import get_number_of_rows, init_db from llm_server.helpers import resolve_path from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import cache, redis @@ -48,8 +52,8 @@ if not success: if config['database_path'].startswith('./'): config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) -opts.database_path = resolve_path(config['database_path']) -init_db() +db_pool.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) +create_db() if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) @@ -93,6 +97,15 @@ if config['average_generation_time_mode'] not in ['database', 'minute']: sys.exit(1) opts.average_generation_time_mode = config['average_generation_time_mode'] + +if opts.mode == 'oobabooga': + raise NotImplementedError + # llm_server.llm.tokenizer = OobaboogaBackend() +elif opts.mode == 'vllm': + llm_server.llm.tokenizer = llm_server.llm.vllm.tokenize +else: + raise Exception + # Start background processes start_workers(opts.concurrent_gens) process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)