feat(python-client): add new parameters (#118)

This commit is contained in:
OlivierDehaene 2023-03-09 16:05:33 +01:00 committed by GitHub
parent 55bd4fed7d
commit d8dc8f1b0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 278 additions and 40 deletions

View File

@ -133,6 +133,22 @@ class FinishReason(Enum):
StopSequence = "stop_sequence" StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details # `generate` details
class Details: class Details:
# Generation finish reason # Generation finish reason
@ -145,6 +161,8 @@ class Details:
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value # `generate` return value

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.2.1" version = "0.3.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -4,11 +4,6 @@ from text_generation import __version__
from huggingface_hub.utils import build_hf_headers from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def bloom_model():
return "bigscience/bloom"
@pytest.fixture @pytest.fixture
def flan_t5_xxl(): def flan_t5_xxl():
return "google/flan-t5-xxl" return "google/flan-t5-xxl"

View File

@ -5,24 +5,32 @@ from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(bloom_url, hf_headers): def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1) response = client.generate("test", max_new_tokens=1)
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False id=3, text=" ", logprob=-1.984375, special=False
) )
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
assert len(response.details.best_of_sequences) == 1
assert response.details.best_of_sequences[0].seed is not None
def test_generate_not_found(fake_url, hf_headers): def test_generate_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers) client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
@ -35,8 +43,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client.generate("test", max_new_tokens=10_000) client.generate("test", max_new_tokens=10_000)
def test_generate_stream(bloom_url, hf_headers): def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(bloom_url, hf_headers) client = Client(flan_t5_xxl_url, hf_headers)
responses = [ responses = [
response for response in client.generate_stream("test", max_new_tokens=1) response for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -44,7 +52,7 @@ def test_generate_stream(bloom_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(bloom_url, hf_headers): async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1) response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken( assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
id=9234, text="test", logprob=None
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token( assert response.details.tokens[0] == Token(
id=17, text=".", logprob=-1.75, special=False id=3, text=" ", logprob=-1.984375, special=False
) )
@ -96,8 +102,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async(bloom_url, hf_headers): async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(bloom_url, hf_headers) client = AsyncClient(flan_t5_xxl_url, hf_headers)
responses = [ responses = [
response async for response in client.generate_stream("test", max_new_tokens=1) response async for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -105,7 +111,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "." assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None

View File

@ -14,8 +14,8 @@ def test_get_supported_models():
assert isinstance(get_supported_models(), list) assert isinstance(get_supported_models(), list)
def test_client(bloom_model): def test_client(flan_t5_xxl):
client = InferenceAPIClient(bloom_model) client = InferenceAPIClient(flan_t5_xxl)
assert isinstance(client, Client) assert isinstance(client, Client)
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
InferenceAPIClient(unsupported_model) InferenceAPIClient(unsupported_model)
def test_async_client(bloom_model): def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(bloom_model) client = InferenceAPIAsyncClient(flan_t5_xxl)
assert isinstance(client, AsyncClient) assert isinstance(client, AsyncClient)

View File

@ -1,10 +1,20 @@
import pytest import pytest
from text_generation.types import Parameters from text_generation.types import Parameters, Request
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
def test_parameters_validation(): def test_parameters_validation():
# Test best_of
Parameters(best_of=1)
with pytest.raises(ValidationError):
Parameters(best_of=0)
with pytest.raises(ValidationError):
Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
# Test repetition_penalty # Test repetition_penalty
Parameters(repetition_penalty=1) Parameters(repetition_penalty=1)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
@ -32,8 +42,41 @@ def test_parameters_validation():
Parameters(top_k=-1) Parameters(top_k=-1)
# Test top_p # Test top_p
Parameters(top_p=1) Parameters(top_p=0.5)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Parameters(top_p=0) Parameters(top_p=0)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Parameters(top_p=-1) Parameters(top_p=-1)
with pytest.raises(ValidationError):
Parameters(top_p=1)
# Test truncate
Parameters(truncate=1)
with pytest.raises(ValidationError):
Parameters(truncate=0)
with pytest.raises(ValidationError):
Parameters(truncate=-1)
# Test typical_p
Parameters(typical_p=0.5)
with pytest.raises(ValidationError):
Parameters(typical_p=0)
with pytest.raises(ValidationError):
Parameters(typical_p=-1)
with pytest.raises(ValidationError):
Parameters(typical_p=1)
def test_request_validation():
Request(inputs="test")
with pytest.raises(ValidationError):
Request(inputs="")
Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.2.1" __version__ = "0.3.0"
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -56,6 +56,7 @@ class Client:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -63,6 +64,8 @@ class Client:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
) -> Response: ) -> Response:
""" """
@ -75,6 +78,8 @@ class Client:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -91,6 +96,11 @@ class Client:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
@ -99,6 +109,7 @@ class Client:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@ -109,6 +120,8 @@ class Client:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark, watermark=watermark,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -129,6 +142,7 @@ class Client:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -136,6 +150,8 @@ class Client:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
@ -148,6 +164,8 @@ class Client:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -164,6 +182,11 @@ class Client:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
@ -172,6 +195,7 @@ class Client:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@ -182,6 +206,8 @@ class Client:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark, watermark=watermark,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -261,6 +287,7 @@ class AsyncClient:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -268,6 +295,8 @@ class AsyncClient:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
) -> Response: ) -> Response:
""" """
@ -280,6 +309,8 @@ class AsyncClient:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -296,6 +327,11 @@ class AsyncClient:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
@ -304,6 +340,7 @@ class AsyncClient:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@ -314,6 +351,8 @@ class AsyncClient:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark, watermark=watermark,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -331,6 +370,7 @@ class AsyncClient:
prompt: str, prompt: str,
do_sample: bool = False, do_sample: bool = False,
max_new_tokens: int = 20, max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
return_full_text: bool = False, return_full_text: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -338,6 +378,8 @@ class AsyncClient:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
@ -350,6 +392,8 @@ class AsyncClient:
Activate logits sampling Activate logits sampling
max_new_tokens (`int`): max_new_tokens (`int`):
Maximum number of generated tokens Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`): repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
@ -366,6 +410,11 @@ class AsyncClient:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
@ -374,6 +423,7 @@ class AsyncClient:
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of,
details=True, details=True,
do_sample=do_sample, do_sample=do_sample,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
@ -384,6 +434,8 @@ class AsyncClient:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark, watermark=watermark,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -6,27 +6,64 @@ from text_generation.errors import ValidationError
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20 max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text
return_full_text: bool = False return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = [] stop: List[str] = []
# Random sampling seed
seed: Optional[int] seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float] temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int] top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float] top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False watermark: bool = False
# Get generation details
details: bool = False details: bool = False
@validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
@validator("repetition_penalty") @validator("repetition_penalty")
def valid_repetition_penalty(cls, v): def valid_repetition_penalty(cls, v):
if v is not None and v is v <= 0: if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive") raise ValidationError("`repetition_penalty` must be strictly positive")
return v return v
@validator("seed") @validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
if v is not None and v is v < 0: if v is not None and v < 0:
raise ValidationError("`seed` must be positive") raise ValidationError("`seed` must be positive")
return v return v
@ -44,56 +81,143 @@ class Parameters(BaseModel):
@validator("top_p") @validator("top_p")
def valid_top_p(cls, v): def valid_top_p(cls, v):
if v is not None and (v <= 0 or v > 1.0): if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and <= 1.0") raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
@validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
@validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
class Request(BaseModel): class Request(BaseModel):
# Prompt
inputs: str inputs: str
parameters: Parameters # Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool = False stream: bool = False
@validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
@validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value
# Prompt tokens
class PrefillToken(BaseModel): class PrefillToken(BaseModel):
# Token ID from the model tokenizer
id: int id: int
# Token text
text: str text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float] logprob: Optional[float]
# Generated tokens
class Token(BaseModel): class Token(BaseModel):
# Token ID from the model tokenizer
id: int id: int
# Token text
text: str text: str
# Logprob
logprob: float logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool special: bool
# Generation finish reason
class FinishReason(Enum): class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length" Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token" EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence" StopSequence = "stop_sequence"
class Details(BaseModel): # Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken] prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token] tokens: List[Token]
class StreamDetails(BaseModel): # `generate` details
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int] seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response(BaseModel): class Response(BaseModel):
# Generated text
generated_text: str generated_text: str
# Generation details
details: Details details: Details
# `generate_stream` details
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
# Generated token
token: Token token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str] generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails] details: Optional[StreamDetails]