37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import requests
|
|
from flask import current_app
|
|
|
|
from llm_server import opts
|
|
|
|
|
|
def prepare_json(json_data: dict):
|
|
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
|
|
seed = json_data.get('seed', None)
|
|
if seed == -1:
|
|
seed = None
|
|
return {
|
|
'inputs': json_data.get('prompt', ''),
|
|
'parameters': {
|
|
'max_new_tokens': token_count - opts.token_limit,
|
|
'repetition_penalty': json_data.get('repetition_penalty', None),
|
|
'seed': seed,
|
|
'stop': json_data.get('stopping_strings', []),
|
|
'temperature': json_data.get('temperature', None),
|
|
'top_k': json_data.get('top_k', None),
|
|
'top_p': json_data.get('top_p', None),
|
|
'truncate': True,
|
|
'typical_p': json_data.get('typical_p', None),
|
|
'watermark': False
|
|
}
|
|
}
|
|
|
|
|
|
def generate(json_data: dict):
|
|
try:
|
|
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
|
|
except Exception as e:
|
|
return False, None, f'{e.__class__.__name__}: {e}'
|
|
if r.status_code != 200:
|
|
return False, r, f'Backend returned {r.status_code}'
|
|
return True, r, None
|