diff --git a/.gitignore b/.gitignore
index d12d30d..ec4f14b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
proxy-server.db
.idea
config/config.yml
+install vllm_gptq-0.1.3-py3-none-any.whl
# ---> Python
# Byte-compiled / optimized / DLL files
diff --git a/README.md b/README.md
index b6abb5f..73af42c 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,8 @@ The purpose of this server is to abstract your LLM backend from your frontend AP
2. `python3 -m venv venv`
3. `source venv/bin/activate`
4. `pip install -r requirements.txt`
-5. `python3 server.py`
+5. `wget https://git.evulid.cc/attachments/89c87201-58b1-4e28-b8fd-d0b323c810c4 -O vllm_gptq-0.1.3-py3-none-any.whl && pip install vllm_gptq-0.1.3-py3-none-any.whl`
+6. `python3 server.py`
An example systemctl service file is provided in `other/local-llm.service`.
diff --git a/llm_server/database.py b/llm_server/database.py
index e0bb1a5..e59eb55 100644
--- a/llm_server/database.py
+++ b/llm_server/database.py
@@ -18,6 +18,7 @@ def init_db():
CREATE TABLE prompts (
ip TEXT,
token TEXT DEFAULT NULL,
+ backend TEXT,
prompt TEXT,
prompt_tokens INTEGER,
response TEXT,
@@ -71,8 +72,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path)
c = conn.cursor()
- c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
+ c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit()
conn.close()
@@ -129,15 +130,17 @@ def average_column_for_model(table_name, column_name, model_name):
return result[0]
-def weighted_average_column_for_model(table_name, column_name, model_name, exclude_zeros: bool = False):
+def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
- cursor.execute(f"SELECT DISTINCT model FROM {table_name}")
- models = [row[0] for row in cursor.fetchall()]
+ cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}")
+ models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
model_averages = {}
- for model in models:
- cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,))
+ for model, backend in models_backends:
+ if backend != backend_name:
+ continue
+ cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend))
results = cursor.fetchall()
if not results:
@@ -155,11 +158,11 @@ def weighted_average_column_for_model(table_name, column_name, model_name, exclu
if total_weight == 0:
continue
- model_averages[model] = weighted_sum / total_weight
+ model_averages[(model, backend)] = weighted_sum / total_weight
conn.close()
- return model_averages.get(model_name)
+ return model_averages.get((model_name, backend_name))
def sum_column(table_name, column_name):
diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py
index 9b456e0..4121d3e 100644
--- a/llm_server/llm/info.py
+++ b/llm_server/llm/info.py
@@ -1,10 +1,11 @@
import requests
from llm_server import opts
-from pathlib import Path
+
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
+ # TODO: only use one try/catch
if opts.mode == 'oobabooga':
try:
@@ -22,11 +23,9 @@ def get_running_model():
return False, e
elif opts.mode == 'vllm':
try:
- backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
+ backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl)
r_json = backend_response.json()
- model_name = Path(r_json['data'][0]['root']).name
- # r_json['data'][0]['root'] = model_name
- return model_name, None
+ return r_json['model'], None
except Exception as e:
return False, e
else:
diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py
index 6285b1d..7302728 100644
--- a/llm_server/llm/llm_backend.py
+++ b/llm_server/llm/llm_backend.py
@@ -1,4 +1,4 @@
-from typing import Union, Tuple
+from typing import Tuple, Union
class LLMBackend:
@@ -10,3 +10,12 @@ class LLMBackend:
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError
+
+ def get_parameters(self, parameters) -> Union[dict, None]:
+ """
+ Validate and return the parameters for this backend.
+ Lets you set defaults for specific backends.
+ :param parameters:
+ :return:
+ """
+ raise NotImplementedError
diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py
index a5d6e69..ee3a7d6 100644
--- a/llm_server/llm/oobabooga/ooba_backend.py
+++ b/llm_server/llm/oobabooga/ooba_backend.py
@@ -1,15 +1,11 @@
-from typing import Tuple
-
-import requests
from flask import jsonify
-from ... import opts
+from ..llm_backend import LLMBackend
from ...database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
-from ..llm_backend import LLMBackend
class OobaboogaLLMBackend(LLMBackend):
@@ -71,3 +67,7 @@ class OobaboogaLLMBackend(LLMBackend):
# return r_json['result'], None
# except Exception as e:
# return False, e
+
+ def get_parameters(self, parameters):
+ del parameters['prompt']
+ return parameters
diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py
index 1a0bf92..2e1267c 100644
--- a/llm_server/llm/vllm/generate.py
+++ b/llm_server/llm/vllm/generate.py
@@ -1,17 +1,14 @@
"""
This file is used by the worker that processes requests.
"""
-import io
import json
import time
from uuid import uuid4
import requests
-from requests import Response
from llm_server import opts
from llm_server.database import tokenizer
-from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
@@ -19,7 +16,7 @@ from llm_server.routes.cache import redis
def prepare_json(json_data: dict):
# logit_bias is not currently supported
- del json_data['logit_bias']
+ # del json_data['logit_bias']
return json_data
@@ -83,26 +80,26 @@ def transform_prompt_to_text(prompt: list):
def handle_blocking_request(json_data: dict):
try:
- r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl)
+ r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
# TODO: check for error here?
- response_json = r.json()
- response_json['error'] = False
+ # response_json = r.json()
+ # response_json['error'] = False
- new_response = Response()
- new_response.status_code = r.status_code
- new_response._content = json.dumps(response_json).encode('utf-8')
- new_response.raw = io.BytesIO(new_response._content)
- new_response.headers = r.headers
- new_response.url = r.url
- new_response.reason = r.reason
- new_response.cookies = r.cookies
- new_response.elapsed = r.elapsed
- new_response.request = r.request
+ # new_response = Response()
+ # new_response.status_code = r.status_code
+ # new_response._content = json.dumps(response_json).encode('utf-8')
+ # new_response.raw = io.BytesIO(new_response._content)
+ # new_response.headers = r.headers
+ # new_response.url = r.url
+ # new_response.reason = r.reason
+ # new_response.cookies = r.cookies
+ # new_response.elapsed = r.elapsed
+ # new_response.request = r.request
- return True, new_response, None
+ return True, r, None
def generate(json_data: dict):
diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py
index d83de8b..e873a30 100644
--- a/llm_server/llm/vllm/info.py
+++ b/llm_server/llm/vllm/info.py
@@ -1,13 +1,8 @@
-from pathlib import Path
-
-import requests
-
-from llm_server import opts
-
-
-def get_vlmm_models_info():
- backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
- r_json = backend_response.json()
- r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
- r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
- return r_json
+vllm_info = """
Important: This endpoint is running vllm-gptq and not all Oobabooga parameters are supported.
+Supported Parameters:
+
+- temperature
+- top_p
+- top_k
+- max_new_tokens
+
"""
\ No newline at end of file
diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py
index afbab33..86e8129 100644
--- a/llm_server/llm/vllm/vllm_backend.py
+++ b/llm_server/llm/vllm/vllm_backend.py
@@ -1,8 +1,9 @@
+from typing import Tuple
+
from flask import jsonify
from vllm import SamplingParams
from llm_server.database import log_prompt
-from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
@@ -22,19 +23,23 @@ class VLLMBackend(LLMBackend):
response_status_code = 0
if response_valid_json:
- backend_response = response_json_body
+ if len(response_json_body.get('text', [])):
+ # Does vllm return the prompt and the response together???
+ backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
+ else:
+ # Failsafe
+ backend_response = ''
- if response_json_body.get('error'):
- backend_err = True
- error_type = response_json_body.get('error_type')
- error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
- backend_response = format_sillytavern_err(
- f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
- f'HTTP CODE {response_status_code}'
- )
+ # TODO: how to detect an error?
+ # if backend_response == '':
+ # backend_err = True
+ # backend_response = format_sillytavern_err(
+ # f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
+ # f'HTTP CODE {response_status_code}'
+ # )
- log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
- return jsonify(backend_response), 200
+ log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
+ return jsonify({'results': [{'text': backend_response}]}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
@@ -44,13 +49,24 @@ class VLLMBackend(LLMBackend):
'results': [{'text': backend_response}]
}), 200
- def validate_params(self, params_dict: dict):
- try:
- sampling_params = SamplingParams(**params_dict)
- except ValueError as e:
- print(e)
- return False, e
- return True, None
+ # def validate_params(self, params_dict: dict):
+ # default_params = SamplingParams()
+ # try:
+ # sampling_params = SamplingParams(
+ # temperature=params_dict.get('temperature', default_params.temperature),
+ # top_p=params_dict.get('top_p', default_params.top_p),
+ # top_k=params_dict.get('top_k', default_params.top_k),
+ # use_beam_search=True if params_dict['num_beams'] > 1 else False,
+ # length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
+ # early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
+ # stop=params_dict.get('stopping_strings', default_params.stop),
+ # ignore_eos=params_dict.get('ban_eos_token', False),
+ # max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
+ # )
+ # except ValueError as e:
+ # print(e)
+ # return False, e
+ # return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
@@ -61,3 +77,33 @@ class VLLMBackend(LLMBackend):
# return r_json, None
# except Exception as e:
# return False, e
+
+ def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
+ default_params = SamplingParams()
+ try:
+ sampling_params = SamplingParams(
+ temperature=parameters.get('temperature', default_params.temperature),
+ top_p=parameters.get('top_p', default_params.top_p),
+ top_k=parameters.get('top_k', default_params.top_k),
+ use_beam_search=True if parameters['num_beams'] > 1 else False,
+ stop=parameters.get('stopping_strings', default_params.stop),
+ ignore_eos=parameters.get('ban_eos_token', False),
+ max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
+ )
+ except ValueError as e:
+ print(e)
+ return None, e
+ return vars(sampling_params), None
+
+# def transform_sampling_params(params: SamplingParams):
+# return {
+# 'temperature': params['temperature'],
+# 'top_p': params['top_p'],
+# 'top_k': params['top_k'],
+# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
+# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
+# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
+# stop = parameters.get('stopping_strings', default_params.stop),
+# ignore_eos = parameters.get('ban_eos_token', False),
+# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
+# }
diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py
index 8ea7356..7c9a962 100644
--- a/llm_server/routes/request_handler.py
+++ b/llm_server/routes/request_handler.py
@@ -38,9 +38,9 @@ class OobaRequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
- self.parameters = self.get_parameters()
self.priority = self.get_priority()
self.backend = self.get_backend()
+ self.parameters = self.parameters_invalid_msg = None
def validate_request(self) -> (bool, Union[str, None]):
# TODO: move this to LLMBackend
@@ -56,19 +56,9 @@ class OobaRequestHandler:
else:
return self.request.remote_addr
- def get_parameters(self):
- # TODO: make this a LLMBackend method
- request_valid_json, self.request_json_body = validate_json(self.request.data)
- if not request_valid_json:
- return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
- parameters = self.request_json_body.copy()
- if opts.mode in ['oobabooga', 'hf-textgen']:
- del parameters['prompt']
- elif opts.mode == 'vllm':
- parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
- else:
- raise Exception
- return parameters
+ # def get_parameters(self):
+ # # TODO: make this a LLMBackend method
+ # return self.backend.get_parameters()
def get_priority(self):
if self.token:
@@ -91,24 +81,26 @@ class OobaRequestHandler:
else:
raise Exception
+ def get_parameters(self):
+ self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
+
def handle_request(self):
+ request_valid_json, self.request_json_body = validate_json(self.request.data)
+ if not request_valid_json:
+ return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
+
+ self.get_parameters()
+
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
- # Fix bug on text-generation-inference
- # https://github.com/huggingface/text-generation-inference/issues/929
- if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998:
- self.request_json_body['typical_p'] = 0.998
-
- if opts.mode == 'vllm':
- full_model_path = redis.get('full_model_path')
- if not full_model_path:
- raise Exception
- self.request_json_body['model'] = full_model_path.decode()
-
request_valid, invalid_request_err_msg = self.validate_request()
- params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters)
+ if not self.parameters:
+ params_valid = False
+ else:
+ params_valid = True
+
if not request_valid or not params_valid:
- error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid]
+ error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
@@ -119,21 +111,27 @@ class OobaRequestHandler:
'results': [{'text': err}]
}), 200
+ # Reconstruct the request JSON with the validated parameters and prompt.
+ prompt = self.request_json_body.get('prompt', '')
+ llm_request = {**self.parameters, 'prompt': prompt}
+
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
- event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
+ event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority)
else:
# Client was rate limited
event = None
if not event:
return self.handle_ratelimited()
+
event.wait()
success, response, error_msg = event.data
end_time = time.time()
elapsed_time = end_time - self.start_time
- return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
+
+ return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py
index c6c3915..0f42ccf 100644
--- a/llm_server/routes/v1/generate.py
+++ b/llm_server/routes/v1/generate.py
@@ -7,14 +7,7 @@ from ... import opts
@bp.route('/generate', methods=['POST'])
-@bp.route('/chat/completions', methods=['POST'])
def generate():
- if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
- return jsonify({
- 'code': 404,
- 'error': 'this LLM backend is in VLLM mode'
- }), 404
-
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py
index f3117be..7e09d6f 100644
--- a/llm_server/routes/v1/generate_stats.py
+++ b/llm_server/routes/v1/generate_stats.py
@@ -82,6 +82,7 @@ def generate_stats():
'online': online,
'endpoints': {
'blocking': f'https://{opts.base_client_api}',
+ 'streaming': f'wss://{opts.base_client_api}/stream',
},
'queue': {
'processing': active_gen_workers,
@@ -104,9 +105,9 @@ def generate_stats():
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
- if opts.mode in ['oobabooga', 'hf-textgen']:
- output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
- else:
- output['endpoints']['streaming'] = None
+ # if opts.mode in ['oobabooga', 'hf-textgen']:
+ # output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
+ # else:
+ # output['endpoints']['streaming'] = None
return deep_sort(output)
diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py
index 56d5e5e..26bb2e3 100644
--- a/llm_server/routes/v1/info.py
+++ b/llm_server/routes/v1/info.py
@@ -4,9 +4,7 @@ from flask import jsonify, request
from . import bp
from ..cache import cache
-from ... import opts
from ...llm.info import get_running_model
-from ...llm.vllm.info import get_vlmm_models_info
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
@@ -20,16 +18,7 @@ from ...llm.vllm.info import get_vlmm_models_info
@bp.route('/model', methods=['GET'])
-@bp.route('/models', methods=['GET'])
def get_model():
- if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
- return jsonify({
- 'code': 404,
- 'error': 'this LLM backend is in VLLM mode'
- }), 404
-
-
-
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@@ -46,18 +35,10 @@ def get_model():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
- if opts.mode in ['oobabooga', 'hf-texgen']:
- response = jsonify({
- 'result': model,
- 'timestamp': int(time.time())
- }), 200
- elif opts.mode == 'vllm':
- response = jsonify({
- **get_vlmm_models_info(),
- 'timestamp': int(time.time())
- }), 200
- else:
- raise Exception
+ response = jsonify({
+ 'result': model,
+ 'timestamp': int(time.time())
+ }), 200
cache.set(cache_key, response, timeout=60)
return response
diff --git a/llm_server/threads.py b/llm_server/threads.py
index b6e6132..5ce9c0e 100644
--- a/llm_server/threads.py
+++ b/llm_server/threads.py
@@ -3,7 +3,6 @@ from threading import Thread
import requests
-import llm_server
from llm_server import opts
from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
@@ -25,16 +24,6 @@ class MainBackgroundThread(Thread):
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
- if opts.mode == 'vllm':
- while True:
- try:
- backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
- r_json = backend_response.json()
- redis.set('full_model_path', r_json['data'][0]['root'])
- break
- except Exception as e:
- print(e)
-
def run(self):
while True:
if opts.mode == 'oobabooga':
@@ -77,13 +66,13 @@ class MainBackgroundThread(Thread):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now
- average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, exclude_zeros=True) or 0
+ average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
- average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, exclude_zeros=True) or 0
+ average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
diff --git a/other/vllm-gptq-setup.py b/other/vllm-gptq-setup.py
new file mode 100644
index 0000000..dd1b250
--- /dev/null
+++ b/other/vllm-gptq-setup.py
@@ -0,0 +1,70 @@
+import io
+import os
+import re
+from typing import List
+
+import setuptools
+from torch.utils.cpp_extension import BuildExtension
+
+ROOT_DIR = os.path.dirname(__file__)
+
+"""
+Build vllm-gptq without any CUDA
+"""
+
+
+def get_path(*filepath) -> str:
+ return os.path.join(ROOT_DIR, *filepath)
+
+
+def find_version(filepath: str):
+ """Extract version information from the given filepath.
+
+ Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
+ """
+ with open(filepath) as fp:
+ version_match = re.search(
+ r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
+ if version_match:
+ return version_match.group(1)
+ raise RuntimeError("Unable to find version string.")
+
+
+def read_readme() -> str:
+ """Read the README file."""
+ return io.open(get_path("README.md"), "r", encoding="utf-8").read()
+
+
+def get_requirements() -> List[str]:
+ """Get Python package dependencies from requirements.txt."""
+ with open(get_path("requirements.txt")) as f:
+ requirements = f.read().strip().split("\n")
+ return requirements
+
+
+setuptools.setup(
+ name="vllm-gptq",
+ version=find_version(get_path("vllm", "__init__.py")),
+ author="vLLM Team",
+ license="Apache 2.0",
+ description="A high-throughput and memory-efficient inference and serving engine for LLMs",
+ long_description=read_readme(),
+ long_description_content_type="text/markdown",
+ url="https://github.com/vllm-project/vllm",
+ project_urls={
+ "Homepage": "https://github.com/vllm-project/vllm",
+ "Documentation": "https://vllm.readthedocs.io/en/latest/",
+ },
+ classifiers=[
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "License :: OSI Approved :: Apache Software License",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+ packages=setuptools.find_packages(
+ exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
+ python_requires=">=3.8",
+ install_requires=get_requirements(),
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/other/vllm_api_server.py b/other/vllm_api_server.py
new file mode 100644
index 0000000..f5b5f45
--- /dev/null
+++ b/other/vllm_api_server.py
@@ -0,0 +1,94 @@
+import argparse
+import json
+import time
+from pathlib import Path
+from typing import AsyncGenerator
+
+import uvicorn
+from fastapi import BackgroundTasks, FastAPI, Request
+from fastapi.responses import JSONResponse, Response, StreamingResponse
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.sampling_params import SamplingParams
+from vllm.utils import random_uuid
+
+TIMEOUT_KEEP_ALIVE = 5 # seconds.
+TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
+app = FastAPI()
+
+served_model = None
+
+
+@app.get("/model")
+async def generate(request: Request) -> Response:
+ return JSONResponse({'model': served_model, 'timestamp': int(time.time())})
+
+
+@app.post("/generate")
+async def generate(request: Request) -> Response:
+ """Generate completion for the request.
+
+ The request should be a JSON object with the following fields:
+ - prompt: the prompt to use for the generation.
+ - stream: whether to stream the results or not.
+ - other fields: the sampling parameters (See `SamplingParams` for details).
+ """
+ request_dict = await request.json()
+ prompt = request_dict.pop("prompt")
+ stream = request_dict.pop("stream", False)
+ sampling_params = SamplingParams(**request_dict)
+ request_id = random_uuid()
+ results_generator = engine.generate(prompt, sampling_params, request_id)
+
+ # Streaming case
+ async def stream_results() -> AsyncGenerator[bytes, None]:
+ async for request_output in results_generator:
+ prompt = request_output.prompt
+ text_outputs = [
+ prompt + output.text for output in request_output.outputs
+ ]
+ ret = {"text": text_outputs}
+ yield (json.dumps(ret) + "\0").encode("utf-8")
+
+ async def abort_request() -> None:
+ await engine.abort(request_id)
+
+ if stream:
+ background_tasks = BackgroundTasks()
+ # Abort the request if the client disconnects.
+ background_tasks.add_task(abort_request)
+ return StreamingResponse(stream_results(), background=background_tasks)
+
+ # Non-streaming case
+ final_output = None
+ async for request_output in results_generator:
+ if await request.is_disconnected():
+ # Abort the request if the client disconnects.
+ await engine.abort(request_id)
+ return Response(status_code=499)
+ final_output = request_output
+
+ assert final_output is not None
+ prompt = final_output.prompt
+ text_outputs = [prompt + output.text for output in final_output.outputs]
+ ret = {"text": text_outputs}
+ return JSONResponse(ret)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=8000)
+ parser = AsyncEngineArgs.add_cli_args(parser)
+ args = parser.parse_args()
+
+ served_model = Path(args.model).name
+
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
+
+ uvicorn.run(app,
+ host=args.host,
+ port=args.port,
+ log_level="debug",
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
diff --git a/requirements.txt b/requirements.txt
index 691f4ef..59a0753 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,4 +9,4 @@ redis
gevent
async-timeout
flask-sock
-vllm
\ No newline at end of file
+auto_gptq
\ No newline at end of file
diff --git a/server.py b/server.py
index 7a9310e..fa8c97c 100644
--- a/server.py
+++ b/server.py
@@ -11,7 +11,7 @@ from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.database import get_number_of_rows, init_db
from llm_server.helpers import resolve_path
-from llm_server.llm.hf_textgen.info import hf_textget_info
+from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
@@ -20,6 +20,13 @@ from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread
+try:
+ import vllm
+except ModuleNotFoundError as e:
+ print('Could not import vllm-gptq:', e)
+ print('Please see vllm.md for install instructions')
+ sys.exit(1)
+
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@@ -130,6 +137,10 @@ def home():
else:
info_html = ''
+ mode_info = ''
+ if opts.mode == 'vllm':
+ mode_info = vllm_info
+
return render_template('home.html',
llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code,
@@ -143,7 +154,7 @@ def home():
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
- extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
+ extra_info=mode_info,
)
@@ -156,5 +167,11 @@ def fallback(first=None, rest=None):
}), 404
+@app.errorhandler(500)
+def server_error(e):
+ print(e)
+ return {'error': True}, 500
+
+
if __name__ == "__main__":
app.run(host='0.0.0.0')
diff --git a/vllm.md b/vllm.md
new file mode 100644
index 0000000..e091362
--- /dev/null
+++ b/vllm.md
@@ -0,0 +1,4 @@
+```bash
+wget https://git.evulid.cc/attachments/6e7bfc04-cad4-4494-a98d-1391fbb402d3 -O vllm-0.1.3-cp311-cp311-linux_x86_64.whl && pip install vllm-0.1.3-cp311-cp311-linux_x86_64.whl
+pip install auto_gptq
+```
\ No newline at end of file