local-llm-server/llm_server/routes/openai/models.py

79 lines
2.8 KiB
Python

import traceback
import requests
from flask import jsonify
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
from ..stats import server_start_time
from ... import opts
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
from ...helpers import jsonify_pretty
from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models()
r = {
"object": "list",
"data": oai
}
# TODO: verify this works
if opts.openai_expose_our_model:
r["data"].insert(0, {
"id": running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
})
response = jsonify_pretty(r), 200
return response
@flask_cache.memoize(timeout=ONE_MONTH_SECONDS)
def fetch_openai_models():
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except:
traceback.print_exc()
return []
else:
return []