feat(clients): Python client (#103)

This commit is contained in:
OlivierDehaene 2023-03-07 18:52:22 +01:00 committed by GitHub
parent 0e9ed1a8c2
commit 3fef90d50f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 2522 additions and 112 deletions

View File

@ -13,7 +13,7 @@ server-dev:
cd server && make run-dev
router-dev:
cd router && cargo run
cd router && cargo run -- --port 8080
integration-tests: install-router install-launcher
cargo test
@ -22,16 +22,16 @@ python-tests:
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
run-bloom-560m:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
run-bloom-560m-quantize:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
download-bloom:
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
run-bloom:
text-generation-launcher --model-id bigscience/bloom --num-shard 8
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080

View File

@ -89,40 +89,35 @@ You can then query the model using either the `/generate` or `/generate_stream`
```shell
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json'
```
```shell
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json'
```
or from Python:
```python
import requests
result = requests.post("http://127.0.0.1:8080/generate", json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
print(result.json())
```
```shell
pip install sseclient-py
pip install text-generation
```
````python
import sseclient
import requests
```python
from text_generation import Client
r = requests.post("http://127.0.0.1:8080/generate_stream", stream=True, json={"inputs":"Testing API","parameters":{"max_new_tokens":9}})
sse_client = sseclient.SSEClient(r)
client = Client("http://127.0.0.1:8080")
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
for i, event in enumerate(sse_client.events()):
print(i, event.data)
````
text = ""
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
if not response.token.special:
text += response.token.text
print(text)
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).

158
clients/python/.gitignore vendored Normal file
View File

@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors

6
clients/python/Makefile Normal file
View File

@ -0,0 +1,6 @@
unit-tests:
python -m pytest --cov=text_generation tests
install:
pip install pip --upgrade
pip install -e .

52
clients/python/README.md Normal file
View File

@ -0,0 +1,52 @@
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
## Get Started
### Install
```shell
pip install text-generation
```
### Usage
```python
from text_generation import InferenceAPIClient
client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```

1038
clients/python/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
[tool.poetry]
name = "text-generation"
version = "0.1.0"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
readme = "README.md"
homepage = "https://github.com/huggingface/text-generation-inference"
repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
python = "^3.7"
pydantic = "^1.10.5"
aiohttp = "^3.8.4"
huggingface-hub = "^0.12.1"
[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
pytest-asyncio = "^0.17.2"
pytest-cov = "^3.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -0,0 +1,56 @@
import pytest
from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def bloom_model():
return "bigscience/bloom"
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture
def fake_model():
return "fake/model"
@pytest.fixture
def unsupported_model():
return "gpt2"
@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"
@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"
@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="text-generation-tests", library_version=__version__
)

View File

@ -0,0 +1,127 @@
import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
)
def test_generate_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(bloom_url, hf_headers):
client = Client(bloom_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
def test_generate_stream_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio
async def test_generate_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False
)
@pytest.mark.asyncio
async def test_generate_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
await client.generate("test")
@pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio
async def test_generate_stream_async(bloom_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == "."
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@pytest.mark.asyncio
async def test_generate_stream_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
async for _ in client.generate_stream("test"):
pass
@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass

View File

@ -0,0 +1,64 @@
from text_generation.errors import (
parse_error,
GenerationError,
IncompleteGenerationError,
OverloadedError,
ValidationError,
BadRequestError,
ShardNotReadyError,
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
)
def test_generation_error():
payload = {"error_type": "generation", "error": "test"}
assert isinstance(parse_error(400, payload), GenerationError)
def test_incomplete_generation_error():
payload = {"error_type": "incomplete_generation", "error": "test"}
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
def test_overloaded_error():
payload = {"error_type": "overloaded", "error": "test"}
assert isinstance(parse_error(400, payload), OverloadedError)
def test_validation_error():
payload = {"error_type": "validation", "error": "test"}
assert isinstance(parse_error(400, payload), ValidationError)
def test_bad_request_error():
payload = {"error": "test"}
assert isinstance(parse_error(400, payload), BadRequestError)
def test_shard_not_ready_error():
payload = {"error": "test"}
assert isinstance(parse_error(403, payload), ShardNotReadyError)
assert isinstance(parse_error(424, payload), ShardNotReadyError)
def test_shard_timeout_error():
payload = {"error": "test"}
assert isinstance(parse_error(504, payload), ShardTimeoutError)
def test_not_found_error():
payload = {"error": "test"}
assert isinstance(parse_error(404, payload), NotFoundError)
def test_rate_limit_exceeded_error():
payload = {"error": "test"}
assert isinstance(parse_error(429, payload), RateLimitExceededError)
def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)

View File

@ -0,0 +1,34 @@
import pytest
from text_generation import (
InferenceAPIClient,
InferenceAPIAsyncClient,
Client,
AsyncClient,
)
from text_generation.errors import NotSupportedError
from text_generation.inference_api import get_supported_models
def test_get_supported_models():
assert isinstance(get_supported_models(), list)
def test_client(bloom_model):
client = InferenceAPIClient(bloom_model)
assert isinstance(client, Client)
def test_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIClient(unsupported_model)
def test_async_client(bloom_model):
client = InferenceAPIAsyncClient(bloom_model)
assert isinstance(client, AsyncClient)
def test_async_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIAsyncClient(unsupported_model)

View File

@ -0,0 +1,39 @@
import pytest
from text_generation.types import Parameters
from text_generation.errors import ValidationError
def test_parameters_validation():
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
with pytest.raises(ValidationError):
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=1)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)

View File

@ -0,0 +1,18 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.2"
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -0,0 +1,415 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
"""Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Iterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
stream=False,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
"""Asynchronous Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermarking: bool = False,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermarking (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermarking,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response

View File

@ -0,0 +1,106 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardTimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Args:
status_code (`int`):
HTTP status code
payload (`Dict[str, str]`):
Json payload
Returns:
Exception: parsed exception
"""
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return ShardTimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -0,0 +1,150 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
"""
Get the list of supported text-generation models from GitHub
Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
"""
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)

View File

@ -0,0 +1,99 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation.errors import ValidationError
class Parameters(BaseModel):
do_sample: bool = False
max_new_tokens: int = 20
repetition_penalty: Optional[float] = None
return_full_text: bool = False
stop: List[str] = []
seed: Optional[int]
temperature: Optional[float]
top_k: Optional[int]
top_p: Optional[float]
watermark: bool = False
details: bool = False
@validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v is v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
def valid_seed(cls, v):
if v is not None and v is v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v > 1.0):
raise ValidationError("`top_p` must be > 0.0 and <= 1.0")
return v
class Request(BaseModel):
inputs: str
parameters: Parameters
stream: bool = False
class PrefillToken(BaseModel):
id: int
text: str
logprob: Optional[float]
class Token(BaseModel):
id: int
text: str
logprob: float
special: bool
class FinishReason(Enum):
Length = "length"
EndOfSequenceToken = "eos_token"
StopSequence = "stop_sequence"
class Details(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
prefill: List[PrefillToken]
tokens: List[Token]
class StreamDetails(BaseModel):
finish_reason: FinishReason
generated_tokens: int
seed: Optional[int]
class Response(BaseModel):
generated_text: str
details: Details
class StreamResponse(BaseModel):
token: Token
generated_text: Optional[str]
details: Optional[StreamDetails]

View File

@ -439,3 +439,14 @@ pub enum InferError {
#[error("Incomplete generation")]
IncompleteGeneration,
}
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
}
}
}

View File

@ -152,8 +152,8 @@ pub(crate) struct Details {
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
pub prefill: Option<Vec<PrefillToken>>,
pub tokens: Option<Vec<Token>>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)]
@ -185,6 +185,6 @@ pub(crate) struct StreamResponse {
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String,
pub error_type: String,
}

View File

@ -133,8 +133,8 @@ async fn generate(
true => Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill),
tokens: Some(response.tokens),
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}),
false => None,
@ -554,6 +554,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
status_code,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
)
}
@ -564,6 +565,7 @@ impl From<InferError> for Event {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
})
.unwrap()
}

View File

@ -271,23 +271,23 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("temperature must be strictly positive")]
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("repetition_penalty must be strictly positive")]
#[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")]
#[error("`top_p` must be > 0.0 and <= 1.0")]
TopP,
#[error("top_k must be strictly positive")]
#[error("`top_k` must be strictly positive")]
TopK,
#[error("max_new_tokens must be strictly positive")]
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")]
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize),
#[error("inputs cannot be empty")]
#[error("`inputs` cannot be empty")]
EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")]
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize),
#[error("tokenizer error {0}")]
Tokenizer(String),

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class

View File

@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,11 +1,11 @@
[tool.poetry]
name = "text-generation"
name = "text-generation-server"
version = "0.3.2"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app'
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"

View File

@ -1,6 +1,6 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation_server.pb import generate_pb2
@pytest.fixture

View File

@ -4,9 +4,9 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")

View File

@ -4,8 +4,8 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")

View File

@ -1,8 +1,8 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session")

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files
from text_generation_server.utils.convert import convert_files
def test_convert_files():

View File

@ -1,6 +1,6 @@
import pytest
from text_generation.utils.hub import (
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,

View File

@ -1,4 +1,4 @@
from text_generation.utils.tokens import (
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)

View File

@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
@ -42,7 +42,7 @@ def serve(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
@ -68,7 +68,7 @@ def download_weights(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig
from typing import Optional
from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
__all__ = [
"Model",

View File

@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)

View File

@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,

View File

@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)

View File

@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []

View File

@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import Seq2SeqLM
from text_generation.utils import (
from text_generation_server.models import Seq2SeqLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -1,6 +1,6 @@
from text_generation.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import (
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
@ -8,7 +8,7 @@ from text_generation.utils.hub import (
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation.utils.tokens import (
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,

View File

@ -11,9 +11,9 @@ from transformers import (
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling: