45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
import json
|
|
|
|
import requests
|
|
from flask import current_app
|
|
|
|
from llm_server import opts
|
|
from llm_server.database import tokenizer
|
|
|
|
|
|
def prepare_json(json_data: dict):
|
|
token_count = len(tokenizer.encode(json_data.get('prompt', '')))
|
|
seed = json_data.get('seed', None)
|
|
if seed == -1:
|
|
seed = None
|
|
typical_p = json_data.get('typical_p', None)
|
|
if typical_p >= 1:
|
|
typical_p = 0.999
|
|
return {
|
|
'inputs': json_data.get('prompt', ''),
|
|
'parameters': {
|
|
'max_new_tokens': opts.context_size - token_count,
|
|
'repetition_penalty': json_data.get('repetition_penalty', None),
|
|
'seed': seed,
|
|
'stop': json_data.get('stopping_strings', []),
|
|
'temperature': json_data.get('temperature', None),
|
|
'top_k': json_data.get('top_k', None),
|
|
'top_p': json_data.get('top_p', None),
|
|
# 'truncate': opts.token_limit,
|
|
'typical_p': typical_p,
|
|
'watermark': False
|
|
}
|
|
}
|
|
|
|
|
|
def generate(json_data: dict):
|
|
print(json.dumps(prepare_json(json_data)))
|
|
# try:
|
|
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl)
|
|
print(r.text)
|
|
# except Exception as e:
|
|
# return False, None, f'{e.__class__.__name__}: {e}'
|
|
# if r.status_code != 200:
|
|
# return False, r, f'Backend returned {r.status_code}'
|
|
# return True, r, None
|