more work

This commit is contained in:
Cyberes 2024-04-27 22:36:36 -06:00
parent 6c97f16415
commit 2e80ccceab
5 changed files with 54 additions and 14 deletions

2
.gitignore vendored
View File

@ -1,4 +1,5 @@
.idea
punkt/
# ---> Python
# Byte-compiled / optimized / DLL files
@ -161,4 +162,3 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

9
punkt-download.py Normal file
View File

@ -0,0 +1,9 @@
import os
from pathlib import Path
import nltk
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
target = script_dir / 'punkt'
print('Downloading punkt to:', target)
nltk.download('punkt', download_dir=target)

View File

@ -2,3 +2,5 @@ starlette==0.37.2
uvicorn==0.29.0
transformers==4.40.1
torch==2.3.0
gunicorn==22.0.0
nltk==3.8.1

View File

@ -1,22 +1,27 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import logging
import os
import signal
import sys
from pathlib import Path
import nltk
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import AutoTokenizer
from transformers import pipeline
logging.basicConfig()
logger = logging.getLogger('MAIN')
logger.setLevel(logging.INFO)
async def generate(request):
payload = await request.body()
string = payload.decode("utf-8")
data = await request.json()
string = data.get('text')
if not string:
return JSONResponse({"error": "invalid request"}, status_code=400)
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
output = await response_q.get()
@ -25,11 +30,29 @@ async def generate(request):
async def server_loop(queue):
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
max_length = int(os.getenv("MAXLEN", default=200)) - 5 # add some buffer
summarizer = pipeline("summarization", model=model, device=0)
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
while True:
(string, response_q) = await queue.get()
out = summarizer(string, max_length=200, min_length=100, do_sample=False)
await response_q.put(out)
# Automatically adjust the string and the task's max length
sentences = nltk.sent_tokenize(string)
tokens = []
for sentence in sentences:
sentence_tokens = tokenizer.encode(sentence, truncation=False)
if len(tokens) + len(sentence_tokens) < max_length:
tokens.extend(sentence_tokens)
else:
break
try:
out = summarizer(tokenizer.decode(tokens), max_length=len(tokens), min_length=100, do_sample=False)
await response_q.put(out)
except RuntimeError:
# Just kill ourselves and let gunicorn restart the worker.
os.kill(os.getpid(), signal.SIGINT)
app = Starlette(
@ -41,6 +64,11 @@ app = Starlette(
@app.on_event("startup")
async def startup_event():
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
if not (script_dir / 'punkt' / 'tokenizers' / 'punkt').is_dir():
logger.critical(f'Punkt not found at "{script_dir}/punkt". Please run "./punkt-download.py" first.')
sys.exit(1)
nltk.download('punkt', download_dir=script_dir / 'punkt')
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))
await asyncio.create_task(server_loop(q))

View File

@ -7,7 +7,8 @@ Type=simple
User=flask
WorkingDirectory=/srv/text-summarization-api
Environment="MODEL=/storage/models/facebook/bart-large-cnn"
ExecStart=/srv/text-summarization-api/venv/bin/uvicorn server:app --host 0.0.0.0
Environment="MAXLEN=1000"
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000
Restart=on-failure
RestartSec=5s