feat(clients): Python client (#103)
This commit is contained in:
parent
0e9ed1a8c2
commit
3fef90d50f
10
Makefile
10
Makefile
|
@ -13,7 +13,7 @@ server-dev:
|
|||
cd server && make run-dev
|
||||
|
||||
router-dev:
|
||||
cd router && cargo run
|
||||
cd router && cargo run -- --port 8080
|
||||
|
||||
integration-tests: install-router install-launcher
|
||||
cargo test
|
||||
|
@ -22,16 +22,16 @@ python-tests:
|
|||
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
|
||||
|
||||
run-bloom-560m:
|
||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
|
||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
|
||||
|
||||
run-bloom-560m-quantize:
|
||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
|
||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
|
||||
|
||||
download-bloom:
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
|
||||
|
||||
run-bloom:
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
|
||||
|
||||
run-bloom-quantize:
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
31
README.md
31
README.md
|
@ -89,40 +89,35 @@ You can then query the model using either the `/generate` or `/generate_stream`
|
|||
```shell
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
```shell
|
||||
curl 127.0.0.1:8080/generate_stream \
|
||||
-X POST \
|
||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
or from Python:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
result = requests.post("http://127.0.0.1:8080/generate", json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
|
||||
print(result.json())
|
||||
```
|
||||
|
||||
```shell
|
||||
pip install sseclient-py
|
||||
pip install text-generation
|
||||
```
|
||||
|
||||
````python
|
||||
import sseclient
|
||||
import requests
|
||||
```python
|
||||
from text_generation import Client
|
||||
|
||||
r = requests.post("http://127.0.0.1:8080/generate_stream", stream=True, json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
|
||||
sse_client = sseclient.SSEClient(r)
|
||||
client = Client("http://127.0.0.1:8080")
|
||||
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
|
||||
|
||||
for i, event in enumerate(sse_client.events()):
|
||||
print(i, event.data)
|
||||
````
|
||||
text = ""
|
||||
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
|
||||
if not response.token.special:
|
||||
text += response.token.text
|
||||
print(text)
|
||||
```
|
||||
|
||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
|
||||
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
text_generation/__pycache__/
|
||||
text_generation/pb/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
transformers
|
||||
safetensors
|
|
@ -0,0 +1,6 @@
|
|||
unit-tests:
|
||||
python -m pytest --cov=text_generation tests
|
||||
|
||||
install:
|
||||
pip install pip --upgrade
|
||||
pip install -e .
|
|
@ -0,0 +1,52 @@
|
|||
# Text Generation
|
||||
|
||||
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
|
||||
|
||||
## Get Started
|
||||
|
||||
### Install
|
||||
|
||||
```shell
|
||||
pip install text-generation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from text_generation import InferenceAPIClient
|
||||
|
||||
client = InferenceAPIClient("bigscience/bloomz")
|
||||
text = client.generate("Why is the sky blue?").generated_text
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
|
||||
# Token Streaming
|
||||
text = ""
|
||||
for response in client.generate_stream("Why is the sky blue?"):
|
||||
if not response.token.special:
|
||||
text += response.token.text
|
||||
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
```
|
||||
|
||||
or with the asynchronous client:
|
||||
|
||||
```python
|
||||
from text_generation import InferenceAPIAsyncClient
|
||||
|
||||
client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||
response = await client.generate("Why is the sky blue?")
|
||||
print(response.generated_text)
|
||||
# ' Rayleigh scattering'
|
||||
|
||||
# Token Streaming
|
||||
text = ""
|
||||
async for response in client.generate_stream("Why is the sky blue?"):
|
||||
if not response.token.special:
|
||||
text += response.token.text
|
||||
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
```
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,26 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.1.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
readme = "README.md"
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
repository = "https://github.com/huggingface/text-generation-inference"
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
pydantic = "^1.10.5"
|
||||
aiohttp = "^3.8.4"
|
||||
huggingface-hub = "^0.12.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2.5"
|
||||
pytest-asyncio = "^0.17.2"
|
||||
pytest-cov = "^3.0.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
|
||||
from text_generation import __version__
|
||||
from huggingface_hub.utils import build_hf_headers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bloom_model():
|
||||
return "bigscience/bloom"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_xxl():
|
||||
return "google/flan-t5-xxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_model():
|
||||
return "fake/model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_model():
|
||||
return "gpt2"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_url():
|
||||
return "https://api-inference.huggingface.co/models"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bloom_url(base_url, bloom_model):
|
||||
return f"{base_url}/{bloom_model}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||
return f"{base_url}/{flan_t5_xxl}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_url(base_url, fake_model):
|
||||
return f"{base_url}/{fake_model}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_url(base_url, unsupported_model):
|
||||
return f"{base_url}/{unsupported_model}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_headers():
|
||||
return build_hf_headers(
|
||||
library_name="text-generation-tests", library_version=__version__
|
||||
)
|
|
@ -0,0 +1,127 @@
|
|||
import pytest
|
||||
|
||||
from text_generation import Client, AsyncClient
|
||||
from text_generation.errors import NotFoundError, ValidationError
|
||||
from text_generation.types import FinishReason, PrefillToken, Token
|
||||
|
||||
|
||||
def test_generate(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
response = client.generate("test", max_new_tokens=1)
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=9234, text="test", logprob=None
|
||||
)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=17, text=".", logprob=-1.75, special=False
|
||||
)
|
||||
|
||||
|
||||
def test_generate_not_found(fake_url, hf_headers):
|
||||
client = Client(fake_url, hf_headers)
|
||||
with pytest.raises(NotFoundError):
|
||||
client.generate("test")
|
||||
|
||||
|
||||
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
|
||||
def test_generate_stream(bloom_url, hf_headers):
|
||||
client = Client(bloom_url, hf_headers)
|
||||
responses = [
|
||||
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
|
||||
|
||||
def test_generate_stream_not_found(fake_url, hf_headers):
|
||||
client = Client(fake_url, hf_headers)
|
||||
with pytest.raises(NotFoundError):
|
||||
list(client.generate_stream("test"))
|
||||
|
||||
|
||||
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = Client(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
response = await client.generate("test", max_new_tokens=1)
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
assert len(response.details.prefill) == 1
|
||||
assert response.details.prefill[0] == PrefillToken(
|
||||
id=9234, text="test", logprob=None
|
||||
)
|
||||
assert len(response.details.tokens) == 1
|
||||
assert response.details.tokens[0] == Token(
|
||||
id=17, text=".", logprob=-1.75, special=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async_not_found(fake_url, hf_headers):
|
||||
client = AsyncClient(fake_url, hf_headers)
|
||||
with pytest.raises(NotFoundError):
|
||||
await client.generate("test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
await client.generate("test", max_new_tokens=10_000)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async(bloom_url, hf_headers):
|
||||
client = AsyncClient(bloom_url, hf_headers)
|
||||
responses = [
|
||||
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||
]
|
||||
|
||||
assert len(responses) == 1
|
||||
response = responses[0]
|
||||
|
||||
assert response.generated_text == "."
|
||||
assert response.details.finish_reason == FinishReason.Length
|
||||
assert response.details.generated_tokens == 1
|
||||
assert response.details.seed is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
||||
client = AsyncClient(fake_url, hf_headers)
|
||||
with pytest.raises(NotFoundError):
|
||||
async for _ in client.generate_stream("test"):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||
with pytest.raises(ValidationError):
|
||||
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||
pass
|
|
@ -0,0 +1,64 @@
|
|||
from text_generation.errors import (
|
||||
parse_error,
|
||||
GenerationError,
|
||||
IncompleteGenerationError,
|
||||
OverloadedError,
|
||||
ValidationError,
|
||||
BadRequestError,
|
||||
ShardNotReadyError,
|
||||
ShardTimeoutError,
|
||||
NotFoundError,
|
||||
RateLimitExceededError,
|
||||
UnknownError,
|
||||
)
|
||||
|
||||
|
||||
def test_generation_error():
|
||||
payload = {"error_type": "generation", "error": "test"}
|
||||
assert isinstance(parse_error(400, payload), GenerationError)
|
||||
|
||||
|
||||
def test_incomplete_generation_error():
|
||||
payload = {"error_type": "incomplete_generation", "error": "test"}
|
||||
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
|
||||
|
||||
|
||||
def test_overloaded_error():
|
||||
payload = {"error_type": "overloaded", "error": "test"}
|
||||
assert isinstance(parse_error(400, payload), OverloadedError)
|
||||
|
||||
|
||||
def test_validation_error():
|
||||
payload = {"error_type": "validation", "error": "test"}
|
||||
assert isinstance(parse_error(400, payload), ValidationError)
|
||||
|
||||
|
||||
def test_bad_request_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(400, payload), BadRequestError)
|
||||
|
||||
|
||||
def test_shard_not_ready_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(403, payload), ShardNotReadyError)
|
||||
assert isinstance(parse_error(424, payload), ShardNotReadyError)
|
||||
|
||||
|
||||
def test_shard_timeout_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(504, payload), ShardTimeoutError)
|
||||
|
||||
|
||||
def test_not_found_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(404, payload), NotFoundError)
|
||||
|
||||
|
||||
def test_rate_limit_exceeded_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(429, payload), RateLimitExceededError)
|
||||
|
||||
|
||||
def test_unknown_error():
|
||||
payload = {"error": "test"}
|
||||
assert isinstance(parse_error(500, payload), UnknownError)
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from text_generation import (
|
||||
InferenceAPIClient,
|
||||
InferenceAPIAsyncClient,
|
||||
Client,
|
||||
AsyncClient,
|
||||
)
|
||||
from text_generation.errors import NotSupportedError
|
||||
from text_generation.inference_api import get_supported_models
|
||||
|
||||
|
||||
def test_get_supported_models():
|
||||
assert isinstance(get_supported_models(), list)
|
||||
|
||||
|
||||
def test_client(bloom_model):
|
||||
client = InferenceAPIClient(bloom_model)
|
||||
assert isinstance(client, Client)
|
||||
|
||||
|
||||
def test_client_unsupported_model(unsupported_model):
|
||||
with pytest.raises(NotSupportedError):
|
||||
InferenceAPIClient(unsupported_model)
|
||||
|
||||
|
||||
def test_async_client(bloom_model):
|
||||
client = InferenceAPIAsyncClient(bloom_model)
|
||||
assert isinstance(client, AsyncClient)
|
||||
|
||||
|
||||
def test_async_client_unsupported_model(unsupported_model):
|
||||
with pytest.raises(NotSupportedError):
|
||||
InferenceAPIAsyncClient(unsupported_model)
|
|
@ -0,0 +1,39 @@
|
|||
import pytest
|
||||
|
||||
from text_generation.types import Parameters
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
|
||||
def test_parameters_validation():
|
||||
# Test repetition_penalty
|
||||
Parameters(repetition_penalty=1)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(repetition_penalty=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(repetition_penalty=-1)
|
||||
|
||||
# Test seed
|
||||
Parameters(seed=1)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(seed=-1)
|
||||
|
||||
# Test temperature
|
||||
Parameters(temperature=1)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(temperature=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(temperature=-1)
|
||||
|
||||
# Test top_k
|
||||
Parameters(top_k=1)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(top_k=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(top_k=-1)
|
||||
|
||||
# Test top_p
|
||||
Parameters(top_p=1)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(top_p=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Parameters(top_p=-1)
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.3.2"
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
|
@ -0,0 +1,415 @@
|
|||
import json
|
||||
import requests
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
Response,
|
||||
Request,
|
||||
Parameters,
|
||||
)
|
||||
from text_generation.errors import parse_error
|
||||
|
||||
|
||||
class Client:
|
||||
"""Client to make calls to a text-generation-inference instance
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from text_generation import Client
|
||||
|
||||
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||
>>> client.generate("Why is the sky blue?").generated_text
|
||||
' Rayleigh scattering'
|
||||
|
||||
>>> result = ""
|
||||
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||
>>> if not response.token.special:
|
||||
>>> result += response.token.text
|
||||
>>> result
|
||||
' Rayleigh scattering'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
base_url (`str`):
|
||||
text-generation-inference instance base url
|
||||
headers (`Optional[Dict[str, str]]`):
|
||||
Additional headers
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.timeout = timeout
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Input text
|
||||
do_sample (`bool`):
|
||||
Activate logits sampling
|
||||
max_new_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stop_sequences (`List[str]`):
|
||||
Stop generating tokens if a member of `stop_sequences` is generated
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
"""
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
resp = requests.post(
|
||||
self.base_url,
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return Response(**payload[0])
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Input text
|
||||
do_sample (`bool`):
|
||||
Activate logits sampling
|
||||
max_new_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stop_sequences (`List[str]`):
|
||||
Stop generating tokens if a member of `stop_sequences` is generated
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
Iterator[StreamResponse]: stream of generated tokens
|
||||
"""
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
resp = requests.post(
|
||||
self.base_url,
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, resp.json())
|
||||
|
||||
# Parse ServerSentEvents
|
||||
for byte_payload in resp.iter_lines():
|
||||
# Skip line
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
|
||||
payload = byte_payload.decode("utf-8")
|
||||
|
||||
# Event data
|
||||
if payload.startswith("data:"):
|
||||
# Decode payload
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
# Parse payload
|
||||
try:
|
||||
response = StreamResponse(**json_payload)
|
||||
except ValidationError:
|
||||
# If we failed to parse the payload, then it is an error payload
|
||||
raise parse_error(resp.status_code, json_payload)
|
||||
yield response
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
"""Asynchronous Client to make calls to a text-generation-inference instance
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from text_generation import AsyncClient
|
||||
|
||||
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||
>>> response = await client.generate("Why is the sky blue?")
|
||||
>>> response.generated_text
|
||||
' Rayleigh scattering'
|
||||
|
||||
>>> result = ""
|
||||
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||
>>> if not response.token.special:
|
||||
>>> result += response.token.text
|
||||
>>> result
|
||||
' Rayleigh scattering'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
base_url (`str`):
|
||||
text-generation-inference instance base url
|
||||
headers (`Optional[Dict[str, str]]`):
|
||||
Additional headers
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.timeout = ClientTimeout(timeout * 60)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Input text
|
||||
do_sample (`bool`):
|
||||
Activate logits sampling
|
||||
max_new_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stop_sequences (`List[str]`):
|
||||
Stop generating tokens if a member of `stop_sequences` is generated
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
"""
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
payload = await resp.json()
|
||||
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return Response(**payload[0])
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Input text
|
||||
do_sample (`bool`):
|
||||
Activate logits sampling
|
||||
max_new_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
return_full_text (`bool`):
|
||||
Whether to prepend the prompt to the generated text
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stop_sequences (`List[str]`):
|
||||
Stop generating tokens if a member of `stop_sequences` is generated
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
"""
|
||||
# Validate parameters
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop_sequences if stop_sequences is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, await resp.json())
|
||||
|
||||
# Parse ServerSentEvents
|
||||
async for byte_payload in resp.content:
|
||||
# Skip line
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
|
||||
payload = byte_payload.decode("utf-8")
|
||||
|
||||
# Event data
|
||||
if payload.startswith("data:"):
|
||||
# Decode payload
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
# Parse payload
|
||||
try:
|
||||
response = StreamResponse(**json_payload)
|
||||
except ValidationError:
|
||||
# If we failed to parse the payload, then it is an error payload
|
||||
raise parse_error(resp.status, json_payload)
|
||||
yield response
|
|
@ -0,0 +1,106 @@
|
|||
from typing import Dict
|
||||
|
||||
|
||||
# Text Generation Inference Errors
|
||||
class ValidationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GenerationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OverloadedError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class IncompleteGenerationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# API Inference Errors
|
||||
class BadRequestError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ShardNotReadyError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ShardTimeoutError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class NotFoundError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RateLimitExceededError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class NotSupportedError(Exception):
|
||||
def __init__(self, model_id: str):
|
||||
message = (
|
||||
f"Model `{model_id}` is not available for inference with this client. \n"
|
||||
"Use `huggingface_hub.inference_api.InferenceApi` instead."
|
||||
)
|
||||
super(NotSupportedError, self).__init__(message)
|
||||
|
||||
|
||||
# Unknown error
|
||||
class UnknownError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
|
||||
"""
|
||||
Parse error given an HTTP status code and a json payload
|
||||
|
||||
Args:
|
||||
status_code (`int`):
|
||||
HTTP status code
|
||||
payload (`Dict[str, str]`):
|
||||
Json payload
|
||||
|
||||
Returns:
|
||||
Exception: parsed exception
|
||||
|
||||
"""
|
||||
# Try to parse a Text Generation Inference error
|
||||
message = payload["error"]
|
||||
if "error_type" in payload:
|
||||
error_type = payload["error_type"]
|
||||
if error_type == "generation":
|
||||
return GenerationError(message)
|
||||
if error_type == "incomplete_generation":
|
||||
return IncompleteGenerationError(message)
|
||||
if error_type == "overloaded":
|
||||
return OverloadedError(message)
|
||||
if error_type == "validation":
|
||||
return ValidationError(message)
|
||||
|
||||
# Try to parse a APIInference error
|
||||
if status_code == 400:
|
||||
return BadRequestError(message)
|
||||
if status_code == 403 or status_code == 424:
|
||||
return ShardNotReadyError(message)
|
||||
if status_code == 504:
|
||||
return ShardTimeoutError(message)
|
||||
if status_code == 404:
|
||||
return NotFoundError(message)
|
||||
if status_code == 429:
|
||||
return RateLimitExceededError(message)
|
||||
|
||||
# Fallback to an unknown error
|
||||
return UnknownError(message)
|
|
@ -0,0 +1,150 @@
|
|||
import os
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional
|
||||
from huggingface_hub.utils import build_hf_headers
|
||||
|
||||
from text_generation import Client, AsyncClient, __version__
|
||||
from text_generation.errors import NotSupportedError
|
||||
|
||||
INFERENCE_ENDPOINT = os.environ.get(
|
||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||
)
|
||||
|
||||
SUPPORTED_MODELS = None
|
||||
|
||||
|
||||
def get_supported_models() -> Optional[List[str]]:
|
||||
"""
|
||||
Get the list of supported text-generation models from GitHub
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
|
||||
"""
|
||||
global SUPPORTED_MODELS
|
||||
if SUPPORTED_MODELS is not None:
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
file_content = response.json()["content"]
|
||||
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
warnings.warn("Could not retrieve list of supported models.")
|
||||
return None
|
||||
|
||||
|
||||
class InferenceAPIClient(Client):
|
||||
"""Client to make calls to the HuggingFace Inference API.
|
||||
|
||||
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||
text-generation-inference
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from text_generation import InferenceAPIClient
|
||||
|
||||
>>> client = InferenceAPIClient("bigscience/bloomz")
|
||||
>>> client.generate("Why is the sky blue?").generated_text
|
||||
' Rayleigh scattering'
|
||||
|
||||
>>> result = ""
|
||||
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||
>>> if not response.token.special:
|
||||
>>> result += response.token.text
|
||||
>>> result
|
||||
' Rayleigh scattering'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||
"""
|
||||
Init headers and API information
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
Id of repository (e.g. `bigscience/bloom`).
|
||||
token (`str`, `optional`):
|
||||
The API token to use as HTTP bearer authorization. This is not
|
||||
the authentication token. You can find the token in
|
||||
https://huggingface.co/settings/token. Alternatively, you can
|
||||
find both your organizations and personal API tokens using
|
||||
`HfApi().whoami(token)`.
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and repo_id not in supported_models:
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
token=token, library_name="text-generation", library_version=__version__
|
||||
)
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||
|
||||
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
|
||||
|
||||
|
||||
class InferenceAPIAsyncClient(AsyncClient):
|
||||
"""Aynschronous Client to make calls to the HuggingFace Inference API.
|
||||
|
||||
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||
text-generation-inference
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from text_generation import InferenceAPIAsyncClient
|
||||
|
||||
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||
>>> response = await client.generate("Why is the sky blue?")
|
||||
>>> response.generated_text
|
||||
' Rayleigh scattering'
|
||||
|
||||
>>> result = ""
|
||||
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||
>>> if not response.token.special:
|
||||
>>> result += response.token.text
|
||||
>>> result
|
||||
' Rayleigh scattering'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||
"""
|
||||
Init headers and API information
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
Id of repository (e.g. `bigscience/bloom`).
|
||||
token (`str`, `optional`):
|
||||
The API token to use as HTTP bearer authorization. This is not
|
||||
the authentication token. You can find the token in
|
||||
https://huggingface.co/settings/token. Alternatively, you can
|
||||
find both your organizations and personal API tokens using
|
||||
`HfApi().whoami(token)`.
|
||||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and repo_id not in supported_models:
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
token=token, library_name="text-generation", library_version=__version__
|
||||
)
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||
|
||||
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)
|
|
@ -0,0 +1,99 @@
|
|||
from enum import Enum
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import Optional, List
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
do_sample: bool = False
|
||||
max_new_tokens: int = 20
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: bool = False
|
||||
stop: List[str] = []
|
||||
seed: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_k: Optional[int]
|
||||
top_p: Optional[float]
|
||||
watermark: bool = False
|
||||
details: bool = False
|
||||
|
||||
@validator("repetition_penalty")
|
||||
def valid_repetition_penalty(cls, v):
|
||||
if v is not None and v is v <= 0:
|
||||
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("seed")
|
||||
def valid_seed(cls, v):
|
||||
if v is not None and v is v < 0:
|
||||
raise ValidationError("`seed` must be positive")
|
||||
return v
|
||||
|
||||
@validator("temperature")
|
||||
def valid_temp(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`temperature` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("top_k")
|
||||
def valid_top_k(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_k` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("top_p")
|
||||
def valid_top_p(cls, v):
|
||||
if v is not None and (v <= 0 or v > 1.0):
|
||||
raise ValidationError("`top_p` must be > 0.0 and <= 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
inputs: str
|
||||
parameters: Parameters
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class PrefillToken(BaseModel):
|
||||
id: int
|
||||
text: str
|
||||
logprob: Optional[float]
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
id: int
|
||||
text: str
|
||||
logprob: float
|
||||
special: bool
|
||||
|
||||
|
||||
class FinishReason(Enum):
|
||||
Length = "length"
|
||||
EndOfSequenceToken = "eos_token"
|
||||
StopSequence = "stop_sequence"
|
||||
|
||||
|
||||
class Details(BaseModel):
|
||||
finish_reason: FinishReason
|
||||
generated_tokens: int
|
||||
seed: Optional[int]
|
||||
prefill: List[PrefillToken]
|
||||
tokens: List[Token]
|
||||
|
||||
|
||||
class StreamDetails(BaseModel):
|
||||
finish_reason: FinishReason
|
||||
generated_tokens: int
|
||||
seed: Optional[int]
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
generated_text: str
|
||||
details: Details
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
token: Token
|
||||
generated_text: Optional[str]
|
||||
details: Optional[StreamDetails]
|
|
@ -439,3 +439,14 @@ pub enum InferError {
|
|||
#[error("Incomplete generation")]
|
||||
IncompleteGeneration,
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
pub(crate) fn error_type(&self) -> &str {
|
||||
match self {
|
||||
InferError::GenerationError(_) => "generation",
|
||||
InferError::Overloaded(_) => "overloaded",
|
||||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -152,8 +152,8 @@ pub(crate) struct Details {
|
|||
pub generated_tokens: u32,
|
||||
#[schema(example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
pub prefill: Option<Vec<PrefillToken>>,
|
||||
pub tokens: Option<Vec<Token>>,
|
||||
pub prefill: Vec<PrefillToken>,
|
||||
pub tokens: Vec<Token>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -185,6 +185,6 @@ pub(crate) struct StreamResponse {
|
|||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct ErrorResponse {
|
||||
#[schema(inline)]
|
||||
pub error: String,
|
||||
pub error_type: String,
|
||||
}
|
||||
|
|
|
@ -133,8 +133,8 @@ async fn generate(
|
|||
true => Some(Details {
|
||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
prefill: Some(response.prefill),
|
||||
tokens: Some(response.tokens),
|
||||
prefill: response.prefill,
|
||||
tokens: response.tokens,
|
||||
seed: response.generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
|
@ -554,6 +554,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||
status_code,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -564,6 +565,7 @@ impl From<InferError> for Event {
|
|||
Event::default()
|
||||
.json_data(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
|
|
@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
|
|||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("temperature must be strictly positive")]
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("repetition_penalty must be strictly positive")]
|
||||
#[error("`repetition_penalty` must be strictly positive")]
|
||||
RepetitionPenalty,
|
||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
||||
#[error("`top_p` must be > 0.0 and <= 1.0")]
|
||||
TopP,
|
||||
#[error("top_k must be strictly positive")]
|
||||
#[error("`top_k` must be strictly positive")]
|
||||
TopK,
|
||||
#[error("max_new_tokens must be strictly positive")]
|
||||
#[error("`max_new_tokens` must be strictly positive")]
|
||||
MaxNewTokens,
|
||||
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
|
||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||
MaxTotalTokens(usize, usize, u32),
|
||||
#[error("inputs must have less than {0} tokens. Given: {1}")]
|
||||
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||
InputLength(usize, usize),
|
||||
#[error("inputs cannot be empty")]
|
||||
#[error("`inputs` cannot be empty")]
|
||||
EmptyInput,
|
||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
||||
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
|
||||
StopSequence(usize, usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
Tokenizer(String),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
text_generation/__pycache__/
|
||||
text_generation/pb/__pycache__/
|
||||
text_generation_server/__pycache__/
|
||||
text_generation_server/pb/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
|
|||
gen-server:
|
||||
# Compile protos
|
||||
pip install grpcio-tools==1.51.1 --no-cache-dir
|
||||
mkdir text_generation/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation/pb/__init__.py
|
||||
mkdir text_generation_server/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
|
||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-transformers:
|
||||
# Install specific version of transformers with custom cuda kernels
|
||||
|
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
|
|||
pip install -e . --no-cache-dir
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
|
@ -1,11 +1,11 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
name = "text-generation-server"
|
||||
version = "0.3.2"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
text-generation-server = 'text_generation.cli:app'
|
||||
text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -4,9 +4,9 @@ import torch
|
|||
from copy import copy
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
|||
from copy import copy
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.models.santacoder import SantaCoder
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -5,8 +5,8 @@ from copy import copy
|
|||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
||||
from text_generation_server.utils.hub import (
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
weight_files,
|
||||
)
|
||||
|
||||
from text_generation.utils.convert import convert_files
|
||||
from text_generation_server.utils.convert import convert_files
|
||||
|
||||
|
||||
def test_convert_files():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from text_generation.utils.hub import (
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from text_generation.utils.tokens import (
|
||||
from text_generation_server.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation.models.types import Batch
|
||||
from text_generation_server.models.types import Batch
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
|
@ -6,8 +6,8 @@ from pathlib import Path
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from text_generation import server, utils
|
||||
from text_generation.tracing import setup_tracing
|
||||
from text_generation_server import server, utils
|
||||
from text_generation_server.tracing import setup_tracing
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
@ -42,7 +42,7 @@ def serve(
|
|||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
||||
|
@ -68,7 +68,7 @@ def download_weights(
|
|||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
|
@ -3,14 +3,14 @@ import torch
|
|||
from transformers import AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation.models.model import Model
|
||||
from text_generation.models.causal_lm import CausalLM
|
||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation.models.santacoder import SantaCoder
|
||||
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||
from text_generation.models.t5 import T5Sharded
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
|
@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
|
|||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
|
@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
|||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
initialize_torch_distributed,
|
|
@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
|
|||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.models.types import Batch, GeneratedText
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
padding_right_offset: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
|
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
|
@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
|
|||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import Seq2SeqLM
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
|
@ -6,8 +6,8 @@ from typing import List, Optional
|
|||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
|
||||
|
||||
class Batch(ABC):
|
|
@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
|
|||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation.cache import Cache
|
||||
from text_generation.interceptor import ExceptionInterceptor
|
||||
from text_generation.models import Model, get_model
|
||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|
@ -1,6 +1,6 @@
|
|||
from text_generation.utils.convert import convert_file, convert_files
|
||||
from text_generation.utils.dist import initialize_torch_distributed
|
||||
from text_generation.utils.hub import (
|
||||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_files,
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
|
@ -8,7 +8,7 @@ from text_generation.utils.hub import (
|
|||
LocalEntryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from text_generation.utils.tokens import (
|
||||
from text_generation_server.utils.tokens import (
|
||||
Greedy,
|
||||
NextTokenChooser,
|
||||
Sampling,
|
|
@ -11,9 +11,9 @@ from transformers import (
|
|||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
|
||||
|
||||
class Sampling:
|
Loading…
Reference in New Issue