implement vllm backend
This commit is contained in:
parent
c14cc51f09
commit
4c9d543eab
|
@ -9,7 +9,7 @@ token_limit: 8192
|
|||
# How many requests a single IP is allowed to put in the queue.
|
||||
# If an IP tries to put more than this their request will be rejected
|
||||
# until the other(s) are completed.
|
||||
ip_in_queue_max: 2
|
||||
simultaneous_requests_per_ip: 2
|
||||
|
||||
## Optional
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ config_default_vars = {
|
|||
'average_generation_time_mode': 'database',
|
||||
'info_html': None,
|
||||
'show_total_output_tokens': True,
|
||||
'ip_in_queue_max': 3,
|
||||
'simultaneous_requests_per_ip': 3,
|
||||
'show_backend_info': True,
|
||||
'max_new_tokens': 500
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middlewar
|
|||
mode_ui_names = {
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'hf-textgen': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'vllm': ('Chat Completion', 'Reverse Proxy', 'N/A'),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -8,5 +8,9 @@ def generator(request_json_body):
|
|||
elif opts.mode == 'hf-textgen':
|
||||
from .hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
r = generate(request_json_body)
|
||||
return r
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.helpers import indefinite_article
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
@ -47,3 +49,11 @@ class HfTextgenLLMBackend(LLMBackend):
|
|||
if params_dict.get('typical_p', 0) > 0.998:
|
||||
return False, '`typical_p` must be less than 0.999'
|
||||
return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/info', verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# return r_json['model_id'].replace('/', '_'), None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
def get_running_model():
|
||||
# TODO: cache the results for 1 min so we don't have to keep calling the backend
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
|
@ -18,5 +20,14 @@ def get_running_model():
|
|||
return r_json['model_id'].replace('/', '_'), None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
model_name = Path(r_json['data'][0]['root']).name
|
||||
# r_json['data'][0]['root'] = model_name
|
||||
return model_name, None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from typing import Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_params(self, params_dict: dict) -> (bool, Union[str, None]):
|
||||
def validate_params(self, params_dict: dict) -> Tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.cache import redis
|
||||
|
@ -59,3 +63,11 @@ class OobaboogaLLMBackend(LLMBackend):
|
|||
def validate_params(self, params_dict: dict):
|
||||
# No validation required
|
||||
return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# return r_json['result'], None
|
||||
# except Exception as e:
|
||||
# return False, e
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import tokenizer
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
# logit_bias is not currently supported
|
||||
del json_data['logit_bias']
|
||||
return json_data
|
||||
|
||||
|
||||
def transform_to_text(json_request, api_response):
|
||||
"""
|
||||
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
|
||||
:param json_request:
|
||||
:param api_response:
|
||||
:return:
|
||||
"""
|
||||
prompt = transform_prompt_to_text(json_request['messages'])
|
||||
text = ''
|
||||
finish_reason = None
|
||||
for line in api_response.split('\n'):
|
||||
if line.startswith('data:'):
|
||||
try:
|
||||
data = json.loads(line[5:].strip())
|
||||
except json.decoder.JSONDecodeError:
|
||||
break
|
||||
print(data)
|
||||
if 'choices' in data:
|
||||
for choice in data['choices']:
|
||||
if 'delta' in choice and 'content' in choice['delta']:
|
||||
text += choice['delta']['content']
|
||||
if data['choices'][0]['finish_reason']:
|
||||
finish_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
prompt_tokens = len(tokenizer.encode(prompt))
|
||||
completion_tokens = len(tokenizer.encode(text))
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
"id": str(uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": opts.running_model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def transform_prompt_to_text(prompt: list):
|
||||
text = ''
|
||||
for item in prompt:
|
||||
text += item['content'] + '\n'
|
||||
return text.strip('\n')
|
||||
|
||||
|
||||
def handle_blocking_request(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/v1/chat/completions', json=prepare_json(json_data), verify=opts.verify_ssl)
|
||||
except Exception as e:
|
||||
return False, None, f'{e.__class__.__name__}: {e}'
|
||||
|
||||
# TODO: check for error here?
|
||||
response_json = r.json()
|
||||
response_json['error'] = False
|
||||
|
||||
new_response = Response()
|
||||
new_response.status_code = r.status_code
|
||||
new_response._content = json.dumps(response_json).encode('utf-8')
|
||||
new_response.raw = io.BytesIO(new_response._content)
|
||||
new_response.headers = r.headers
|
||||
new_response.url = r.url
|
||||
new_response.reason = r.reason
|
||||
new_response.cookies = r.cookies
|
||||
new_response.elapsed = r.elapsed
|
||||
new_response.request = r.request
|
||||
|
||||
return True, new_response, None
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
full_model_path = redis.get('full_model_path')
|
||||
if not full_model_path:
|
||||
raise Exception
|
||||
json_data['model'] = full_model_path.decode()
|
||||
if json_data.get('stream'):
|
||||
raise Exception('streaming not implemented')
|
||||
else:
|
||||
return handle_blocking_request(json_data)
|
|
@ -0,0 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_vlmm_models_info():
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
r_json['data'][0]['root'] = Path(r_json['data'][0]['root']).name
|
||||
r_json['data'][0]['id'] = Path(r_json['data'][0]['id']).name
|
||||
return r_json
|
|
@ -0,0 +1,63 @@
|
|||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server.database import log_prompt
|
||||
from llm_server.helpers import indefinite_article
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
||||
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
|
||||
|
||||
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
|
||||
|
||||
class VLLMBackend(LLMBackend):
|
||||
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
backend_err = False
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
if response_valid_json:
|
||||
backend_response = response_json_body
|
||||
|
||||
if response_json_body.get('error'):
|
||||
backend_err = True
|
||||
error_type = response_json_body.get('error_type')
|
||||
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
|
||||
f'HTTP CODE {response_status_code}'
|
||||
)
|
||||
|
||||
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify(backend_response), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 200
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
try:
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
return False, e
|
||||
return True, None
|
||||
|
||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# try:
|
||||
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
# r_json = backend_response.json()
|
||||
# model_path = Path(r_json['data'][0]['root']).name
|
||||
# r_json['data'][0]['root'] = model_path
|
||||
# return r_json, None
|
||||
# except Exception as e:
|
||||
# return False, e
|
|
@ -20,5 +20,5 @@ show_uptime = True
|
|||
average_generation_time_mode = 'database'
|
||||
show_total_output_tokens = True
|
||||
netdata_root = None
|
||||
ip_in_queue_max = 3
|
||||
simultaneous_requests_per_ip = 3
|
||||
show_backend_info = True
|
||||
|
|
|
@ -40,7 +40,7 @@ class PriorityQueue:
|
|||
with self._cv:
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = redis.get_dict('queued_ip_count')
|
||||
if item[1] in ip_count and ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
|
||||
if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
return None # reject the request
|
||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||
self._index += 1
|
||||
|
|
|
@ -8,6 +8,7 @@ from llm_server import opts
|
|||
from llm_server.database import log_prompt
|
||||
from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
|
@ -17,6 +18,19 @@ from llm_server.routes.stats import SemaphoreCheckerThread
|
|||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
def delete_dict_key(d: dict, k: Union[str, list]):
|
||||
if isinstance(k, str):
|
||||
if k in d.keys():
|
||||
del d[k]
|
||||
elif isinstance(k, list):
|
||||
for item in k:
|
||||
if item in d.keys():
|
||||
del d[item]
|
||||
else:
|
||||
raise ValueError
|
||||
return d
|
||||
|
||||
|
||||
class OobaRequestHandler:
|
||||
def __init__(self, incoming_request):
|
||||
self.request_json_body = None
|
||||
|
@ -29,7 +43,8 @@ class OobaRequestHandler:
|
|||
self.backend = self.get_backend()
|
||||
|
||||
def validate_request(self) -> (bool, Union[str, None]):
|
||||
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
||||
# TODO: move this to LLMBackend
|
||||
if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens:
|
||||
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
||||
return True, None
|
||||
|
||||
|
@ -42,11 +57,17 @@ class OobaRequestHandler:
|
|||
return self.request.remote_addr
|
||||
|
||||
def get_parameters(self):
|
||||
# TODO: make this a LLMBackend method
|
||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
parameters = self.request_json_body.copy()
|
||||
del parameters['prompt']
|
||||
if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
del parameters['prompt']
|
||||
elif opts.mode == 'vllm':
|
||||
parameters = delete_dict_key(parameters, ['messages', 'model', 'stream', 'logit_bias'])
|
||||
else:
|
||||
raise Exception
|
||||
return parameters
|
||||
|
||||
def get_priority(self):
|
||||
|
@ -65,6 +86,8 @@ class OobaRequestHandler:
|
|||
return OobaboogaLLMBackend()
|
||||
elif opts.mode == 'hf-textgen':
|
||||
return HfTextgenLLMBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
@ -83,6 +106,7 @@ class OobaRequestHandler:
|
|||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
# TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types
|
||||
return jsonify({
|
||||
'code': 400,
|
||||
'msg': 'parameter validation error',
|
||||
|
@ -90,7 +114,7 @@ class OobaRequestHandler:
|
|||
}), 200
|
||||
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < opts.ip_in_queue_max or self.priority == 0:
|
||||
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
||||
event = priority_queue.put((self.request_json_body, self.client_ip, self.token, self.parameters), self.priority)
|
||||
else:
|
||||
# Client was rate limited
|
||||
|
@ -106,7 +130,7 @@ class OobaRequestHandler:
|
|||
return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, self.request_json_body.get('prompt', ''), elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def handle_ratelimited(self):
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
|
|
|
@ -3,12 +3,20 @@ from flask import jsonify, request
|
|||
from . import bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..request_handler import OobaRequestHandler
|
||||
from ... import opts
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
@bp.route('/chat/completions', methods=['POST'])
|
||||
def generate():
|
||||
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'generate':
|
||||
return jsonify({
|
||||
'code': 404,
|
||||
'error': 'this LLM backend is in VLLM mode'
|
||||
}), 404
|
||||
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request)
|
||||
|
|
|
@ -82,7 +82,6 @@ def generate_stats():
|
|||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{opts.base_client_api}',
|
||||
'streaming': f'wss://{opts.base_client_api}/v1/stream',
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
|
@ -96,7 +95,7 @@ def generate_stats():
|
|||
'concurrent': opts.concurrent_gens,
|
||||
'model': model_name,
|
||||
'mode': opts.mode,
|
||||
'simultaneous_requests_per_ip': opts.ip_in_queue_max,
|
||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||
},
|
||||
'keys': {
|
||||
'openaiKeys': '∞',
|
||||
|
@ -104,4 +103,10 @@ def generate_stats():
|
|||
},
|
||||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
|
||||
if opts.mode in ['oobabooga', 'hf-textgen']:
|
||||
output['endpoints']['streaming'] = f'wss://{opts.base_client_api}/v1/stream'
|
||||
else:
|
||||
output['endpoints']['streaming'] = None
|
||||
|
||||
return deep_sort(output)
|
||||
|
|
|
@ -3,10 +3,10 @@ import time
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import bp
|
||||
from ..helpers.http import cache_control
|
||||
from ..cache import cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
from ..cache import cache
|
||||
from ...llm.vllm.info import get_vlmm_models_info
|
||||
|
||||
|
||||
# cache = Cache(bp, config={'CACHE_TYPE': 'simple'})
|
||||
|
@ -20,7 +20,16 @@ from ..cache import cache
|
|||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
@bp.route('/models', methods=['GET'])
|
||||
def get_model():
|
||||
if opts.mode == 'vllm' and request.url.split('/')[-1] == 'model':
|
||||
return jsonify({
|
||||
'code': 404,
|
||||
'error': 'this LLM backend is in VLLM mode'
|
||||
}), 404
|
||||
|
||||
|
||||
|
||||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
|
@ -37,10 +46,18 @@ def get_model():
|
|||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
response = jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
if opts.mode in ['oobabooga', 'hf-texgen']:
|
||||
response = jsonify({
|
||||
'result': model,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
elif opts.mode == 'vllm':
|
||||
response = jsonify({
|
||||
**get_vlmm_models_info(),
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
else:
|
||||
raise Exception
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
|
|
@ -3,8 +3,10 @@ from threading import Thread
|
|||
|
||||
import requests
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
|
@ -23,17 +25,34 @@ class MainBackgroundThread(Thread):
|
|||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
if opts.mode == 'vllm':
|
||||
while True:
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/v1/models', timeout=3, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
redis.set('full_model_path', r_json['data'][0]['root'])
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
if opts.mode == 'oobabooga':
|
||||
try:
|
||||
r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
opts.running_model = r.json()['result']
|
||||
redis.set('backend_online', 1)
|
||||
except Exception as e:
|
||||
# try:
|
||||
# r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
|
||||
# opts.running_model = r.json()['result']
|
||||
# redis.set('backend_online', 1)
|
||||
# except Exception as e:
|
||||
# redis.set('backend_online', 0)
|
||||
# # TODO: handle error
|
||||
# print(e)
|
||||
model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
# TODO: handle error
|
||||
print(e)
|
||||
else:
|
||||
opts.running_model = model
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
try:
|
||||
r = requests.get(f'{opts.backend_url}/info', timeout=3, verify=opts.verify_ssl)
|
||||
|
@ -45,6 +64,14 @@ class MainBackgroundThread(Thread):
|
|||
redis.set('backend_online', 0)
|
||||
# TODO: handle error
|
||||
print(e)
|
||||
elif opts.mode == 'vllm':
|
||||
model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
opts.running_model = model
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
|
|
@ -9,3 +9,4 @@ redis
|
|||
gevent
|
||||
async-timeout
|
||||
flask-sock
|
||||
vllm
|
|
@ -41,7 +41,7 @@ if config['database_path'].startswith('./'):
|
|||
opts.database_path = resolve_path(config['database_path'])
|
||||
init_db()
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'hf-textgen']:
|
||||
if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
opts.mode = config['mode']
|
||||
|
@ -55,7 +55,7 @@ opts.show_uptime = config['show_uptime']
|
|||
opts.backend_url = config['backend_url'].strip('/')
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.ip_in_queue_max = config['ip_in_queue_max']
|
||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
opts.show_backend_info = config['show_backend_info']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
|
||||
|
@ -135,8 +135,8 @@ def home():
|
|||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=running_model,
|
||||
client_api=f'https://{opts.base_client_api}',
|
||||
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
|
||||
client_api=stats['endpoints']['blocking'],
|
||||
ws_client_api=stats['endpoints']['streaming'],
|
||||
estimated_wait=estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
|
|
|
@ -61,6 +61,10 @@
|
|||
font-size: 1.5em;
|
||||
}
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
|
@ -83,11 +87,17 @@
|
|||
<strong>Instructions:</strong>
|
||||
<ol>
|
||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||
{% if not ws_client_api %}
|
||||
<li>Set <kbd>Chat Completion Source</kbd> to <kbd>OpenAI</kbd>.</li>
|
||||
{% endif %}
|
||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
||||
{% if ws_client_api %}
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
|
||||
</li>
|
||||
<li>If using a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||
API key</kbd> textbox.
|
||||
</li>
|
||||
{% endif %}
|
||||
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
||||
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
||||
|
|
Reference in New Issue