feat(clients): Python client (#103)

This commit is contained in:
OlivierDehaene 2023-03-07 18:52:22 +01:00 committed by GitHub
parent 0e9ed1a8c2
commit 3fef90d50f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 2522 additions and 112 deletions

View File

@ -13,7 +13,7 @@ server-dev:
cd server && make run-dev cd server && make run-dev
router-dev: router-dev:
cd router && cargo run cd router && cargo run -- --port 8080
integration-tests: install-router install-launcher integration-tests: install-router install-launcher
cargo test cargo test
@ -22,16 +22,16 @@ python-tests:
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
run-bloom-560m: run-bloom-560m:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
run-bloom-560m-quantize: run-bloom-560m-quantize:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
download-bloom: download-bloom:
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
run-bloom: run-bloom:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
run-bloom-quantize: run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080

View File

@ -89,40 +89,35 @@ You can then query the model using either the `/generate` or `/generate_stream`
```shell ```shell
curl 127.0.0.1:8080/generate \ curl 127.0.0.1:8080/generate \
-X POST \ -X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
```shell ```shell
curl 127.0.0.1:8080/generate_stream \ curl 127.0.0.1:8080/generate_stream \
-X POST \ -X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
or from Python: or from Python:
```python
import requests
result = requests.post("http://127.0.0.1:8080/generate", json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
print(result.json())
```
```shell ```shell
pip install sseclient-py pip install text-generation
``` ```
````python ```python
import sseclient from text_generation import Client
import requests
r = requests.post("http://127.0.0.1:8080/generate_stream", stream=True, json={"inputs":"Testing API","parameters":{"max_new_tokens":9}}) client = Client("http://127.0.0.1:8080")
sse_client = sseclient.SSEClient(r) print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
for i, event in enumerate(sse_client.events()): text = ""
print(i, event.data) for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
```` if not response.token.special:
text += response.token.text
print(text)
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).

158
clients/python/.gitignore vendored Normal file
View File

@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors

6
clients/python/Makefile Normal file
View File

@ -0,0 +1,6 @@
unit-tests:
python -m pytest --cov=text_generation tests
install:
pip install pip --upgrade
pip install -e .

52
clients/python/README.md Normal file
View File

@ -0,0 +1,52 @@
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
## Get Started
### Install
```shell
pip install text-generation
```
### Usage
```python
from text_generation import InferenceAPIClient
client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```

1038
clients/python/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
[tool.poetry]
name = "text-generation"
version = "0.1.0"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
readme = "README.md"
homepage = "https://github.com/huggingface/text-generation-inference"
repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
python = "^3.7"
pydantic = "^1.10.5"
aiohttp = "^3.8.4"
huggingface-hub = "^0.12.1"
[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
pytest-asyncio = "^0.17.2"
pytest-cov = "^3.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -0,0 +1,56 @@
import pytest
from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def bloom_model():
return "bigscience/bloom"
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture
def fake_model():
return "fake/model"
@pytest.fixture
def unsupported_model():
return "gpt2"
@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"
@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"
@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="text-generation-tests", library_version=__version__
)

View File

@ -0,0 +1,127 @@
import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
)
def test_generate_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
def test_generate_stream_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio
async def test_generate_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
)
@pytest.mark.asyncio
async def test_generate_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
await client.generate("test")
@pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio
async def test_generate_stream_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@pytest.mark.asyncio
async def test_generate_stream_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
async for _ in client.generate_stream("test"):
pass
@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass

View File

@ -0,0 +1,64 @@
from text_generation.errors import (
parse_error,
GenerationError,
IncompleteGenerationError,
OverloadedError,
ValidationError,
BadRequestError,
ShardNotReadyError,
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
)
def test_generation_error():
payload = {"error_type": "generation", "error": "test"}
assert isinstance(parse_error(400, payload), GenerationError)
def test_incomplete_generation_error():
payload = {"error_type": "incomplete_generation", "error": "test"}
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
def test_overloaded_error():
payload = {"error_type": "overloaded", "error": "test"}
assert isinstance(parse_error(400, payload), OverloadedError)
def test_validation_error():
payload = {"error_type": "validation", "error": "test"}
assert isinstance(parse_error(400, payload), ValidationError)
def test_bad_request_error():
payload = {"error": "test"}
assert isinstance(parse_error(400, payload), BadRequestError)
def test_shard_not_ready_error():
payload = {"error": "test"}
assert isinstance(parse_error(403, payload), ShardNotReadyError)
assert isinstance(parse_error(424, payload), ShardNotReadyError)
def test_shard_timeout_error():
payload = {"error": "test"}
assert isinstance(parse_error(504, payload), ShardTimeoutError)
def test_not_found_error():
payload = {"error": "test"}
assert isinstance(parse_error(404, payload), NotFoundError)
def test_rate_limit_exceeded_error():
payload = {"error": "test"}
assert isinstance(parse_error(429, payload), RateLimitExceededError)
def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)

View File

@ -0,0 +1,34 @@
import pytest
from text_generation import (
InferenceAPIClient,
InferenceAPIAsyncClient,
Client,
AsyncClient,
)
from text_generation.errors import NotSupportedError
from text_generation.inference_api import get_supported_models
def test_get_supported_models():
assert isinstance(get_supported_models(), list)
def test_client(bloom_model):
client = InferenceAPIClient(bloom_model)
assert isinstance(client, Client)
def test_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIClient(unsupported_model)
def test_async_client(bloom_model):
client = InferenceAPIAsyncClient(bloom_model)
assert isinstance(client, AsyncClient)
def test_async_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIAsyncClient(unsupported_model)

View File

@ -0,0 +1,39 @@
import pytest
from text_generation.types import Parameters
from text_generation.errors import ValidationError
def test_parameters_validation():
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
with pytest.raises(ValidationError):
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=1)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)

View File

@ -0,0 +1,18 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.2"
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -0,0 +1,415 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
"""Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Iterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
stream=False,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
"""Asynchronous Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response

View File

@ -0,0 +1,106 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardTimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Args:
status_code (`int`):
HTTP status code
payload (`Dict[str, str]`):
Json payload
Returns:
Exception: parsed exception
"""
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return ShardTimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -0,0 +1,150 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
"""
Get the list of supported text-generation models from GitHub
Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
"""
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)

View File

@ -0,0 +1,99 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation.errors import ValidationError
class Parameters(BaseModel):
do_sample: bool = False
max_new_tokens: int = 20
repetition_penalty: Optional[float] = None
return_full_text: bool = False
stop: List[str] = []
seed: Optional[int]
temperature: Optional[float]
top_k: Optional[int]
top_p: Optional[float]
watermark: bool = False
details: bool = False
@validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v is v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
def valid_seed(cls, v):
if v is not None and v is v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v > 1.0):
raise ValidationError("`top_p` must be > 0.0 and <= 1.0")
return v
class Request(BaseModel):
inputs: str
parameters: Parameters
stream: bool = False
class PrefillToken(BaseModel):
id: int
text: str
logprob: Optional[float]
class Token(BaseModel):
id: int
text: str
logprob: float
special: bool
class FinishReason(Enum):
Length = "length"
EndOfSequenceToken = "eos_token"
StopSequence = "stop_sequence"
class Details(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
prefill: List[PrefillToken]
tokens: List[Token]
class StreamDetails(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
class Response(BaseModel):
generated_text: str
details: Details
class StreamResponse(BaseModel):
token: Token
generated_text: Optional[str]
details: Optional[StreamDetails]

View File

@ -439,3 +439,14 @@ pub enum InferError {
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
} }
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
}
}
}

View File

@ -152,8 +152,8 @@ pub(crate) struct Details {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Option<Vec<PrefillToken>>, pub prefill: Vec<PrefillToken>,
pub tokens: Option<Vec<Token>>, pub tokens: Vec<Token>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -185,6 +185,6 @@ pub(crate) struct StreamResponse {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String, pub error: String,
pub error_type: String,
} }

View File

@ -133,8 +133,8 @@ async fn generate(
true => Some(Details { true => Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason), finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill), prefill: response.prefill,
tokens: Some(response.tokens), tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
}), }),
false => None, false => None,
@ -554,6 +554,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
status_code, status_code,
Json(ErrorResponse { Json(ErrorResponse {
error: err.to_string(), error: err.to_string(),
error_type: err.error_type().to_string(),
}), }),
) )
} }
@ -564,6 +565,7 @@ impl From<InferError> for Event {
Event::default() Event::default()
.json_data(ErrorResponse { .json_data(ErrorResponse {
error: err.to_string(), error: err.to_string(),
error_type: err.error_type().to_string(),
}) })
.unwrap() .unwrap()
} }

View File

@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("repetition_penalty must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")] #[error("`top_p` must be > 0.0 and <= 1.0")]
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("`top_k` must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be strictly positive")] #[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens, MaxNewTokens,
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")] #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32), MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")] #[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize), InputLength(usize, usize),
#[error("inputs cannot be empty")] #[error("`inputs` cannot be empty")]
EmptyInput, EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")] #[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
text_generation/__pycache__/ text_generation_server/__pycache__/
text_generation/pb/__pycache__/ text_generation_server/pb/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class

View File

@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers: install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,11 +1,11 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation-server"
version = "0.3.2" version = "0.3.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app' text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"

View File

@ -1,6 +1,6 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@pytest.fixture @pytest.fixture

View File

@ -4,9 +4,9 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -4,8 +4,8 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,8 +1,8 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files from text_generation_server.utils.convert import convert_files
def test_convert_files(): def test_convert_files():

View File

@ -1,6 +1,6 @@
import pytest import pytest
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_hub_files, weight_hub_files,
download_weights, download_weights,
weight_files, weight_files,

View File

@ -1,4 +1,4 @@
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)

View File

@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation import server, utils from text_generation_server import server, utils
from text_generation.tracing import setup_tracing from text_generation_server.tracing import setup_tracing
app = typer.Typer() app = typer.Typer()
@ -42,7 +42,7 @@ def serve(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,
@ -68,7 +68,7 @@ def download_weights(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig from transformers import AutoConfig
from typing import Optional from typing import Optional
from text_generation.models.model import Model from text_generation_server.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
__all__ = [ __all__ = [
"Model", "Model",

View File

@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 Batch,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)

View File

@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,

View File

@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)

View File

@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>" FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>" FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 GeneratedText,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []

View File

@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import Seq2SeqLM from text_generation_server.models import Seq2SeqLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from text_generation.cache import Cache from text_generation_server.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -1,6 +1,6 @@
from text_generation.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
download_weights, download_weights,
@ -8,7 +8,7 @@ from text_generation.utils.hub import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
) )
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
Greedy, Greedy,
NextTokenChooser, NextTokenChooser,
Sampling, Sampling,

View File

@ -11,9 +11,9 @@ from transformers import (
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling: class Sampling: