actually we don't want to emulate openai

This commit is contained in:
Cyberes 2023-09-12 01:04:11 -06:00
parent 747d838138
commit 40ac84aa9a
19 changed files with 348 additions and 150 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
proxy-server.db proxy-server.db
.idea .idea
config/config.yml config/config.yml
install vllm_gptq-0.1.3-py3-none-any.whl
# ---> Python # ---> Python
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files

View File

@ -10,7 +10,8 @@ The purpose of this server is to abstract your LLM backend from your frontend AP
2. `python3 -m venv venv` 2. `python3 -m venv venv`
3. `source venv/bin/activate` 3. `source venv/bin/activate`
4. `pip install -r requirements.txt` 4. `pip install -r requirements.txt`
5. `python3 server.py` 5. `wget https://git.evulid.cc/attachments/89c87201-58b1-4e28-b8fd-d0b323c810c4 -O vllm_gptq-0.1.3-py3-none-any.whl && pip install vllm_gptq-0.1.3-py3-none-any.whl`
6. `python3 server.py`
An example systemctl service file is provided in `other/local-llm.service`. An example systemctl service file is provided in `other/local-llm.service`.

View File

@ -18,6 +18,7 @@ def init_db():
CREATE TABLE prompts ( CREATE TABLE prompts (
ip TEXT, ip TEXT,
token TEXT DEFAULT NULL, token TEXT DEFAULT NULL,
backend TEXT,
prompt TEXT, prompt TEXT,
prompt_tokens INTEGER, prompt_tokens INTEGER,
response TEXT, response TEXT,
@ -71,8 +72,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
timestamp = int(time.time()) timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
c = conn.cursor() c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit() conn.commit()
conn.close() conn.close()
@ -129,15 +130,17 @@ def average_column_for_model(table_name, column_name, model_name):
return result[0] return result[0]
def weighted_average_column_for_model(table_name, column_name, model_name, exclude_zeros: bool = False): def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False):
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f"SELECT DISTINCT model FROM {table_name}") cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}")
models = [row[0] for row in cursor.fetchall()] models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
model_averages = {} model_averages = {}
for model in models: for model, backend in models_backends:
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,)) if backend != backend_name:
continue
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend))
results = cursor.fetchall() results = cursor.fetchall()
if not results: if not results:
@ -155,11 +158,11 @@ def weighted_average_column_for_model(table_name, column_name, model_name, exclu
if total_weight == 0: if total_weight == 0:
continue continue
model_averages[model] = weighted_sum / total_weight model_averages[(model, backend)] = weighted_sum / total_weight
conn.close() conn.close()
return model_averages.get(model_name) return model_averages.get((model_name, backend_name))
def sum_column(table_name, column_name): def sum_column(table_name, column_name):

View File

@ -1,10 +1,11 @@
import requests import requests
from llm_server import opts from llm_server import opts
from pathlib import Path
def get_running_model(): def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend # TODO: cache the results for 1 min so we don't have to keep calling the backend
# TODO: only use one try/catch
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
try: try:
@ -22,11 +23,9 @@ def get_running_model():
return False, e return False, e
elif opts.mode == 'vllm': elif opts.mode == 'vllm':
try: try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl) backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
model_name = Path(r_json['data'][0]['root']).name return r_json['model'], None
# r_json['data'][0]['root'] = model_name
return model_name, None
except Exception as e: except Exception as e:
return False, e return False, e
else: else:

View File

@ -1,4 +1,4 @@
from typing import Union, Tuple from typing import Tuple, Union
class LLMBackend: class LLMBackend:
@ -10,3 +10,12 @@ class LLMBackend:
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError # raise NotImplementedError
def get_parameters(self, parameters) -> Union[dict, None]:
"""
Validate and return the parameters for this backend.
Lets you set defaults for specific backends.
:param parameters:
:return:
"""
raise NotImplementedError

View File

@ -1,15 +1,11 @@
from typing import Tuple
import requests
from flask import jsonify from flask import jsonify
from ... import opts from ..llm_backend import LLMBackend
from ...database import log_prompt from ...database import log_prompt
from ...helpers import safe_list_get from ...helpers import safe_list_get
from ...routes.cache import redis from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json from ...routes.helpers.http import validate_json
from ..llm_backend import LLMBackend
class OobaboogaLLMBackend(LLMBackend): class OobaboogaLLMBackend(LLMBackend):
@ -71,3 +67,7 @@ class OobaboogaLLMBackend(LLMBackend):
# return r_json['result'], None # return r_json['result'], None
# except Exception as e: # except Exception as e:
# return False, e # return False, e
def get_parameters(self, parameters):
del parameters['prompt']
return parameters

View File

@ -1,17 +1,14 @@
""" """
This file is used by the worker that processes requests. This file is used by the worker that processes requests.
""" """
import io
import json import json
import time import time
from uuid import uuid4 from uuid import uuid4
import requests import requests
from requests import Response
from llm_server import opts from llm_server import opts
from llm_server.database import tokenizer from llm_server.database import tokenizer
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed
@ -19,7 +16,7 @@ from llm_server.routes.cache import redis
def prepare_json(json_data: dict): def prepare_json(json_data: dict):
# logit_bias is not currently supported # logit_bias is not currently supported
del json_data['logit_bias'] # del json_data['logit_bias']
return json_data return json_data
@ -83,26 +80,26 @@ def transform_prompt_to_text(prompt: list):
def handle_blocking_request(json_data: dict): def handle_blocking_request(json_data: dict):
try: try:
r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl) r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
except Exception as e: except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}' return False, None, f'{e.__class__.__name__}: {e}'
# TODO: check for error here? # TODO: check for error here?
response_json = r.json() # response_json = r.json()
response_json['error'] = False # response_json['error'] = False
new_response = Response() # new_response = Response()
new_response.status_code = r.status_code # new_response.status_code = r.status_code
new_response._content = json.dumps(response_json).encode('utf-8') # new_response._content = json.dumps(response_json).encode('utf-8')
new_response.raw = io.BytesIO(new_response._content) # new_response.raw = io.BytesIO(new_response._content)
new_response.headers = r.headers # new_response.headers = r.headers
new_response.url = r.url # new_response.url = r.url
new_response.reason = r.reason # new_response.reason = r.reason
new_response.cookies = r.cookies # new_response.cookies = r.cookies
new_response.elapsed = r.elapsed # new_response.elapsed = r.elapsed
new_response.request = r.request # new_response.request = r.request
return True, new_response, None return True, r, None
def generate(json_data: dict): def generate(json_data: dict):

View File

@ -1,13 +1,8 @@
from pathlib import Path vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/chu-tianxiang/vllm-gptq" target="_blank">vllm-gptq</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
import requests <ul>
<li><kbd>temperature</kbd></li>
from llm_server import opts <li><kbd>top_p</kbd></li>
<li><kbd>top_k</kbd></li>
<li><kbd>max_new_tokens</kbd></li>
def get_vlmm_models_info(): </ul>"""
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
return r_json

View File

@ -1,8 +1,9 @@
from typing import Tuple
from flask import jsonify from flask import jsonify
from vllm import SamplingParams from vllm import SamplingParams
from llm_server.database import log_prompt from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json from llm_server.routes.helpers.http import validate_json
@ -22,19 +23,23 @@ class VLLMBackend(LLMBackend):
response_status_code = 0 response_status_code = 0
if response_valid_json: if response_valid_json:
backend_response = response_json_body if len(response_json_body.get('text', [])):
# Does vllm return the prompt and the response together???
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
else:
# Failsafe
backend_response = ''
if response_json_body.get('error'): # TODO: how to detect an error?
backend_err = True # if backend_response == '':
error_type = response_json_body.get('error_type') # backend_err = True
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error' # backend_response = format_sillytavern_err(
backend_response = format_sillytavern_err( # f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}', # f'HTTP CODE {response_status_code}'
f'HTTP CODE {response_status_code}' # )
)
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify(backend_response), 200 return jsonify({'results': [{'text': backend_response}]}), 200
else: else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True) log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
@ -44,13 +49,24 @@ class VLLMBackend(LLMBackend):
'results': [{'text': backend_response}] 'results': [{'text': backend_response}]
}), 200 }), 200
def validate_params(self, params_dict: dict): # def validate_params(self, params_dict: dict):
try: # default_params = SamplingParams()
sampling_params = SamplingParams(**params_dict) # try:
except ValueError as e: # sampling_params = SamplingParams(
print(e) # temperature=params_dict.get('temperature', default_params.temperature),
return False, e # top_p=params_dict.get('top_p', default_params.top_p),
return True, None # top_k=params_dict.get('top_k', default_params.top_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
# stop=params_dict.get('stopping_strings', default_params.stop),
# ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
# )
# except ValueError as e:
# print(e)
# return False, e
# return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try: # try:
@ -61,3 +77,33 @@ class VLLMBackend(LLMBackend):
# return r_json, None # return r_json, None
# except Exception as e: # except Exception as e:
# return False, e # return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
default_params = SamplingParams()
try:
sampling_params = SamplingParams(
temperature=parameters.get('temperature', default_params.temperature),
top_p=parameters.get('top_p', default_params.top_p),
top_k=parameters.get('top_k', default_params.top_k),
use_beam_search=True if parameters['num_beams'] > 1 else False,
stop=parameters.get('stopping_strings', default_params.stop),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
)
except ValueError as e:
print(e)
return None, e
return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams):
# return {
# 'temperature': params['temperature'],
# 'top_p': params['top_p'],
# 'top_k': params['top_k'],
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
# stop = parameters.get('stopping_strings', default_params.stop),
# ignore_eos = parameters.get('ban_eos_token', False),
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
# }

View File

@ -38,9 +38,9 @@ class OobaRequestHandler:
self.start_time = time.time() self.start_time = time.time()
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key') self.token = self.request.headers.get('X-Api-Key')
self.parameters = self.get_parameters()
self.priority = self.get_priority() self.priority = self.get_priority()
self.backend = self.get_backend() self.backend = self.get_backend()
self.parameters = self.parameters_invalid_msg = None
def validate_request(self) -> (bool, Union[str, None]): def validate_request(self) -> (bool, Union[str, None]):
# TODO: move this to LLMBackend # TODO: move this to LLMBackend
@ -56,19 +56,9 @@ class OobaRequestHandler:
else: else:
return self.request.remote_addr return self.request.remote_addr
def get_parameters(self): # def get_parameters(self):
# TODO: make this a LLMBackend method # # TODO: make this a LLMBackend method
request_valid_json, self.request_json_body = validate_json(self.request.data) # return self.backend.get_parameters()
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
parameters = self.request_json_body.copy()
if opts.mode in ['oobabooga', 'hf-textgen']:
del parameters['prompt']
elif opts.mode == 'vllm':
parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
else:
raise Exception
return parameters
def get_priority(self): def get_priority(self):
if self.token: if self.token:
@ -91,24 +81,26 @@ class OobaRequestHandler:
else: else:
raise Exception raise Exception
def get_parameters(self):
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self): def handle_request(self):
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters()
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
# Fix bug on text-generation-inference
# https://github.com/huggingface/text-generation-inference/issues/929
if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
self.request_json_body['typical_p'] = 0.998
if opts.mode == 'vllm':
full_model_path = redis.get('full_model_path')
if not full_model_path:
raise Exception
self.request_json_body['model'] = full_model_path.decode()
request_valid, invalid_request_err_msg = self.validate_request() request_valid, invalid_request_err_msg = self.validate_request()
params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters) if not self.parameters:
params_valid = False
else:
params_valid = True
if not request_valid or not params_valid: if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid] error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
combined_error_message = ', '.join(error_messages) combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
@ -119,21 +111,27 @@ class OobaRequestHandler:
'results': [{'text': err}] 'results': [{'text': err}]
}), 200 }), 200
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self.request_json_body.get('prompt', '')
llm_request = {**self.parameters, 'prompt': prompt}
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority) event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else: else:
# Client was rate limited # Client was rate limited
event = None event = None
if not event: if not event:
return self.handle_ratelimited() return self.handle_ratelimited()
event.wait() event.wait()
success, response, error_msg = event.data success, response, error_msg = event.data
end_time = time.time() end_time = time.time()
elapsed_time = end_time - self.start_time elapsed_time = end_time - self.start_time
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self): def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')

View File

@ -7,14 +7,7 @@ from ... import opts
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
@bp.route('/chat/completions', methods=['POST'])
def generate(): def generate():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
request_valid_json, request_json_body = validate_json(request.data) request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')): if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400

View File

@ -82,6 +82,7 @@ def generate_stats():
'online': online, 'online': online,
'endpoints': { 'endpoints': {
'blocking': f'https://{opts.base_client_api}', 'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/stream',
}, },
'queue': { 'queue': {
'processing': active_gen_workers, 'processing': active_gen_workers,
@ -104,9 +105,9 @@ def generate_stats():
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
} }
if opts.mode in ['oobabooga', 'hf-textgen']: # if opts.mode in ['oobabooga', 'hf-textgen']:
output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream' # output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
else: # else:
output['endpoints']['streaming'] = None # output['endpoints']['streaming'] = None
return deep_sort(output) return deep_sort(output)

View File

@ -4,9 +4,7 @@ from flask import jsonify, request
from . import bp from . import bp
from ..cache import cache from ..cache import cache
from ... import opts
from ...llm.info import get_running_model from ...llm.info import get_running_model
from ...llm.vllm.info import get_vlmm_models_info
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'}) # cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
@ -20,16 +18,7 @@ from ...llm.vllm.info import get_vlmm_models_info
@bp.route('/model', methods=['GET']) @bp.route('/model', methods=['GET'])
@bp.route('/models', methods=['GET'])
def get_model(): def get_model():
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
return jsonify({
'code': 404,
'error': 'this LLM backend is in VLLM mode'
}), 404
# We will manage caching ourself since we don't want to cache # We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors. # when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url cache_key = 'model_cache::' + request.url
@ -46,18 +35,10 @@ def get_model():
'type': error.__class__.__name__ 'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
if opts.mode in ['oobabooga', 'hf-texgen']: response = jsonify({
response = jsonify({ 'result': model,
'result': model, 'timestamp': int(time.time())
'timestamp': int(time.time()) }), 200
}), 200
elif opts.mode == 'vllm':
response = jsonify({
**get_vlmm_models_info(),
'timestamp': int(time.time())
}), 200
else:
raise Exception
cache.set(cache_key, response, timeout=60) cache.set(cache_key, response, timeout=60)
return response return response

View File

@ -3,7 +3,6 @@ from threading import Thread
import requests import requests
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.database import weighted_average_column_for_model from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model from llm_server.llm.info import get_running_model
@ -25,16 +24,6 @@ class MainBackgroundThread(Thread):
redis.set('backend_online', 0) redis.set('backend_online', 0)
redis.set_dict('backend_info', {}) redis.set_dict('backend_info', {})
if opts.mode == 'vllm':
while True:
try:
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
redis.set('full_model_path', r_json['data'][0]['root'])
break
except Exception as e:
print(e)
def run(self): def run(self):
while True: while True:
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
@ -77,13 +66,13 @@ class MainBackgroundThread(Thread):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now # was entered into the column. The new code enters null instead but we need to be backwards compatible for now
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, exclude_zeros=True) or 0 average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, exclude_zeros=True) or 0 average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_output_tokens', average_output_tokens) redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)

70
other/vllm-gptq-setup.py Normal file
View File

@ -0,0 +1,70 @@
import io
import os
import re
from typing import List
import setuptools
from torch.utils.cpp_extension import BuildExtension
ROOT_DIR = os.path.dirname(__file__)
"""
Build vllm-gptq without any CUDA
"""
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def find_version(filepath: str):
"""Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
def read_readme() -> str:
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements
setuptools.setup(
name="vllm-gptq",
version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team",
license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm",
project_urls={
"Homepage": "https://github.com/vllm-project/vllm",
"Documentation": "https://vllm.readthedocs.io/en/latest/",
},
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
cmdclass={"build_ext": BuildExtension},
)

94
other/vllm_api_server.py Normal file
View File

@ -0,0 +1,94 @@
import argparse
import json
import time
from pathlib import Path
from typing import AsyncGenerator
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
served_model = None
@app.get("/model")
async def generate(request: Request) -> Response:
return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
served_model = Path(args.model).name
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -9,4 +9,4 @@ redis
gevent gevent
async-timeout async-timeout
flask-sock flask-sock
vllm auto_gptq

View File

@ -11,7 +11,7 @@ from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.database import get_number_of_rows, init_db from llm_server.database import get_number_of_rows, init_db
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path
from llm_server.llm.hf_textgen.info import hf_textget_info from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
@ -20,6 +20,13 @@ from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread from llm_server.threads import MainBackgroundThread
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
print('Please see vllm.md for install instructions')
sys.exit(1)
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
@ -130,6 +137,10 @@ def home():
else: else:
info_html = '' info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
return render_template('home.html', return render_template('home.html',
llm_middleware_name=config['llm_middleware_name'], llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
@ -143,7 +154,7 @@ def home():
streaming_input_textbox=mode_ui_names[opts.mode][2], streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size, context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '', extra_info=mode_info,
) )
@ -156,5 +167,11 @@ def fallback(first=None, rest=None):
}), 404 }), 404
@app.errorhandler(500)
def server_error(e):
print(e)
return {'error': True}, 500
if __name__ == "__main__": if __name__ == "__main__":
app.run(host='0.0.0.0') app.run(host='0.0.0.0')

4
vllm.md Normal file
View File

@ -0,0 +1,4 @@
```bash
wget https://git.evulid.cc/attachments/6e7bfc04-cad4-4494-a98d-1391fbb402d3 -O vllm-0.1.3-cp311-cp311-linux_x86_64.whl && pip install vllm-0.1.3-cp311-cp311-linux_x86_64.whl
pip install auto_gptq
```