local-llm-server/llm_server/llm/hf_textgen/generate.py

37 lines
1.2 KiB
Python
Raw Normal View History

2023-08-21 21:28:52 -06:00
import requests
from flask import current_app
from llm_server import opts
def prepare_json(json_data: dict):
token_count = len(current_app.tokenizer.encode(json_data.get('prompt', '')))
seed = json_data.get('seed', None)
if seed == -1:
seed = None
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': token_count - opts.token_limit,
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),
'temperature': json_data.get('temperature', None),
'top_k': json_data.get('top_k', None),
'top_p': json_data.get('top_p', None),
'truncate': True,
'typical_p': json_data.get('typical_p', None),
'watermark': False
}
}
def generate(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data))
except Exception as e:
return False, None, f'{e.__class__.__name__}: {e}'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None