actually we don't want to emulate openai
This commit is contained in:
parent
747d838138
commit
40ac84aa9a
|
@ -1,6 +1,7 @@
|
|||
proxy-server.db
|
||||
.idea
|
||||
config/config.yml
|
||||
install vllm_gptq-0.1.3-py3-none-any.whl
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
|
|
@ -10,7 +10,8 @@ The purpose of this server is to abstract your LLM backend from your frontend AP
|
|||
2. `python3 -m venv venv`
|
||||
3. `source venv/bin/activate`
|
||||
4. `pip install -r requirements.txt`
|
||||
5. `python3 server.py`
|
||||
5. `wget https://git.evulid.cc/attachments/89c87201-58b1-4e28-b8fd-d0b323c810c4 -O vllm_gptq-0.1.3-py3-none-any.whl && pip install vllm_gptq-0.1.3-py3-none-any.whl`
|
||||
6. `python3 server.py`
|
||||
|
||||
An example systemctl service file is provided in `other/local-llm.service`.
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ def init_db():
|
|||
CREATE TABLE prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
backend TEXT,
|
||||
prompt TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response TEXT,
|
||||
|
@ -71,8 +72,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
timestamp = int(time.time())
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
c = conn.cursor()
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
@ -129,15 +130,17 @@ def average_column_for_model(table_name, column_name, model_name):
|
|||
return result[0]
|
||||
|
||||
|
||||
def weighted_average_column_for_model(table_name, column_name, model_name, exclude_zeros: bool = False):
|
||||
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False):
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT DISTINCT model FROM {table_name}")
|
||||
models = [row[0] for row in cursor.fetchall()]
|
||||
cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}")
|
||||
models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
|
||||
|
||||
model_averages = {}
|
||||
for model in models:
|
||||
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,))
|
||||
for model, backend in models_backends:
|
||||
if backend != backend_name:
|
||||
continue
|
||||
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend))
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
|
@ -155,11 +158,11 @@ def weighted_average_column_for_model(table_name, column_name, model_name, exclu
|
|||
if total_weight == 0:
|
||||
continue
|
||||
|
||||
model_averages[model] = weighted_sum / total_weight
|
||||
model_averages[(model, backend)] = weighted_sum / total_weight
|
||||
|
||||
conn.close()
|
||||
|
||||
return model_averages.get(model_name)
|
||||
return model_averages.get((model_name, backend_name))
|
||||
|
||||
|
||||
def sum_column(table_name, column_name):
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_running_model():
|
||||
# TODO: cache the results for 1 min so we don't have to keep calling the backend
|
||||
# TODO: only use one try/catch
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
try:
|
||||
|
@ -22,11 +23,9 @@ def get_running_model():
|
|||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
model_name = Path(r_json['data'][0]['root']).name
|
||||
# r_json['data'][0]['root'] = model_name
|
||||
return model_name, None
|
||||
return r_json['model'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
|
@ -10,3 +10,12 @@ class LLMBackend:
|
|||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# raise NotImplementedError
|
||||
|
||||
def get_parameters(self, parameters) -> Union[dict, None]:
|
||||
"""
|
||||
Validate and return the parameters for this backend.
|
||||
Lets you set defaults for specific backends.
|
||||
:param parameters:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from ... import opts
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.cache import redis
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
from ...routes.helpers.http import validate_json
|
||||
from ..llm_backend import LLMBackend
|
||||
|
||||
|
||||
class OobaboogaLLMBackend(LLMBackend):
|
||||
|
@ -71,3 +67,7 @@ class OobaboogaLLMBackend(LLMBackend):
|
|||
# return r_json['result'], None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
|
||||
def get_parameters(self, parameters):
|
||||
del parameters['prompt']
|
||||
return parameters
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
|
@ -19,7 +16,7 @@ from llm_server.routes.cache import redis
|
|||
|
||||
def prepare_json(json_data: dict):
|
||||
# logit_bias is not currently supported
|
||||
del json_data['logit_bias']
|
||||
# del json_data['logit_bias']
|
||||
return json_data
|
||||
|
||||
|
||||
|
@ -83,26 +80,26 @@ def transform_prompt_to_text(prompt: list):
|
|||
|
||||
def handle_blocking_request(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl)
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
|
||||
# TODO: check for error here?
|
||||
response_json = r.json()
|
||||
response_json['error'] = False
|
||||
# response_json = r.json()
|
||||
# response_json['error'] = False
|
||||
|
||||
new_response = Response()
|
||||
new_response.status_code = r.status_code
|
||||
new_response._content = json.dumps(response_json).encode('utf-8')
|
||||
new_response.raw = io.BytesIO(new_response._content)
|
||||
new_response.headers = r.headers
|
||||
new_response.url = r.url
|
||||
new_response.reason = r.reason
|
||||
new_response.cookies = r.cookies
|
||||
new_response.elapsed = r.elapsed
|
||||
new_response.request = r.request
|
||||
# new_response = Response()
|
||||
# new_response.status_code = r.status_code
|
||||
# new_response._content = json.dumps(response_json).encode('utf-8')
|
||||
# new_response.raw = io.BytesIO(new_response._content)
|
||||
# new_response.headers = r.headers
|
||||
# new_response.url = r.url
|
||||
# new_response.reason = r.reason
|
||||
# new_response.cookies = r.cookies
|
||||
# new_response.elapsed = r.elapsed
|
||||
# new_response.request = r.request
|
||||
|
||||
return True, new_response, None
|
||||
return True, r, None
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_vlmm_models_info():
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
|
||||
r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
|
||||
return r_json
|
||||
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/chu-tianxiang/vllm-gptq" target="_blank">vllm-gptq</a> and not all Oobabooga parameters are supported.</p>
|
||||
<strong>Supported Parameters:</strong>
|
||||
<ul>
|
||||
<li><kbd>temperature</kbd></li>
|
||||
<li><kbd>top_p</kbd></li>
|
||||
<li><kbd>top_k</kbd></li>
|
||||
<li><kbd>max_new_tokens</kbd></li>
|
||||
</ul>"""
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Tuple
|
||||
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.helpers import indefinite_article
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
@ -22,19 +23,23 @@ class VLLMBackend(LLMBackend):
|
|||
response_status_code = 0
|
||||
|
||||
if response_valid_json:
|
||||
backend_response = response_json_body
|
||||
if len(response_json_body.get('text', [])):
|
||||
# Does vllm return the prompt and the response together???
|
||||
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
||||
else:
|
||||
# Failsafe
|
||||
backend_response = ''
|
||||
|
||||
if response_json_body.get('error'):
|
||||
backend_err = True
|
||||
error_type = response_json_body.get('error_type')
|
||||
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
|
||||
f'HTTP CODE {response_status_code}'
|
||||
)
|
||||
# TODO: how to detect an error?
|
||||
# if backend_response == '':
|
||||
# backend_err = True
|
||||
# backend_response = format_sillytavern_err(
|
||||
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
# f'HTTP CODE {response_status_code}'
|
||||
# )
|
||||
|
||||
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify(backend_response), 200
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
|
||||
|
@ -44,13 +49,24 @@ class VLLMBackend(LLMBackend):
|
|||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
try:
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
return False, e
|
||||
return True, None
|
||||
# def validate_params(self, params_dict: dict):
|
||||
# default_params = SamplingParams()
|
||||
# try:
|
||||
# sampling_params = SamplingParams(
|
||||
# temperature=params_dict.get('temperature', default_params.temperature),
|
||||
# top_p=params_dict.get('top_p', default_params.top_p),
|
||||
# top_k=params_dict.get('top_k', default_params.top_k),
|
||||
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
|
||||
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
|
||||
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
|
||||
# stop=params_dict.get('stopping_strings', default_params.stop),
|
||||
# ignore_eos=params_dict.get('ban_eos_token', False),
|
||||
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# print(e)
|
||||
# return False, e
|
||||
# return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
|
@ -61,3 +77,33 @@ class VLLMBackend(LLMBackend):
|
|||
# return r_json, None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
|
||||
default_params = SamplingParams()
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=parameters.get('temperature', default_params.temperature),
|
||||
top_p=parameters.get('top_p', default_params.top_p),
|
||||
top_k=parameters.get('top_k', default_params.top_k),
|
||||
use_beam_search=True if parameters['num_beams'] > 1 else False,
|
||||
stop=parameters.get('stopping_strings', default_params.stop),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
return None, e
|
||||
return vars(sampling_params), None
|
||||
|
||||
# def transform_sampling_params(params: SamplingParams):
|
||||
# return {
|
||||
# 'temperature': params['temperature'],
|
||||
# 'top_p': params['top_p'],
|
||||
# 'top_k': params['top_k'],
|
||||
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
|
||||
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
|
||||
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
|
||||
# stop = parameters.get('stopping_strings', default_params.stop),
|
||||
# ignore_eos = parameters.get('ban_eos_token', False),
|
||||
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
|
||||
# }
|
||||
|
|
|
@ -38,9 +38,9 @@ class OobaRequestHandler:
|
|||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.request.headers.get('X-Api-Key')
|
||||
self.parameters = self.get_parameters()
|
||||
self.priority = self.get_priority()
|
||||
self.backend = self.get_backend()
|
||||
self.parameters = self.parameters_invalid_msg = None
|
||||
|
||||
def validate_request(self) -> (bool, Union[str, None]):
|
||||
# TODO: move this to LLMBackend
|
||||
|
@ -56,19 +56,9 @@ class OobaRequestHandler:
|
|||
else:
|
||||
return self.request.remote_addr
|
||||
|
||||
def get_parameters(self):
|
||||
# TODO: make this a LLMBackend method
|
||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
parameters = self.request_json_body.copy()
|
||||
if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
del parameters['prompt']
|
||||
elif opts.mode == 'vllm':
|
||||
parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
|
||||
else:
|
||||
raise Exception
|
||||
return parameters
|
||||
# def get_parameters(self):
|
||||
# # TODO: make this a LLMBackend method
|
||||
# return self.backend.get_parameters()
|
||||
|
||||
def get_priority(self):
|
||||
if self.token:
|
||||
|
@ -91,24 +81,26 @@ class OobaRequestHandler:
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
def get_parameters(self):
|
||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
|
||||
def handle_request(self):
|
||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
self.get_parameters()
|
||||
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
# Fix bug on text-generation-inference
|
||||
# https://github.com/huggingface/text-generation-inference/issues/929
|
||||
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
|
||||
self.request_json_body['typical_p'] = 0.998
|
||||
|
||||
if opts.mode == 'vllm':
|
||||
full_model_path = redis.get('full_model_path')
|
||||
if not full_model_path:
|
||||
raise Exception
|
||||
self.request_json_body['model'] = full_model_path.decode()
|
||||
|
||||
request_valid, invalid_request_err_msg = self.validate_request()
|
||||
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
|
||||
if not self.parameters:
|
||||
params_valid = False
|
||||
else:
|
||||
params_valid = True
|
||||
|
||||
if not request_valid or not params_valid:
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid]
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
|
@ -119,21 +111,27 @@ class OobaRequestHandler:
|
|||
'results': [{'text': err}]
|
||||
}), 200
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
prompt = self.request_json_body.get('prompt', '')
|
||||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
||||
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
# Client was rate limited
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return self.handle_ratelimited()
|
||||
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
|
|
|
@ -7,14 +7,7 @@ from ... import opts
|
|||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
@bp.route('/chat/completions', methods=['POST'])
|
||||
def generate():
|
||||
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
|
||||
return jsonify({
|
||||
'code': 404,
|
||||
'error': 'this LLM backend is in VLLM mode'
|
||||
}), 404
|
||||
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
|
|
@ -82,6 +82,7 @@ def generate_stats():
|
|||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{opts.base_client_api}',
|
||||
'streaming': f'wss://{opts.base_client_api}/stream',
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
|
@ -104,9 +105,9 @@ def generate_stats():
|
|||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
|
||||
if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
|
||||
else:
|
||||
output['endpoints']['streaming'] = None
|
||||
# if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
# output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
|
||||
# else:
|
||||
# output['endpoints']['streaming'] = None
|
||||
|
||||
return deep_sort(output)
|
||||
|
|
|
@ -4,9 +4,7 @@ from flask import jsonify, request
|
|||
|
||||
from . import bp
|
||||
from ..cache import cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
from ...llm.vllm.info import get_vlmm_models_info
|
||||
|
||||
|
||||
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
|
||||
|
@ -20,16 +18,7 @@ from ...llm.vllm.info import get_vlmm_models_info
|
|||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
@bp.route('/models', methods=['GET'])
|
||||
def get_model():
|
||||
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
|
||||
return jsonify({
|
||||
'code': 404,
|
||||
'error': 'this LLM backend is in VLLM mode'
|
||||
}), 404
|
||||
|
||||
|
||||
|
||||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
|
@ -46,18 +35,10 @@ def get_model():
|
|||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
if opts.mode in ['oobabooga', 'hf-texgen']:
|
||||
response = jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
elif opts.mode == 'vllm':
|
||||
response = jsonify({
|
||||
**get_vlmm_models_info(),
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
else:
|
||||
raise Exception
|
||||
response = jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
|
|
@ -3,7 +3,6 @@ from threading import Thread
|
|||
|
||||
import requests
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
@ -25,16 +24,6 @@ class MainBackgroundThread(Thread):
|
|||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
if opts.mode == 'vllm':
|
||||
while True:
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
redis.set('full_model_path', r_json['data'][0]['root'])
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
if opts.mode == 'oobabooga':
|
||||
|
@ -77,13 +66,13 @@ class MainBackgroundThread(Thread):
|
|||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, exclude_zeros=True) or 0
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, exclude_zeros=True) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import setuptools
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
"""
|
||||
Build vllm-gptq without any CUDA
|
||||
"""
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
|
||||
|
||||
def find_version(filepath: str):
|
||||
"""Extract version information from the given filepath.
|
||||
|
||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||
"""
|
||||
with open(filepath) as fp:
|
||||
version_match = re.search(
|
||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
||||
if version_match:
|
||||
return version_match.group(1)
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
def read_readme() -> str:
|
||||
"""Read the README file."""
|
||||
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
|
||||
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
return requirements
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="vllm-gptq",
|
||||
version=find_version(get_path("vllm", "__init__.py")),
|
||||
author="vLLM Team",
|
||||
license="Apache 2.0",
|
||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
||||
long_description=read_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/vllm-project/vllm",
|
||||
project_urls={
|
||||
"Homepage": "https://github.com/vllm-project/vllm",
|
||||
"Documentation": "https://vllm.readthedocs.io/en/latest/",
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
packages=setuptools.find_packages(
|
||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
|
@ -0,0 +1,94 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
|
||||
served_model = None
|
||||
|
||||
|
||||
@app.get("/model")
|
||||
async def generate(request: Request) -> Response:
|
||||
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(stream_results(), background=background_tasks)
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
served_model = Path(args.model).name
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
@ -9,4 +9,4 @@ redis
|
|||
gevent
|
||||
async-timeout
|
||||
flask-sock
|
||||
vllm
|
||||
auto_gptq
|
21
server.py
21
server.py
|
@ -11,7 +11,7 @@ from llm_server import opts
|
|||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||
from llm_server.database import get_number_of_rows, init_db
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.llm.hf_textgen.info import hf_textget_info
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
||||
|
@ -20,6 +20,13 @@ from llm_server.routes.v1.generate_stats import generate_stats
|
|||
from llm_server.stream import init_socketio
|
||||
from llm_server.threads import MainBackgroundThread
|
||||
|
||||
try:
|
||||
import vllm
|
||||
except ModuleNotFoundError as e:
|
||||
print('Could not import vllm-gptq:', e)
|
||||
print('Please see vllm.md for install instructions')
|
||||
sys.exit(1)
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
|
@ -130,6 +137,10 @@ def home():
|
|||
else:
|
||||
info_html = ''
|
||||
|
||||
mode_info = ''
|
||||
if opts.mode == 'vllm':
|
||||
mode_info = vllm_info
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=config['llm_middleware_name'],
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
|
@ -143,7 +154,7 @@ def home():
|
|||
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
||||
context_size=opts.context_size,
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
|
||||
extra_info=mode_info,
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,5 +167,11 @@ def fallback(first=None, rest=None):
|
|||
}), 404
|
||||
|
||||
|
||||
@app.errorhandler(500)
|
||||
def server_error(e):
|
||||
print(e)
|
||||
return {'error': True}, 500
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0')
|
||||
|
|
Reference in New Issue