more work
This commit is contained in:
parent
6c97f16415
commit
2e80ccceab
|
@ -1,4 +1,5 @@
|
|||
.idea
|
||||
punkt/
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
@ -161,4 +162,3 @@ cython_debug/
|
|||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import nltk
|
||||
|
||||
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
target = script_dir / 'punkt'
|
||||
print('Downloading punkt to:', target)
|
||||
nltk.download('punkt', download_dir=target)
|
|
@ -2,3 +2,5 @@ starlette==0.37.2
|
|||
uvicorn==0.29.0
|
||||
transformers==4.40.1
|
||||
torch==2.3.0
|
||||
gunicorn==22.0.0
|
||||
nltk==3.8.1
|
52
server.py
52
server.py
|
@ -1,22 +1,27 @@
|
|||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import nltk
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import pipeline
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('MAIN')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def generate(request):
|
||||
payload = await request.body()
|
||||
string = payload.decode("utf-8")
|
||||
data = await request.json()
|
||||
string = data.get('text')
|
||||
if not string:
|
||||
return JSONResponse({"error": "invalid request"}, status_code=400)
|
||||
response_q = asyncio.Queue()
|
||||
await request.app.model_queue.put((string, response_q))
|
||||
output = await response_q.get()
|
||||
|
@ -25,11 +30,29 @@ async def generate(request):
|
|||
|
||||
async def server_loop(queue):
|
||||
model = os.getenv("MODEL", default="facebook/bart-large-cnn")
|
||||
max_length = int(os.getenv("MAXLEN", default=200)) - 5 # add some buffer
|
||||
summarizer = pipeline("summarization", model=model, device=0)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model) # load tokenizer from the model
|
||||
|
||||
while True:
|
||||
(string, response_q) = await queue.get()
|
||||
out = summarizer(string, max_length=200, min_length=100, do_sample=False)
|
||||
await response_q.put(out)
|
||||
|
||||
# Automatically adjust the string and the task's max length
|
||||
sentences = nltk.sent_tokenize(string)
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
sentence_tokens = tokenizer.encode(sentence, truncation=False)
|
||||
if len(tokens) + len(sentence_tokens) < max_length:
|
||||
tokens.extend(sentence_tokens)
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
out = summarizer(tokenizer.decode(tokens), max_length=len(tokens), min_length=100, do_sample=False)
|
||||
await response_q.put(out)
|
||||
except RuntimeError:
|
||||
# Just kill ourselves and let gunicorn restart the worker.
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
|
||||
app = Starlette(
|
||||
|
@ -41,6 +64,11 @@ app = Starlette(
|
|||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
script_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
if not (script_dir / 'punkt' / 'tokenizers' / 'punkt').is_dir():
|
||||
logger.critical(f'Punkt not found at "{script_dir}/punkt". Please run "./punkt-download.py" first.')
|
||||
sys.exit(1)
|
||||
nltk.download('punkt', download_dir=script_dir / 'punkt')
|
||||
q = asyncio.Queue()
|
||||
app.model_queue = q
|
||||
asyncio.create_task(server_loop(q))
|
||||
await asyncio.create_task(server_loop(q))
|
||||
|
|
|
@ -7,7 +7,8 @@ Type=simple
|
|||
User=flask
|
||||
WorkingDirectory=/srv/text-summarization-api
|
||||
Environment="MODEL=/storage/models/facebook/bart-large-cnn"
|
||||
ExecStart=/srv/text-summarization-api/venv/bin/uvicorn server:app --host 0.0.0.0
|
||||
Environment="MAXLEN=1000"
|
||||
ExecStart=/srv/text-summarization-api/venv/bin/gunicorn -w 4 -k uvicorn.workers.UvicornWorker server:app -b 0.0.0.0:8000
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
|
||||
|
|
Loading…
Reference in New Issue