work
This commit is contained in:
parent
05baf3e30e
commit
0a51ad30af
60
server.py
60
server.py
|
@ -22,7 +22,7 @@ logger = logging.getLogger("uvicorn.info")
|
|||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
@ -38,48 +38,58 @@ async def generate(request):
|
|||
response_q = asyncio.Queue()
|
||||
await request.app.model_queue.put((string, response_q, request.client.host))
|
||||
output = await response_q.get()
|
||||
return JSONResponse(output)
|
||||
return JSONResponse({'summary_text': output})
|
||||
|
||||
|
||||
async def server_loop(queue):
|
||||
boot_time = datetime.now()
|
||||
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
|
||||
max_input_length = int(os.getenv("MAX_IN_LEN", default=200))
|
||||
max_new_tokens = int(os.getenv("MAX_NEW_LEN", default=200))
|
||||
min_length = int(os.getenv("MIN_LEN", default=100))
|
||||
summarizer = pipeline("summarization", model=model, device=0)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
|
||||
logger.info('Worker started!')
|
||||
logger.info(f'Worker started in {calculate_elapsed_time(boot_time)}')
|
||||
del boot_time
|
||||
|
||||
while True:
|
||||
(string, response_q, client_ip) = await queue.get()
|
||||
start_time = datetime.now()
|
||||
|
||||
# Automatically adjust the input string to the max length.
|
||||
sentences = nltk.sent_tokenize(string)
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
sentence_tokens = tokenizer.encode(sentence, truncation=False)
|
||||
if len(tokens) + len(sentence_tokens) < max_input_length:
|
||||
tokens.extend(sentence_tokens)
|
||||
else:
|
||||
break
|
||||
|
||||
out = ''
|
||||
err = False
|
||||
for i in range(5):
|
||||
try:
|
||||
out = summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False)
|
||||
await response_q.put(out)
|
||||
break
|
||||
except RuntimeError:
|
||||
err = True
|
||||
logger.error(f"Failed to generate summary for {string}: {traceback.format_exc()}")
|
||||
time.sleep(1)
|
||||
tokens = []
|
||||
|
||||
elapsed = calculate_elapsed_time(start_time)
|
||||
logger.info(f"{client_ip} - {len(tokens)} tokens - {elapsed}")
|
||||
try:
|
||||
# Automatically adjust the input string to the max length.
|
||||
sentences = nltk.sent_tokenize(string)
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
sentence_tokens = tokenizer.encode(sentence, truncation=False)
|
||||
if len(tokens) + len(sentence_tokens) < max_input_length:
|
||||
tokens.extend(sentence_tokens)
|
||||
else:
|
||||
break
|
||||
|
||||
for i in range(5):
|
||||
try:
|
||||
out = list(summarizer(tokenizer.decode(tokens), max_new_tokens=max_new_tokens, min_length=min_length, do_sample=False))[0]['summary_text']
|
||||
await response_q.put(out)
|
||||
err = False
|
||||
break
|
||||
except:
|
||||
err = True
|
||||
logger.error(f'Error generating summary for "{string}": {traceback.format_exc()}')
|
||||
time.sleep(1)
|
||||
except:
|
||||
err = True
|
||||
logger.error(f'Failed to generate summary for "{string}": {traceback.format_exc()}')
|
||||
|
||||
out_tokens = len(tokenizer.encode(out, truncation=False))
|
||||
logger.info(f"{client_ip} - {len(tokens)} in tokens - {out_tokens} out tokens - {calculate_elapsed_time(start_time)} elapsed")
|
||||
|
||||
if err:
|
||||
# Just kill ourselves and let gunicorn restart the worker.
|
||||
logger.error('Worker rebooting due to exception.')
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue