diff --git a/.gitignore b/.gitignore index d12d30d..ec4f14b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ proxy-server.db .idea config/config.yml +install vllm_gptq-0.1.3-py3-none-any.whl # ---> Python # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index b6abb5f..73af42c 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,8 @@ The purpose of this server is to abstract your LLM backend from your frontend AP 2. `python3 -m venv venv` 3. `source venv/bin/activate` 4. `pip install -r requirements.txt` -5. `python3 server.py` +5. `wget https://git.evulid.cc/attachments/89c87201-58b1-4e28-b8fd-d0b323c810c4 -O vllm_gptq-0.1.3-py3-none-any.whl && pip install vllm_gptq-0.1.3-py3-none-any.whl` +6. `python3 server.py` An example systemctl service file is provided in `other/local-llm.service`. diff --git a/llm_server/database.py b/llm_server/database.py index e0bb1a5..e59eb55 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -18,6 +18,7 @@ def init_db(): CREATE TABLE prompts ( ip TEXT, token TEXT DEFAULT NULL, + backend TEXT, prompt TEXT, prompt_tokens INTEGER, response TEXT, @@ -71,8 +72,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe timestamp = int(time.time()) conn = sqlite3.connect(opts.database_path) c = conn.cursor() - c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp)) + c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp)) conn.commit() conn.close() @@ -129,15 +130,17 @@ def average_column_for_model(table_name, column_name, model_name): return result[0] -def weighted_average_column_for_model(table_name, column_name, model_name, exclude_zeros: bool = False): +def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() - cursor.execute(f"SELECT DISTINCT model FROM {table_name}") - models = [row[0] for row in cursor.fetchall()] + cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}") + models_backends = [(row[0], row[1]) for row in cursor.fetchall()] model_averages = {} - for model in models: - cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,)) + for model, backend in models_backends: + if backend != backend_name: + continue + cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend)) results = cursor.fetchall() if not results: @@ -155,11 +158,11 @@ def weighted_average_column_for_model(table_name, column_name, model_name, exclu if total_weight == 0: continue - model_averages[model] = weighted_sum / total_weight + model_averages[(model, backend)] = weighted_sum / total_weight conn.close() - return model_averages.get(model_name) + return model_averages.get((model_name, backend_name)) def sum_column(table_name, column_name): diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 9b456e0..4121d3e 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -1,10 +1,11 @@ import requests from llm_server import opts -from pathlib import Path + def get_running_model(): # TODO: cache the results for 1 min so we don't have to keep calling the backend + # TODO: only use one try/catch if opts.mode == 'oobabooga': try: @@ -22,11 +23,9 @@ def get_running_model(): return False, e elif opts.mode == 'vllm': try: - backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl) + backend_response = requests.get(f'{opts.backend_url}/model', timeout=3, verify=opts.verify_ssl) r_json = backend_response.json() - model_name = Path(r_json['data'][0]['root']).name - # r_json['data'][0]['root'] = model_name - return model_name, None + return r_json['model'], None except Exception as e: return False, e else: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 6285b1d..7302728 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple +from typing import Tuple, Union class LLMBackend: @@ -10,3 +10,12 @@ class LLMBackend: # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # raise NotImplementedError + + def get_parameters(self, parameters) -> Union[dict, None]: + """ + Validate and return the parameters for this backend. + Lets you set defaults for specific backends. + :param parameters: + :return: + """ + raise NotImplementedError diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index a5d6e69..ee3a7d6 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -1,15 +1,11 @@ -from typing import Tuple - -import requests from flask import jsonify -from ... import opts +from ..llm_backend import LLMBackend from ...database import log_prompt from ...helpers import safe_list_get from ...routes.cache import redis from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json -from ..llm_backend import LLMBackend class OobaboogaLLMBackend(LLMBackend): @@ -71,3 +67,7 @@ class OobaboogaLLMBackend(LLMBackend): # return r_json['result'], None # except Exception as e: # return False, e + + def get_parameters(self, parameters): + del parameters['prompt'] + return parameters diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 1a0bf92..2e1267c 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -1,17 +1,14 @@ """ This file is used by the worker that processes requests. """ -import io import json import time from uuid import uuid4 import requests -from requests import Response from llm_server import opts from llm_server.database import tokenizer -from llm_server.routes.cache import redis # TODO: make the VLMM backend return TPS and time elapsed @@ -19,7 +16,7 @@ from llm_server.routes.cache import redis def prepare_json(json_data: dict): # logit_bias is not currently supported - del json_data['logit_bias'] + # del json_data['logit_bias'] return json_data @@ -83,26 +80,26 @@ def transform_prompt_to_text(prompt: list): def handle_blocking_request(json_data: dict): try: - r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl) + r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl) except Exception as e: return False, None, f'{e.__class__.__name__}: {e}' # TODO: check for error here? - response_json = r.json() - response_json['error'] = False + # response_json = r.json() + # response_json['error'] = False - new_response = Response() - new_response.status_code = r.status_code - new_response._content = json.dumps(response_json).encode('utf-8') - new_response.raw = io.BytesIO(new_response._content) - new_response.headers = r.headers - new_response.url = r.url - new_response.reason = r.reason - new_response.cookies = r.cookies - new_response.elapsed = r.elapsed - new_response.request = r.request + # new_response = Response() + # new_response.status_code = r.status_code + # new_response._content = json.dumps(response_json).encode('utf-8') + # new_response.raw = io.BytesIO(new_response._content) + # new_response.headers = r.headers + # new_response.url = r.url + # new_response.reason = r.reason + # new_response.cookies = r.cookies + # new_response.elapsed = r.elapsed + # new_response.request = r.request - return True, new_response, None + return True, r, None def generate(json_data: dict): diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py index d83de8b..e873a30 100644 --- a/llm_server/llm/vllm/info.py +++ b/llm_server/llm/vllm/info.py @@ -1,13 +1,8 @@ -from pathlib import Path - -import requests - -from llm_server import opts - - -def get_vlmm_models_info(): - backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl) - r_json = backend_response.json() - r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name - r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name - return r_json +vllm_info = """

Important: This endpoint is running vllm-gptq and not all Oobabooga parameters are supported.

+Supported Parameters: +""" \ No newline at end of file diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index afbab33..86e8129 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,8 +1,9 @@ +from typing import Tuple + from flask import jsonify from vllm import SamplingParams from llm_server.database import log_prompt -from llm_server.helpers import indefinite_article from llm_server.llm.llm_backend import LLMBackend from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import validate_json @@ -22,19 +23,23 @@ class VLLMBackend(LLMBackend): response_status_code = 0 if response_valid_json: - backend_response = response_json_body + if len(response_json_body.get('text', [])): + # Does vllm return the prompt and the response together??? + backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n') + else: + # Failsafe + backend_response = '' - if response_json_body.get('error'): - backend_err = True - error_type = response_json_body.get('error_type') - error_type_string = f'returned {indefinite_article(error_type)} {error_type} error' - backend_response = format_sillytavern_err( - f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}', - f'HTTP CODE {response_status_code}' - ) + # TODO: how to detect an error? + # if backend_response == '': + # backend_err = True + # backend_response = format_sillytavern_err( + # f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', + # f'HTTP CODE {response_status_code}' + # ) - log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) - return jsonify(backend_response), 200 + log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) + return jsonify({'results': [{'text': backend_response}]}), 200 else: backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True) @@ -44,13 +49,24 @@ class VLLMBackend(LLMBackend): 'results': [{'text': backend_response}] }), 200 - def validate_params(self, params_dict: dict): - try: - sampling_params = SamplingParams(**params_dict) - except ValueError as e: - print(e) - return False, e - return True, None + # def validate_params(self, params_dict: dict): + # default_params = SamplingParams() + # try: + # sampling_params = SamplingParams( + # temperature=params_dict.get('temperature', default_params.temperature), + # top_p=params_dict.get('top_p', default_params.top_p), + # top_k=params_dict.get('top_k', default_params.top_k), + # use_beam_search=True if params_dict['num_beams'] > 1 else False, + # length_penalty=params_dict.get('length_penalty', default_params.length_penalty), + # early_stopping=params_dict.get('early_stopping', default_params.early_stopping), + # stop=params_dict.get('stopping_strings', default_params.stop), + # ignore_eos=params_dict.get('ban_eos_token', False), + # max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens) + # ) + # except ValueError as e: + # print(e) + # return False, e + # return True, None # def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # try: @@ -61,3 +77,33 @@ class VLLMBackend(LLMBackend): # return r_json, None # except Exception as e: # return False, e + + def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]: + default_params = SamplingParams() + try: + sampling_params = SamplingParams( + temperature=parameters.get('temperature', default_params.temperature), + top_p=parameters.get('top_p', default_params.top_p), + top_k=parameters.get('top_k', default_params.top_k), + use_beam_search=True if parameters['num_beams'] > 1 else False, + stop=parameters.get('stopping_strings', default_params.stop), + ignore_eos=parameters.get('ban_eos_token', False), + max_tokens=parameters.get('max_new_tokens', default_params.max_tokens) + ) + except ValueError as e: + print(e) + return None, e + return vars(sampling_params), None + +# def transform_sampling_params(params: SamplingParams): +# return { +# 'temperature': params['temperature'], +# 'top_p': params['top_p'], +# 'top_k': params['top_k'], +# 'use_beam_search' = True if parameters['num_beams'] > 1 else False, +# length_penalty = parameters.get('length_penalty', default_params.length_penalty), +# early_stopping = parameters.get('early_stopping', default_params.early_stopping), +# stop = parameters.get('stopping_strings', default_params.stop), +# ignore_eos = parameters.get('ban_eos_token', False), +# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens) +# } diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 8ea7356..7c9a962 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -38,9 +38,9 @@ class OobaRequestHandler: self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') - self.parameters = self.get_parameters() self.priority = self.get_priority() self.backend = self.get_backend() + self.parameters = self.parameters_invalid_msg = None def validate_request(self) -> (bool, Union[str, None]): # TODO: move this to LLMBackend @@ -56,19 +56,9 @@ class OobaRequestHandler: else: return self.request.remote_addr - def get_parameters(self): - # TODO: make this a LLMBackend method - request_valid_json, self.request_json_body = validate_json(self.request.data) - if not request_valid_json: - return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 - parameters = self.request_json_body.copy() - if opts.mode in ['oobabooga', 'hf-textgen']: - del parameters['prompt'] - elif opts.mode == 'vllm': - parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias']) - else: - raise Exception - return parameters + # def get_parameters(self): + # # TODO: make this a LLMBackend method + # return self.backend.get_parameters() def get_priority(self): if self.token: @@ -91,24 +81,26 @@ class OobaRequestHandler: else: raise Exception + def get_parameters(self): + self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) + def handle_request(self): + request_valid_json, self.request_json_body = validate_json(self.request.data) + if not request_valid_json: + return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 + + self.get_parameters() + SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() - # Fix bug on text-generation-inference - # https://github.com/huggingface/text-generation-inference/issues/929 - if opts.mode == 'hf-textgen' and self.parameters.get('typical_p', 0) > 0.998: - self.request_json_body['typical_p'] = 0.998 - - if opts.mode == 'vllm': - full_model_path = redis.get('full_model_path') - if not full_model_path: - raise Exception - self.request_json_body['model'] = full_model_path.decode() - request_valid, invalid_request_err_msg = self.validate_request() - params_valid, invalid_params_err_msg = self.backend.validate_params(self.parameters) + if not self.parameters: + params_valid = False + else: + params_valid = True + if not request_valid or not params_valid: - error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, invalid_params_err_msg)] if not valid] + error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid] combined_error_message = ', '.join(error_messages) err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) @@ -119,21 +111,27 @@ class OobaRequestHandler: 'results': [{'text': err}] }), 200 + # Reconstruct the request JSON with the validated parameters and prompt. + prompt = self.request_json_body.get('prompt', '') + llm_request = {**self.parameters, 'prompt': prompt} + queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: - event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority) + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) else: # Client was rate limited event = None if not event: return self.handle_ratelimited() + event.wait() success, response, error_msg = event.data end_time = time.time() elapsed_time = end_time - self.start_time - return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers)) + + return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index c6c3915..0f42ccf 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -7,14 +7,7 @@ from ... import opts @bp.route('/generate', methods=['POST']) -@bp.route('/chat/completions', methods=['POST']) def generate(): - if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate': - return jsonify({ - 'code': 404, - 'error': 'this LLM backend is in VLLM mode' - }), 404 - request_valid_json, request_json_body = validate_json(request.data) if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index f3117be..7e09d6f 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -82,6 +82,7 @@ def generate_stats(): 'online': online, 'endpoints': { 'blocking': f'https://{opts.base_client_api}', + 'streaming': f'wss://{opts.base_client_api}/stream', }, 'queue': { 'processing': active_gen_workers, @@ -104,9 +105,9 @@ def generate_stats(): 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, } - if opts.mode in ['oobabooga', 'hf-textgen']: - output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream' - else: - output['endpoints']['streaming'] = None + # if opts.mode in ['oobabooga', 'hf-textgen']: + # output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream' + # else: + # output['endpoints']['streaming'] = None return deep_sort(output) diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 56d5e5e..26bb2e3 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,9 +4,7 @@ from flask import jsonify, request from . import bp from ..cache import cache -from ... import opts from ...llm.info import get_running_model -from ...llm.vllm.info import get_vlmm_models_info # cache = Cache(bp, config={'CACHE_TYPE': 'simple'}) @@ -20,16 +18,7 @@ from ...llm.vllm.info import get_vlmm_models_info @bp.route('/model', methods=['GET']) -@bp.route('/models', methods=['GET']) def get_model(): - if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model': - return jsonify({ - 'code': 404, - 'error': 'this LLM backend is in VLLM mode' - }), 404 - - - # We will manage caching ourself since we don't want to cache # when the backend is down. Also, Cloudflare won't cache 500 errors. cache_key = 'model_cache::' + request.url @@ -46,18 +35,10 @@ def get_model(): 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: - if opts.mode in ['oobabooga', 'hf-texgen']: - response = jsonify({ - 'result': model, - 'timestamp': int(time.time()) - }), 200 - elif opts.mode == 'vllm': - response = jsonify({ - **get_vlmm_models_info(), - 'timestamp': int(time.time()) - }), 200 - else: - raise Exception + response = jsonify({ + 'result': model, + 'timestamp': int(time.time()) + }), 200 cache.set(cache_key, response, timeout=60) return response diff --git a/llm_server/threads.py b/llm_server/threads.py index b6e6132..5ce9c0e 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -3,7 +3,6 @@ from threading import Thread import requests -import llm_server from llm_server import opts from llm_server.database import weighted_average_column_for_model from llm_server.llm.info import get_running_model @@ -25,16 +24,6 @@ class MainBackgroundThread(Thread): redis.set('backend_online', 0) redis.set_dict('backend_info', {}) - if opts.mode == 'vllm': - while True: - try: - backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl) - r_json = backend_response.json() - redis.set('full_model_path', r_json['data'][0]['root']) - break - except Exception as e: - print(e) - def run(self): while True: if opts.mode == 'oobabooga': @@ -77,13 +66,13 @@ class MainBackgroundThread(Thread): # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # was entered into the column. The new code enters null instead but we need to be backwards compatible for now - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, exclude_zeros=True) or 0 + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, exclude_zeros=True) or 0 redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, exclude_zeros=True) or 0 + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, exclude_zeros=True) or 0 redis.set('average_output_tokens', average_output_tokens) # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) diff --git a/other/vllm-gptq-setup.py b/other/vllm-gptq-setup.py new file mode 100644 index 0000000..dd1b250 --- /dev/null +++ b/other/vllm-gptq-setup.py @@ -0,0 +1,70 @@ +import io +import os +import re +from typing import List + +import setuptools +from torch.utils.cpp_extension import BuildExtension + +ROOT_DIR = os.path.dirname(__file__) + +""" +Build vllm-gptq without any CUDA +""" + + +def get_path(*filepath) -> str: + return os.path.join(ROOT_DIR, *filepath) + + +def find_version(filepath: str): + """Extract version information from the given filepath. + + Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py + """ + with open(filepath) as fp: + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + + +def read_readme() -> str: + """Read the README file.""" + return io.open(get_path("README.md"), "r", encoding="utf-8").read() + + +def get_requirements() -> List[str]: + """Get Python package dependencies from requirements.txt.""" + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + return requirements + + +setuptools.setup( + name="vllm-gptq", + version=find_version(get_path("vllm", "__init__.py")), + author="vLLM Team", + license="Apache 2.0", + description="A high-throughput and memory-efficient inference and serving engine for LLMs", + long_description=read_readme(), + long_description_content_type="text/markdown", + url="https://github.com/vllm-project/vllm", + project_urls={ + "Homepage": "https://github.com/vllm-project/vllm", + "Documentation": "https://vllm.readthedocs.io/en/latest/", + }, + classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + packages=setuptools.find_packages( + exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")), + python_requires=">=3.8", + install_requires=get_requirements(), + cmdclass={"build_ext": BuildExtension}, +) diff --git a/other/vllm_api_server.py b/other/vllm_api_server.py new file mode 100644 index 0000000..f5b5f45 --- /dev/null +++ b/other/vllm_api_server.py @@ -0,0 +1,94 @@ +import argparse +import json +import time +from pathlib import Path +from typing import AsyncGenerator + +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. +app = FastAPI() + +served_model = None + + +@app.get("/model") +async def generate(request: Request) -> Response: + return JSONResponse({'model': served_model, 'timestamp': int(time.time())}) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + async def abort_request() -> None: + await engine.abort(request_id) + + if stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(stream_results(), background=background_tasks) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + served_model = Path(args.model).name + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/requirements.txt b/requirements.txt index 691f4ef..59a0753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ redis gevent async-timeout flask-sock -vllm \ No newline at end of file +auto_gptq \ No newline at end of file diff --git a/server.py b/server.py index 7a9310e..fa8c97c 100644 --- a/server.py +++ b/server.py @@ -11,7 +11,7 @@ from llm_server import opts from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.database import get_number_of_rows, init_db from llm_server.helpers import resolve_path -from llm_server.llm.hf_textgen.info import hf_textget_info +from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import cache, redis from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time @@ -20,6 +20,13 @@ from llm_server.routes.v1.generate_stats import generate_stats from llm_server.stream import init_socketio from llm_server.threads import MainBackgroundThread +try: + import vllm +except ModuleNotFoundError as e: + print('Could not import vllm-gptq:', e) + print('Please see vllm.md for install instructions') + sys.exit(1) + script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") @@ -130,6 +137,10 @@ def home(): else: info_html = '' + mode_info = '' + if opts.mode == 'vllm': + mode_info = vllm_info + return render_template('home.html', llm_middleware_name=config['llm_middleware_name'], analytics_tracking_code=analytics_tracking_code, @@ -143,7 +154,7 @@ def home(): streaming_input_textbox=mode_ui_names[opts.mode][2], context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False), - extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '', + extra_info=mode_info, ) @@ -156,5 +167,11 @@ def fallback(first=None, rest=None): }), 404 +@app.errorhandler(500) +def server_error(e): + print(e) + return {'error': True}, 500 + + if __name__ == "__main__": app.run(host='0.0.0.0') diff --git a/vllm.md b/vllm.md new file mode 100644 index 0000000..e091362 --- /dev/null +++ b/vllm.md @@ -0,0 +1,4 @@ +```bash +wget https://git.evulid.cc/attachments/6e7bfc04-cad4-4494-a98d-1391fbb402d3 -O vllm-0.1.3-cp311-cp311-linux_x86_64.whl && pip install vllm-0.1.3-cp311-cp311-linux_x86_64.whl +pip install auto_gptq +``` \ No newline at end of file