feat(clients): Python client (#103)
This commit is contained in:
parent
0e9ed1a8c2
commit
3fef90d50f
10
Makefile
10
Makefile
|
@ -13,7 +13,7 @@ server-dev:
|
||||||
cd server && make run-dev
|
cd server && make run-dev
|
||||||
|
|
||||||
router-dev:
|
router-dev:
|
||||||
cd router && cargo run
|
cd router && cargo run -- --port 8080
|
||||||
|
|
||||||
integration-tests: install-router install-launcher
|
integration-tests: install-router install-launcher
|
||||||
cargo test
|
cargo test
|
||||||
|
@ -22,16 +22,16 @@ python-tests:
|
||||||
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
|
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
|
||||||
|
|
||||||
run-bloom-560m:
|
run-bloom-560m:
|
||||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
|
||||||
|
|
||||||
run-bloom-560m-quantize:
|
run-bloom-560m-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
|
||||||
|
|
||||||
download-bloom:
|
download-bloom:
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
|
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
|
||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
31
README.md
31
README.md
|
@ -89,40 +89,35 @@ You can then query the model using either the `/generate` or `/generate_stream`
|
||||||
```shell
|
```shell
|
||||||
curl 127.0.0.1:8080/generate \
|
curl 127.0.0.1:8080/generate \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl 127.0.0.1:8080/generate_stream \
|
curl 127.0.0.1:8080/generate_stream \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
or from Python:
|
or from Python:
|
||||||
|
|
||||||
```python
|
|
||||||
import requests
|
|
||||||
|
|
||||||
result = requests.post("http://127.0.0.1:8080/generate", json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
|
|
||||||
print(result.json())
|
|
||||||
```
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install sseclient-py
|
pip install text-generation
|
||||||
```
|
```
|
||||||
|
|
||||||
````python
|
```python
|
||||||
import sseclient
|
from text_generation import Client
|
||||||
import requests
|
|
||||||
|
|
||||||
r = requests.post("http://127.0.0.1:8080/generate_stream", stream=True, json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
|
client = Client("http://127.0.0.1:8080")
|
||||||
sse_client = sseclient.SSEClient(r)
|
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
|
||||||
|
|
||||||
for i, event in enumerate(sse_client.events()):
|
text = ""
|
||||||
print(i, event.data)
|
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
|
||||||
````
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
print(text)
|
||||||
|
```
|
||||||
|
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
|
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
text_generation/__pycache__/
|
||||||
|
text_generation/pb/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
transformers
|
||||||
|
safetensors
|
|
@ -0,0 +1,6 @@
|
||||||
|
unit-tests:
|
||||||
|
python -m pytest --cov=text_generation tests
|
||||||
|
|
||||||
|
install:
|
||||||
|
pip install pip --upgrade
|
||||||
|
pip install -e .
|
|
@ -0,0 +1,52 @@
|
||||||
|
# Text Generation
|
||||||
|
|
||||||
|
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||||
|
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
|
||||||
|
|
||||||
|
## Get Started
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install text-generation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import InferenceAPIClient
|
||||||
|
|
||||||
|
client = InferenceAPIClient("bigscience/bloomz")
|
||||||
|
text = client.generate("Why is the sky blue?").generated_text
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
or with the asynchronous client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import InferenceAPIAsyncClient
|
||||||
|
|
||||||
|
client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||||
|
response = await client.generate("Why is the sky blue?")
|
||||||
|
print(response.generated_text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,26 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "text-generation"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Hugging Face Text Generation Python Client"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
readme = "README.md"
|
||||||
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
repository = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.7"
|
||||||
|
pydantic = "^1.10.5"
|
||||||
|
aiohttp = "^3.8.4"
|
||||||
|
huggingface-hub = "^0.12.1"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
pytest = "^6.2.5"
|
||||||
|
pytest-asyncio = "^0.17.2"
|
||||||
|
pytest-cov = "^3.0.0"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
|
@ -0,0 +1,56 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import __version__
|
||||||
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bloom_model():
|
||||||
|
return "bigscience/bloom"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl():
|
||||||
|
return "google/flan-t5-xxl"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_model():
|
||||||
|
return "fake/model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_model():
|
||||||
|
return "gpt2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_url():
|
||||||
|
return "https://api-inference.huggingface.co/models"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bloom_url(base_url, bloom_model):
|
||||||
|
return f"{base_url}/{bloom_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||||
|
return f"{base_url}/{flan_t5_xxl}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_url(base_url, fake_model):
|
||||||
|
return f"{base_url}/{fake_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_url(base_url, unsupported_model):
|
||||||
|
return f"{base_url}/{unsupported_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_headers():
|
||||||
|
return build_hf_headers(
|
||||||
|
library_name="text-generation-tests", library_version=__version__
|
||||||
|
)
|
|
@ -0,0 +1,127 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import Client, AsyncClient
|
||||||
|
from text_generation.errors import NotFoundError, ValidationError
|
||||||
|
from text_generation.types import FinishReason, PrefillToken, Token
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate(bloom_url, hf_headers):
|
||||||
|
client = Client(bloom_url, hf_headers)
|
||||||
|
response = client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
|
assert response.generated_text == "."
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
assert len(response.details.prefill) == 1
|
||||||
|
assert response.details.prefill[0] == PrefillToken(
|
||||||
|
id=9234, text="test", logprob=None
|
||||||
|
)
|
||||||
|
assert len(response.details.tokens) == 1
|
||||||
|
assert response.details.tokens[0] == Token(
|
||||||
|
id=17, text=".", logprob=-1.75, special=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_not_found(fake_url, hf_headers):
|
||||||
|
client = Client(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
client.generate("test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream(bloom_url, hf_headers):
|
||||||
|
client = Client(bloom_url, hf_headers)
|
||||||
|
responses = [
|
||||||
|
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(responses) == 1
|
||||||
|
response = responses[0]
|
||||||
|
|
||||||
|
assert response.generated_text == "."
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream_not_found(fake_url, hf_headers):
|
||||||
|
client = Client(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
list(client.generate_stream("test"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async(bloom_url, hf_headers):
|
||||||
|
client = AsyncClient(bloom_url, hf_headers)
|
||||||
|
response = await client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
|
assert response.generated_text == "."
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
assert len(response.details.prefill) == 1
|
||||||
|
assert response.details.prefill[0] == PrefillToken(
|
||||||
|
id=9234, text="test", logprob=None
|
||||||
|
)
|
||||||
|
assert len(response.details.tokens) == 1
|
||||||
|
assert response.details.tokens[0] == Token(
|
||||||
|
id=17, text=".", logprob=-1.75, special=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async_not_found(fake_url, hf_headers):
|
||||||
|
client = AsyncClient(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.generate("test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async(bloom_url, hf_headers):
|
||||||
|
client = AsyncClient(bloom_url, hf_headers)
|
||||||
|
responses = [
|
||||||
|
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(responses) == 1
|
||||||
|
response = responses[0]
|
||||||
|
|
||||||
|
assert response.generated_text == "."
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
||||||
|
client = AsyncClient(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
async for _ in client.generate_stream("test"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||||
|
pass
|
|
@ -0,0 +1,64 @@
|
||||||
|
from text_generation.errors import (
|
||||||
|
parse_error,
|
||||||
|
GenerationError,
|
||||||
|
IncompleteGenerationError,
|
||||||
|
OverloadedError,
|
||||||
|
ValidationError,
|
||||||
|
BadRequestError,
|
||||||
|
ShardNotReadyError,
|
||||||
|
ShardTimeoutError,
|
||||||
|
NotFoundError,
|
||||||
|
RateLimitExceededError,
|
||||||
|
UnknownError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generation_error():
|
||||||
|
payload = {"error_type": "generation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), GenerationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_incomplete_generation_error():
|
||||||
|
payload = {"error_type": "incomplete_generation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_overloaded_error():
|
||||||
|
payload = {"error_type": "overloaded", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), OverloadedError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error():
|
||||||
|
payload = {"error_type": "validation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), ValidationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_request_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), BadRequestError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_not_ready_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(403, payload), ShardNotReadyError)
|
||||||
|
assert isinstance(parse_error(424, payload), ShardNotReadyError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_timeout_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(504, payload), ShardTimeoutError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_found_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(404, payload), NotFoundError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_exceeded_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(429, payload), RateLimitExceededError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(500, payload), UnknownError)
|
|
@ -0,0 +1,34 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import (
|
||||||
|
InferenceAPIClient,
|
||||||
|
InferenceAPIAsyncClient,
|
||||||
|
Client,
|
||||||
|
AsyncClient,
|
||||||
|
)
|
||||||
|
from text_generation.errors import NotSupportedError
|
||||||
|
from text_generation.inference_api import get_supported_models
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_supported_models():
|
||||||
|
assert isinstance(get_supported_models(), list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client(bloom_model):
|
||||||
|
client = InferenceAPIClient(bloom_model)
|
||||||
|
assert isinstance(client, Client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_unsupported_model(unsupported_model):
|
||||||
|
with pytest.raises(NotSupportedError):
|
||||||
|
InferenceAPIClient(unsupported_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client(bloom_model):
|
||||||
|
client = InferenceAPIAsyncClient(bloom_model)
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client_unsupported_model(unsupported_model):
|
||||||
|
with pytest.raises(NotSupportedError):
|
||||||
|
InferenceAPIAsyncClient(unsupported_model)
|
|
@ -0,0 +1,39 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation.types import Parameters
|
||||||
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_validation():
|
||||||
|
# Test repetition_penalty
|
||||||
|
Parameters(repetition_penalty=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(repetition_penalty=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(repetition_penalty=-1)
|
||||||
|
|
||||||
|
# Test seed
|
||||||
|
Parameters(seed=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(seed=-1)
|
||||||
|
|
||||||
|
# Test temperature
|
||||||
|
Parameters(temperature=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(temperature=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(temperature=-1)
|
||||||
|
|
||||||
|
# Test top_k
|
||||||
|
Parameters(top_k=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_k=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_k=-1)
|
||||||
|
|
||||||
|
# Test top_p
|
||||||
|
Parameters(top_p=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=-1)
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__version__ = "0.3.2"
|
||||||
|
|
||||||
|
from text_generation.client import Client, AsyncClient
|
||||||
|
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
|
@ -0,0 +1,415 @@
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||||
|
|
||||||
|
from text_generation.types import (
|
||||||
|
StreamResponse,
|
||||||
|
Response,
|
||||||
|
Request,
|
||||||
|
Parameters,
|
||||||
|
)
|
||||||
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
"""Client to make calls to a text-generation-inference instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import Client
|
||||||
|
|
||||||
|
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||||
|
>>> client.generate("Why is the sky blue?").generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
base_url (`str`):
|
||||||
|
text-generation-inference instance base url
|
||||||
|
headers (`Optional[Dict[str, str]]`):
|
||||||
|
Additional headers
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermarking: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
watermarking (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: generated response
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermarking,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermarking: bool = False,
|
||||||
|
) -> Iterator[StreamResponse]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following stream of tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
watermarking (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermarking,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, resp.json())
|
||||||
|
|
||||||
|
# Parse ServerSentEvents
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
# Skip line
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
# Event data
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
# Decode payload
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
# Parse payload
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
# If we failed to parse the payload, then it is an error payload
|
||||||
|
raise parse_error(resp.status_code, json_payload)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncClient:
|
||||||
|
"""Asynchronous Client to make calls to a text-generation-inference instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import AsyncClient
|
||||||
|
|
||||||
|
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||||
|
>>> response = await client.generate("Why is the sky blue?")
|
||||||
|
>>> response.generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
base_url (`str`):
|
||||||
|
text-generation-inference instance base url
|
||||||
|
headers (`Optional[Dict[str, str]]`):
|
||||||
|
Additional headers
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermarking: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following text asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
watermarking (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: generated response
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermarking,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermarking: bool = False,
|
||||||
|
) -> AsyncIterator[StreamResponse]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
watermarking (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermarking,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, await resp.json())
|
||||||
|
|
||||||
|
# Parse ServerSentEvents
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
# Skip line
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
# Event data
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
# Decode payload
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
# Parse payload
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
# If we failed to parse the payload, then it is an error payload
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
yield response
|
|
@ -0,0 +1,106 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
# Text Generation Inference Errors
|
||||||
|
class ValidationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class OverloadedError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteGenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# API Inference Errors
|
||||||
|
class BadRequestError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardNotReadyError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardTimeoutError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceededError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedError(Exception):
|
||||||
|
def __init__(self, model_id: str):
|
||||||
|
message = (
|
||||||
|
f"Model `{model_id}` is not available for inference with this client. \n"
|
||||||
|
"Use `huggingface_hub.inference_api.InferenceApi` instead."
|
||||||
|
)
|
||||||
|
super(NotSupportedError, self).__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# Unknown error
|
||||||
|
class UnknownError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
|
||||||
|
"""
|
||||||
|
Parse error given an HTTP status code and a json payload
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status_code (`int`):
|
||||||
|
HTTP status code
|
||||||
|
payload (`Dict[str, str]`):
|
||||||
|
Json payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: parsed exception
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Try to parse a Text Generation Inference error
|
||||||
|
message = payload["error"]
|
||||||
|
if "error_type" in payload:
|
||||||
|
error_type = payload["error_type"]
|
||||||
|
if error_type == "generation":
|
||||||
|
return GenerationError(message)
|
||||||
|
if error_type == "incomplete_generation":
|
||||||
|
return IncompleteGenerationError(message)
|
||||||
|
if error_type == "overloaded":
|
||||||
|
return OverloadedError(message)
|
||||||
|
if error_type == "validation":
|
||||||
|
return ValidationError(message)
|
||||||
|
|
||||||
|
# Try to parse a APIInference error
|
||||||
|
if status_code == 400:
|
||||||
|
return BadRequestError(message)
|
||||||
|
if status_code == 403 or status_code == 424:
|
||||||
|
return ShardNotReadyError(message)
|
||||||
|
if status_code == 504:
|
||||||
|
return ShardTimeoutError(message)
|
||||||
|
if status_code == 404:
|
||||||
|
return NotFoundError(message)
|
||||||
|
if status_code == 429:
|
||||||
|
return RateLimitExceededError(message)
|
||||||
|
|
||||||
|
# Fallback to an unknown error
|
||||||
|
return UnknownError(message)
|
|
@ -0,0 +1,150 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
from text_generation import Client, AsyncClient, __version__
|
||||||
|
from text_generation.errors import NotSupportedError
|
||||||
|
|
||||||
|
INFERENCE_ENDPOINT = os.environ.get(
|
||||||
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_models() -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Get the list of supported text-generation models from GitHub
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
|
||||||
|
"""
|
||||||
|
global SUPPORTED_MODELS
|
||||||
|
if SUPPORTED_MODELS is not None:
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
file_content = response.json()["content"]
|
||||||
|
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
warnings.warn("Could not retrieve list of supported models.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceAPIClient(Client):
|
||||||
|
"""Client to make calls to the HuggingFace Inference API.
|
||||||
|
|
||||||
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||||
|
text-generation-inference
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import InferenceAPIClient
|
||||||
|
|
||||||
|
>>> client = InferenceAPIClient("bigscience/bloomz")
|
||||||
|
>>> client.generate("Why is the sky blue?").generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
"""
|
||||||
|
Init headers and API information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
Id of repository (e.g. `bigscience/bloom`).
|
||||||
|
token (`str`, `optional`):
|
||||||
|
The API token to use as HTTP bearer authorization. This is not
|
||||||
|
the authentication token. You can find the token in
|
||||||
|
https://huggingface.co/settings/token. Alternatively, you can
|
||||||
|
find both your organizations and personal API tokens using
|
||||||
|
`HfApi().whoami(token)`.
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and repo_id not in supported_models:
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
|
headers = build_hf_headers(
|
||||||
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
|
)
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
|
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceAPIAsyncClient(AsyncClient):
|
||||||
|
"""Aynschronous Client to make calls to the HuggingFace Inference API.
|
||||||
|
|
||||||
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||||
|
text-generation-inference
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import InferenceAPIAsyncClient
|
||||||
|
|
||||||
|
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||||
|
>>> response = await client.generate("Why is the sky blue?")
|
||||||
|
>>> response.generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
"""
|
||||||
|
Init headers and API information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
Id of repository (e.g. `bigscience/bloom`).
|
||||||
|
token (`str`, `optional`):
|
||||||
|
The API token to use as HTTP bearer authorization. This is not
|
||||||
|
the authentication token. You can find the token in
|
||||||
|
https://huggingface.co/settings/token. Alternatively, you can
|
||||||
|
find both your organizations and personal API tokens using
|
||||||
|
`HfApi().whoami(token)`.
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and repo_id not in supported_models:
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
|
headers = build_hf_headers(
|
||||||
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
|
)
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
|
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)
|
|
@ -0,0 +1,99 @@
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class Parameters(BaseModel):
|
||||||
|
do_sample: bool = False
|
||||||
|
max_new_tokens: int = 20
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: bool = False
|
||||||
|
stop: List[str] = []
|
||||||
|
seed: Optional[int]
|
||||||
|
temperature: Optional[float]
|
||||||
|
top_k: Optional[int]
|
||||||
|
top_p: Optional[float]
|
||||||
|
watermark: bool = False
|
||||||
|
details: bool = False
|
||||||
|
|
||||||
|
@validator("repetition_penalty")
|
||||||
|
def valid_repetition_penalty(cls, v):
|
||||||
|
if v is not None and v is v <= 0:
|
||||||
|
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("seed")
|
||||||
|
def valid_seed(cls, v):
|
||||||
|
if v is not None and v is v < 0:
|
||||||
|
raise ValidationError("`seed` must be positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("temperature")
|
||||||
|
def valid_temp(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`temperature` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("top_k")
|
||||||
|
def valid_top_k(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`top_k` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("top_p")
|
||||||
|
def valid_top_p(cls, v):
|
||||||
|
if v is not None and (v <= 0 or v > 1.0):
|
||||||
|
raise ValidationError("`top_p` must be > 0.0 and <= 1.0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class Request(BaseModel):
|
||||||
|
inputs: str
|
||||||
|
parameters: Parameters
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillToken(BaseModel):
|
||||||
|
id: int
|
||||||
|
text: str
|
||||||
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
id: int
|
||||||
|
text: str
|
||||||
|
logprob: float
|
||||||
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FinishReason(Enum):
|
||||||
|
Length = "length"
|
||||||
|
EndOfSequenceToken = "eos_token"
|
||||||
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
class Details(BaseModel):
|
||||||
|
finish_reason: FinishReason
|
||||||
|
generated_tokens: int
|
||||||
|
seed: Optional[int]
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDetails(BaseModel):
|
||||||
|
finish_reason: FinishReason
|
||||||
|
generated_tokens: int
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
generated_text: str
|
||||||
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
|
class StreamResponse(BaseModel):
|
||||||
|
token: Token
|
||||||
|
generated_text: Optional[str]
|
||||||
|
details: Optional[StreamDetails]
|
|
@ -439,3 +439,14 @@ pub enum InferError {
|
||||||
#[error("Incomplete generation")]
|
#[error("Incomplete generation")]
|
||||||
IncompleteGeneration,
|
IncompleteGeneration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl InferError {
|
||||||
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
InferError::GenerationError(_) => "generation",
|
||||||
|
InferError::Overloaded(_) => "overloaded",
|
||||||
|
InferError::ValidationError(_) => "validation",
|
||||||
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -152,8 +152,8 @@ pub(crate) struct Details {
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Option<Vec<PrefillToken>>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Vec<Token>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -185,6 +185,6 @@ pub(crate) struct StreamResponse {
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
#[schema(inline)]
|
|
||||||
pub error: String,
|
pub error: String,
|
||||||
|
pub error_type: String,
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,8 +133,8 @@ async fn generate(
|
||||||
true => Some(Details {
|
true => Some(Details {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: Some(response.prefill),
|
prefill: response.prefill,
|
||||||
tokens: Some(response.tokens),
|
tokens: response.tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
|
@ -554,6 +554,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
status_code,
|
status_code,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: err.to_string(),
|
||||||
|
error_type: err.error_type().to_string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -564,6 +565,7 @@ impl From<InferError> for Event {
|
||||||
Event::default()
|
Event::default()
|
||||||
.json_data(ErrorResponse {
|
.json_data(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: err.to_string(),
|
||||||
|
error_type: err.error_type().to_string(),
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
|
@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("temperature must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("repetition_penalty must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
RepetitionPenalty,
|
RepetitionPenalty,
|
||||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
#[error("`top_p` must be > 0.0 and <= 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("`top_k` must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("max_new_tokens must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
MaxNewTokens,
|
MaxNewTokens,
|
||||||
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
|
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||||
MaxTotalTokens(usize, usize, u32),
|
MaxTotalTokens(usize, usize, u32),
|
||||||
#[error("inputs must have less than {0} tokens. Given: {1}")]
|
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
#[error("inputs cannot be empty")]
|
#[error("`inputs` cannot be empty")]
|
||||||
EmptyInput,
|
EmptyInput,
|
||||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
|
||||||
StopSequence(usize, usize),
|
StopSequence(usize, usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
text_generation/__pycache__/
|
text_generation_server/__pycache__/
|
||||||
text_generation/pb/__pycache__/
|
text_generation_server/pb/__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.51.1 --no-cache-dir
|
pip install grpcio-tools==1.51.1 --no-cache-dir
|
||||||
mkdir text_generation/pb || true
|
mkdir text_generation_server/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
|
||||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-transformers:
|
install-transformers:
|
||||||
# Install specific version of transformers with custom cuda kernels
|
# Install specific version of transformers with custom cuda kernels
|
||||||
|
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
|
||||||
pip install -e . --no-cache-dir
|
pip install -e . --no-cache-dir
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
|
@ -1,11 +1,11 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation-server"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
text-generation-server = 'text_generation.cli:app'
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -4,9 +4,9 @@ import torch
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -5,8 +5,8 @@ from copy import copy
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
from text_generation_server.utils.hub import (
|
||||||
|
download_weights,
|
||||||
|
weight_hub_files,
|
||||||
|
weight_files,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation.utils.convert import convert_files
|
from text_generation_server.utils.convert import convert_files
|
||||||
|
|
||||||
|
|
||||||
def test_convert_files():
|
def test_convert_files():
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.utils.hub import (
|
from text_generation_server.utils.hub import (
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
weight_files,
|
weight_files,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from text_generation.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
|
@ -6,8 +6,8 @@ from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation import server, utils
|
from text_generation_server import server, utils
|
||||||
from text_generation.tracing import setup_tracing
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ def serve(
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
filter="text_generation",
|
filter="text_generation_server",
|
||||||
level=logger_level,
|
level=logger_level,
|
||||||
serialize=json_output,
|
serialize=json_output,
|
||||||
backtrace=True,
|
backtrace=True,
|
||||||
|
@ -68,7 +68,7 @@ def download_weights(
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
filter="text_generation",
|
filter="text_generation_server",
|
||||||
level=logger_level,
|
level=logger_level,
|
||||||
serialize=json_output,
|
serialize=json_output,
|
||||||
backtrace=True,
|
backtrace=True,
|
|
@ -3,14 +3,14 @@ import torch
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||||
from text_generation.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
|
@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
Batch,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
PrefillTokens,
|
||||||
|
Generation,
|
||||||
|
GeneratedText,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
|
@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
|
@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
FIM_MIDDLE = "<fim-middle>"
|
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
GeneratedText,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
Batch,
|
||||||
|
Generation,
|
||||||
|
PrefillTokens,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
|
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
|
@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import Seq2SeqLM
|
from text_generation_server.models import Seq2SeqLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
|
@ -6,8 +6,8 @@ from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
|
@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from text_generation.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|
@ -1,6 +1,6 @@
|
||||||
from text_generation.utils.convert import convert_file, convert_files
|
from text_generation_server.utils.convert import convert_file, convert_files
|
||||||
from text_generation.utils.dist import initialize_torch_distributed
|
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||||
from text_generation.utils.hub import (
|
from text_generation_server.utils.hub import (
|
||||||
weight_files,
|
weight_files,
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
|
@ -8,7 +8,7 @@ from text_generation.utils.hub import (
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
)
|
)
|
||||||
from text_generation.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
Greedy,
|
Greedy,
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
Sampling,
|
Sampling,
|
|
@ -11,9 +11,9 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
Loading…
Reference in New Issue