Merge cluster to master #3
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from llm_server import opts
|
||||
|
@ -10,65 +9,60 @@ from llm_server.database.conn import database
|
|||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
def background_task():
|
||||
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
|
||||
# Try not to shove JSON into the database.
|
||||
if isinstance(response, dict) and response.get('results'):
|
||||
response = response['results'][0]['text']
|
||||
try:
|
||||
j = json.loads(response)
|
||||
if j.get('results'):
|
||||
response = j['results'][0]['text']
|
||||
except:
|
||||
pass
|
||||
# Try not to shove JSON into the database.
|
||||
if isinstance(response, dict) and response.get('results'):
|
||||
response = response['results'][0]['text']
|
||||
try:
|
||||
j = json.loads(response)
|
||||
if j.get('results'):
|
||||
response = j['results'][0]['text']
|
||||
except:
|
||||
pass
|
||||
|
||||
prompt_tokens = get_token_count(prompt, backend_url)
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = get_token_count(response, backend_url)
|
||||
else:
|
||||
response_tokens = None
|
||||
prompt_tokens = get_token_count(prompt, backend_url)
|
||||
print('starting')
|
||||
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
gen_time = round(gen_time, 3)
|
||||
if is_error:
|
||||
gen_time = None
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = get_token_count(response, backend_url)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
if not opts.log_prompts:
|
||||
prompt = None
|
||||
# Sometimes we may want to insert null into the DB, but
|
||||
# usually we want to insert a float.
|
||||
if gen_time:
|
||||
gen_time = round(gen_time, 3)
|
||||
if is_error:
|
||||
gen_time = None
|
||||
|
||||
if not opts.log_prompts and not is_error:
|
||||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
if not opts.log_prompts:
|
||||
prompt = None
|
||||
|
||||
if token:
|
||||
increment_token_uses(token)
|
||||
if not opts.log_prompts and not is_error:
|
||||
# TODO: test and verify this works as expected
|
||||
response = None
|
||||
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
running_model = backend_info.get('model')
|
||||
backend_mode = backend_info['mode']
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
INSERT INTO prompts
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
if token:
|
||||
increment_token_uses(token)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
running_model = backend_info.get('model')
|
||||
backend_mode = backend_info['mode']
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
INSERT INTO prompts
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import pickle
|
||||
from typing import Union
|
||||
|
||||
from redis import Redis
|
||||
|
||||
|
||||
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
r = Redis(host='localhost', port=6379, db=3)
|
||||
data = {
|
||||
'function': 'log_prompt',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'ip': ip,
|
||||
'token': token,
|
||||
'prompt': prompt,
|
||||
'response': response,
|
||||
'gen_time': gen_time,
|
||||
'parameters': parameters,
|
||||
'headers': headers,
|
||||
'backend_response_code': backend_response_code,
|
||||
'request_url': request_url,
|
||||
'backend_url': backend_url,
|
||||
'response_tokens': response_tokens,
|
||||
'is_error': is_error
|
||||
}
|
||||
}
|
||||
r.publish('database-logger', pickle.dumps(data))
|
|
@ -2,7 +2,7 @@ from flask import jsonify
|
|||
|
||||
from llm_server.custom_redis import redis
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database.database import log_prompt
|
||||
from ...database.database import do_db_log
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
from ...routes.helpers.http import validate_json
|
||||
|
@ -34,7 +34,7 @@ class OobaboogaBackend(LLMBackend):
|
|||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
|
||||
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
|
@ -57,13 +57,13 @@ class OobaboogaBackend(LLMBackend):
|
|||
if not backend_err:
|
||||
redis.incr('proompts')
|
||||
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
||||
log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
|
|
|
@ -3,17 +3,19 @@ from flask import jsonify
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def oai_to_vllm(request_json_body, hashes: bool, mode):
|
||||
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||
if not request_json_body.get('stop'):
|
||||
request_json_body['stop'] = []
|
||||
if not isinstance(request_json_body['stop'], list):
|
||||
# It is a string, so create a list with the existing element.
|
||||
request_json_body['stop'] = [request_json_body['stop']]
|
||||
|
||||
if hashes:
|
||||
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT', '### RESPONSE'])
|
||||
if stop_hashes:
|
||||
if opts.openai_force_no_hashes:
|
||||
request_json_body['stop'].append('### ')
|
||||
request_json_body['stop'].append('###')
|
||||
else:
|
||||
# TODO: make stopping strings a configurable
|
||||
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
|
||||
else:
|
||||
request_json_body['stop'].extend(['user:', 'assistant:'])
|
||||
|
||||
|
@ -41,6 +43,11 @@ def format_oai_err(err_msg):
|
|||
|
||||
|
||||
def validate_oai(parameters):
|
||||
if parameters.get('messages'):
|
||||
for m in parameters['messages']:
|
||||
if m['role'].lower() not in ['assistant', 'user', 'system']:
|
||||
return format_oai_err('messages role must be assistant, user, or system')
|
||||
|
||||
if parameters.get('temperature', 0) > 2:
|
||||
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
||||
if parameters.get('temperature', 0) < 0:
|
||||
|
|
|
@ -96,7 +96,7 @@ def transform_messages_to_prompt(oai_messages):
|
|||
elif msg['role'] == 'assistant':
|
||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||
else:
|
||||
return False
|
||||
raise Exception(f'Unknown role: {msg["role"]}')
|
||||
except Exception as e:
|
||||
# TODO: use logging
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -1,24 +1,16 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
# logit_bias is not currently supported
|
||||
# del json_data['logit_bias']
|
||||
|
||||
# Convert back to VLLM.
|
||||
json_data['max_tokens'] = json_data.pop('max_new_tokens')
|
||||
return json_data
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
||||
|
||||
|
@ -18,8 +18,8 @@ class VLLMBackend(LLMBackend):
|
|||
# Failsafe
|
||||
backend_response = ''
|
||||
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
|
||||
log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
|
||||
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
|
@ -29,14 +29,15 @@ class VLLMBackend(LLMBackend):
|
|||
top_k = parameters.get('top_k', self._default_params['top_k'])
|
||||
if top_k <= 0:
|
||||
top_k = -1
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=parameters.get('temperature', self._default_params['temperature']),
|
||||
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||
top_k=top_k,
|
||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||
stop=list(set(parameters.get('stopping_strings', self._default_params['stop']) or parameters.get('stop', self._default_params['stop']))),
|
||||
stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty'])
|
||||
)
|
||||
|
|
|
@ -4,7 +4,8 @@ import flask
|
|||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.database.database import do_db_log
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
@ -40,7 +41,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
return backend_response[0], 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
|
|
|
@ -3,6 +3,7 @@ import time
|
|||
import traceback
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp
|
||||
|
@ -10,7 +11,7 @@ from ..helpers.http import validate_json
|
|||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
@ -18,6 +19,7 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
|
|||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
|
@ -36,12 +38,20 @@ def openai_chat_completions():
|
|||
return 'Internal server error', 500
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
return 'DISABLED', 401
|
||||
return
|
||||
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'messages': handler.request_json_body['messages'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||
|
@ -64,7 +74,7 @@ def openai_chat_completions():
|
|||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_prompt(
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
|
@ -82,7 +92,6 @@ def openai_chat_completions():
|
|||
_, _, _ = event.wait()
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
|
@ -90,6 +99,7 @@ def openai_chat_completions():
|
|||
|
||||
def generate():
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
|
@ -125,8 +135,7 @@ def openai_chat_completions():
|
|||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_prompt(
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
|
|
|
@ -10,7 +10,8 @@ from ..helpers.http import validate_json
|
|||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...database.database import do_db_log
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm import get_token_count
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
|
@ -34,7 +35,7 @@ def openai_completions():
|
|||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
|
@ -102,7 +103,7 @@ def openai_completions():
|
|||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_prompt(
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
|
@ -164,7 +165,7 @@ def openai_completions():
|
|||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_prompt(
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
|
|
|
@ -11,7 +11,8 @@ from flask import Response, jsonify, make_response
|
|||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||
from llm_server.database.database import is_api_key_moderated, do_db_log
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
@ -58,7 +59,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
traceback.print_exc()
|
||||
|
||||
# TODO: support Ooba
|
||||
self.parameters = oai_to_vllm(self.parameters, hashes=True, mode=self.cluster_backend_info['mode'])
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
|
@ -88,7 +89,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
|
||||
return response, 429
|
||||
|
||||
|
@ -146,6 +147,6 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return False, invalid_oai_err_msg
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
# If the parameters were invalid, let the superclass deal with it.
|
||||
return super().validate_request(prompt, do_log)
|
||||
|
|
|
@ -7,7 +7,8 @@ from flask import Response, request
|
|||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_token_ratelimit, log_prompt
|
||||
from llm_server.database.database import get_token_ratelimit, do_db_log
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
|
@ -41,9 +42,11 @@ class RequestHandler:
|
|||
if not self.cluster_backend_info.get('mode'):
|
||||
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
if not self.cluster_backend_info.get('model'):
|
||||
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
if not self.cluster_backend_info.get('model_config'):
|
||||
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
|
||||
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model'):
|
||||
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
|
||||
self.offline = True
|
||||
else:
|
||||
self.offline = False
|
||||
|
@ -74,8 +77,6 @@ class RequestHandler:
|
|||
return self.request.remote_addr
|
||||
|
||||
def get_parameters(self):
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
return parameters, parameters_invalid_msg
|
||||
|
||||
|
@ -117,7 +118,7 @@ class RequestHandler:
|
|||
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
|
||||
return False, backend_response
|
||||
return True, (None, 0)
|
||||
|
||||
|
@ -160,17 +161,17 @@ class RequestHandler:
|
|||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(ip=self.client_ip,
|
||||
token=self.token,
|
||||
prompt=prompt,
|
||||
response=backend_response[0].data.decode('utf-8'),
|
||||
gen_time=None,
|
||||
parameters=self.parameters,
|
||||
headers=dict(self.request.headers),
|
||||
backend_response_code=response_status_code,
|
||||
request_url=self.request.url,
|
||||
backend_url=self.backend_url,
|
||||
is_error=True)
|
||||
log_to_db(ip=self.client_ip,
|
||||
token=self.token,
|
||||
prompt=prompt,
|
||||
response=backend_response[0].data.decode('utf-8'),
|
||||
gen_time=None,
|
||||
parameters=self.parameters,
|
||||
headers=dict(self.request.headers),
|
||||
backend_response_code=response_status_code,
|
||||
request_url=self.request.url,
|
||||
backend_url=self.backend_url,
|
||||
is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -190,7 +191,7 @@ class RequestHandler:
|
|||
if return_json_err:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
|
|
@ -9,7 +9,8 @@ from ..helpers.http import require_api_key, validate_json
|
|||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...database.database import do_db_log
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...sock import sock
|
||||
|
||||
|
@ -34,38 +35,38 @@ def stream_with_model(ws, model_name=None):
|
|||
|
||||
|
||||
def do_stream(ws, model_name):
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
'text': quitting_err_msg
|
||||
}))
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
log_prompt(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 500
|
||||
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
|
||||
try:
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
'text': quitting_err_msg
|
||||
}))
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 500
|
||||
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
|
@ -197,7 +198,7 @@ def do_stream(ws, model_name):
|
|||
pass
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(ip=handler.client_ip,
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
import pickle
|
||||
|
||||
import redis
|
||||
|
||||
from llm_server.database.database import do_db_log
|
||||
|
||||
|
||||
def db_logger():
|
||||
"""
|
||||
We don't want the logging operation to be blocking, so we will use a background worker
|
||||
to do the logging.
|
||||
:return:
|
||||
"""
|
||||
|
||||
r = redis.Redis(host='localhost', port=6379, db=3)
|
||||
p = r.pubsub()
|
||||
p.subscribe('database-logger')
|
||||
|
||||
for message in p.listen():
|
||||
if message['type'] == 'message':
|
||||
data = pickle.loads(message['data'])
|
||||
function_name = data['function']
|
||||
args = data['args']
|
||||
kwargs = data['kwargs']
|
||||
|
||||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
print('finished log')
|
|
@ -2,11 +2,11 @@ import time
|
|||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
from llm_server.workers.logger import db_logger
|
||||
from llm_server.workers.mainer import main_background_thread
|
||||
from llm_server.workers.moderator import start_moderation_workers
|
||||
from llm_server.workers.printer import console_printer
|
||||
|
@ -49,3 +49,8 @@ def start_background():
|
|||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the cluster worker.')
|
||||
|
||||
t = Thread(target=db_logger)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started background logger')
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
import warnings
|
||||
|
||||
import gradio as gr
|
||||
import openai
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
openai.api_key = 'null'
|
||||
openai.api_base = 'http://localhost:5000/api/openai/v1'
|
||||
|
||||
|
||||
def stream_response(prompt, history):
|
||||
messages = []
|
||||
for x in history:
|
||||
messages.append({'role': 'user', 'content': x[0]})
|
||||
messages.append({'role': 'assistant', 'content': x[1]})
|
||||
messages.append({'role': 'user', 'content': prompt})
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model='0',
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=300,
|
||||
stream=True
|
||||
)
|
||||
|
||||
message = ''
|
||||
for chunk in response:
|
||||
message += chunk['choices'][0]['delta']['content']
|
||||
yield message
|
||||
|
||||
|
||||
gr.ChatInterface(stream_response, examples=["hello", "hola", "merhaba"], title="Chatbot Demo", analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}').queue().launch()
|
Reference in New Issue