feat(python-client): add new parameters (#118)
This commit is contained in:
parent
55bd4fed7d
commit
d8dc8f1b0c
|
@ -133,6 +133,22 @@ class FinishReason(Enum):
|
||||||
StopSequence = "stop_sequence"
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence:
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
class Details:
|
class Details:
|
||||||
# Generation finish reason
|
# Generation finish reason
|
||||||
|
@ -145,6 +161,8 @@ class Details:
|
||||||
prefill: List[PrefillToken]
|
prefill: List[PrefillToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` return value
|
# `generate` return value
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.2.1"
|
version = "0.3.0"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -4,11 +4,6 @@ from text_generation import __version__
|
||||||
from huggingface_hub.utils import build_hf_headers
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def bloom_model():
|
|
||||||
return "bigscience/bloom"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def flan_t5_xxl():
|
def flan_t5_xxl():
|
||||||
return "google/flan-t5-xxl"
|
return "google/flan-t5-xxl"
|
||||||
|
|
|
@ -5,24 +5,32 @@ from text_generation.errors import NotFoundError, ValidationError
|
||||||
from text_generation.types import FinishReason, PrefillToken, Token
|
from text_generation.types import FinishReason, PrefillToken, Token
|
||||||
|
|
||||||
|
|
||||||
def test_generate(bloom_url, hf_headers):
|
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate("test", max_new_tokens=1)
|
response = client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
id=9234, text="test", logprob=None
|
|
||||||
)
|
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=17, text=".", logprob=-1.75, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
|
||||||
|
|
||||||
|
assert response.details.seed is not None
|
||||||
|
assert response.details.best_of_sequences is not None
|
||||||
|
assert len(response.details.best_of_sequences) == 1
|
||||||
|
assert response.details.best_of_sequences[0].seed is not None
|
||||||
|
|
||||||
|
|
||||||
def test_generate_not_found(fake_url, hf_headers):
|
def test_generate_not_found(fake_url, hf_headers):
|
||||||
client = Client(fake_url, hf_headers)
|
client = Client(fake_url, hf_headers)
|
||||||
with pytest.raises(NotFoundError):
|
with pytest.raises(NotFoundError):
|
||||||
|
@ -35,8 +43,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
client.generate("test", max_new_tokens=10_000)
|
client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_stream(bloom_url, hf_headers):
|
def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(bloom_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
responses = [
|
responses = [
|
||||||
response for response in client.generate_stream("test", max_new_tokens=1)
|
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
]
|
]
|
||||||
|
@ -44,7 +52,7 @@ def test_generate_stream(bloom_url, hf_headers):
|
||||||
assert len(responses) == 1
|
assert len(responses) == 1
|
||||||
response = responses[0]
|
response = responses[0]
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
|
@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_async(bloom_url, hf_headers):
|
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
response = await client.generate("test", max_new_tokens=1)
|
response = await client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
assert len(response.details.prefill) == 1
|
assert len(response.details.prefill) == 1
|
||||||
assert response.details.prefill[0] == PrefillToken(
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
id=9234, text="test", logprob=None
|
|
||||||
)
|
|
||||||
assert len(response.details.tokens) == 1
|
assert len(response.details.tokens) == 1
|
||||||
assert response.details.tokens[0] == Token(
|
assert response.details.tokens[0] == Token(
|
||||||
id=17, text=".", logprob=-1.75, special=False
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,8 +102,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_stream_async(bloom_url, hf_headers):
|
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||||
client = AsyncClient(bloom_url, hf_headers)
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
responses = [
|
responses = [
|
||||||
response async for response in client.generate_stream("test", max_new_tokens=1)
|
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
]
|
]
|
||||||
|
@ -105,7 +111,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
|
||||||
assert len(responses) == 1
|
assert len(responses) == 1
|
||||||
response = responses[0]
|
response = responses[0]
|
||||||
|
|
||||||
assert response.generated_text == "."
|
assert response.generated_text == ""
|
||||||
assert response.details.finish_reason == FinishReason.Length
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
assert response.details.generated_tokens == 1
|
assert response.details.generated_tokens == 1
|
||||||
assert response.details.seed is None
|
assert response.details.seed is None
|
||||||
|
|
|
@ -14,8 +14,8 @@ def test_get_supported_models():
|
||||||
assert isinstance(get_supported_models(), list)
|
assert isinstance(get_supported_models(), list)
|
||||||
|
|
||||||
|
|
||||||
def test_client(bloom_model):
|
def test_client(flan_t5_xxl):
|
||||||
client = InferenceAPIClient(bloom_model)
|
client = InferenceAPIClient(flan_t5_xxl)
|
||||||
assert isinstance(client, Client)
|
assert isinstance(client, Client)
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
|
||||||
InferenceAPIClient(unsupported_model)
|
InferenceAPIClient(unsupported_model)
|
||||||
|
|
||||||
|
|
||||||
def test_async_client(bloom_model):
|
def test_async_client(flan_t5_xxl):
|
||||||
client = InferenceAPIAsyncClient(bloom_model)
|
client = InferenceAPIAsyncClient(flan_t5_xxl)
|
||||||
assert isinstance(client, AsyncClient)
|
assert isinstance(client, AsyncClient)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,20 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.types import Parameters
|
from text_generation.types import Parameters, Request
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
def test_parameters_validation():
|
def test_parameters_validation():
|
||||||
|
# Test best_of
|
||||||
|
Parameters(best_of=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=-1)
|
||||||
|
Parameters(best_of=2, do_sample=True)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=2)
|
||||||
|
|
||||||
# Test repetition_penalty
|
# Test repetition_penalty
|
||||||
Parameters(repetition_penalty=1)
|
Parameters(repetition_penalty=1)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
|
@ -32,8 +42,41 @@ def test_parameters_validation():
|
||||||
Parameters(top_k=-1)
|
Parameters(top_k=-1)
|
||||||
|
|
||||||
# Test top_p
|
# Test top_p
|
||||||
Parameters(top_p=1)
|
Parameters(top_p=0.5)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Parameters(top_p=0)
|
Parameters(top_p=0)
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Parameters(top_p=-1)
|
Parameters(top_p=-1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=1)
|
||||||
|
|
||||||
|
# Test truncate
|
||||||
|
Parameters(truncate=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(truncate=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(truncate=-1)
|
||||||
|
|
||||||
|
# Test typical_p
|
||||||
|
Parameters(typical_p=0.5)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=-1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_validation():
|
||||||
|
Request(inputs="test")
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Request(inputs="")
|
||||||
|
|
||||||
|
Request(inputs="test", stream=True)
|
||||||
|
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Request(
|
||||||
|
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
|
||||||
|
)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.3.0"
|
||||||
|
|
||||||
from text_generation.client import Client, AsyncClient
|
from text_generation.client import Client, AsyncClient
|
||||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||||
|
|
|
@ -56,6 +56,7 @@ class Client:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -63,6 +64,8 @@ class Client:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
|
@ -75,6 +78,8 @@ class Client:
|
||||||
Activate logits sampling
|
Activate logits sampling
|
||||||
max_new_tokens (`int`):
|
max_new_tokens (`int`):
|
||||||
Maximum number of generated tokens
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
repetition_penalty (`float`):
|
repetition_penalty (`float`):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
@ -91,6 +96,11 @@ class Client:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
@ -99,6 +109,7 @@ class Client:
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
details=True,
|
details=True,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
@ -109,6 +120,8 @@ class Client:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
@ -129,6 +142,7 @@ class Client:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -136,6 +150,8 @@ class Client:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
|
@ -148,6 +164,8 @@ class Client:
|
||||||
Activate logits sampling
|
Activate logits sampling
|
||||||
max_new_tokens (`int`):
|
max_new_tokens (`int`):
|
||||||
Maximum number of generated tokens
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
repetition_penalty (`float`):
|
repetition_penalty (`float`):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
@ -164,6 +182,11 @@ class Client:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
@ -172,6 +195,7 @@ class Client:
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
details=True,
|
details=True,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
@ -182,6 +206,8 @@ class Client:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
@ -261,6 +287,7 @@ class AsyncClient:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -268,6 +295,8 @@ class AsyncClient:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
|
@ -280,6 +309,8 @@ class AsyncClient:
|
||||||
Activate logits sampling
|
Activate logits sampling
|
||||||
max_new_tokens (`int`):
|
max_new_tokens (`int`):
|
||||||
Maximum number of generated tokens
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
repetition_penalty (`float`):
|
repetition_penalty (`float`):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
@ -296,6 +327,11 @@ class AsyncClient:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
@ -304,6 +340,7 @@ class AsyncClient:
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
details=True,
|
details=True,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
@ -314,6 +351,8 @@ class AsyncClient:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
@ -331,6 +370,7 @@ class AsyncClient:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -338,6 +378,8 @@ class AsyncClient:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
|
@ -350,6 +392,8 @@ class AsyncClient:
|
||||||
Activate logits sampling
|
Activate logits sampling
|
||||||
max_new_tokens (`int`):
|
max_new_tokens (`int`):
|
||||||
Maximum number of generated tokens
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
repetition_penalty (`float`):
|
repetition_penalty (`float`):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
@ -366,6 +410,11 @@ class AsyncClient:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
@ -374,6 +423,7 @@ class AsyncClient:
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
details=True,
|
details=True,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
@ -384,6 +434,8 @@ class AsyncClient:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
|
@ -6,27 +6,64 @@ from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
|
# Maximum number of generated tokens
|
||||||
max_new_tokens: int = 20
|
max_new_tokens: int = 20
|
||||||
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
|
# Whether to prepend the prompt to the generated text
|
||||||
return_full_text: bool = False
|
return_full_text: bool = False
|
||||||
|
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
stop: List[str] = []
|
stop: List[str] = []
|
||||||
|
# Random sampling seed
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
# The value used to module the logits distribution.
|
||||||
temperature: Optional[float]
|
temperature: Optional[float]
|
||||||
|
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
top_k: Optional[int]
|
top_k: Optional[int]
|
||||||
|
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
# higher are kept for generation.
|
||||||
top_p: Optional[float]
|
top_p: Optional[float]
|
||||||
|
# truncate inputs tokens to the given size
|
||||||
|
truncate: Optional[int]
|
||||||
|
# Typical Decoding mass
|
||||||
|
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
typical_p: Optional[float]
|
||||||
|
# Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
best_of: Optional[int]
|
||||||
|
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
watermark: bool = False
|
watermark: bool = False
|
||||||
|
# Get generation details
|
||||||
details: bool = False
|
details: bool = False
|
||||||
|
|
||||||
|
@validator("best_of")
|
||||||
|
def valid_best_of(cls, field_value, values):
|
||||||
|
if field_value is not None:
|
||||||
|
if field_value <= 0:
|
||||||
|
raise ValidationError("`best_of` must be strictly positive")
|
||||||
|
sampling = (
|
||||||
|
values["do_sample"]
|
||||||
|
| (values["temperature"] is not None)
|
||||||
|
| (values["top_k"] is not None)
|
||||||
|
| (values["top_p"] is not None)
|
||||||
|
| (values["typical_p"] is not None)
|
||||||
|
)
|
||||||
|
if field_value > 1 and not sampling:
|
||||||
|
raise ValidationError("you must use sampling when `best_of` is > 1")
|
||||||
|
|
||||||
|
return field_value
|
||||||
|
|
||||||
@validator("repetition_penalty")
|
@validator("repetition_penalty")
|
||||||
def valid_repetition_penalty(cls, v):
|
def valid_repetition_penalty(cls, v):
|
||||||
if v is not None and v is v <= 0:
|
if v is not None and v <= 0:
|
||||||
raise ValidationError("`repetition_penalty` must be strictly positive")
|
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("seed")
|
@validator("seed")
|
||||||
def valid_seed(cls, v):
|
def valid_seed(cls, v):
|
||||||
if v is not None and v is v < 0:
|
if v is not None and v < 0:
|
||||||
raise ValidationError("`seed` must be positive")
|
raise ValidationError("`seed` must be positive")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -44,56 +81,143 @@ class Parameters(BaseModel):
|
||||||
|
|
||||||
@validator("top_p")
|
@validator("top_p")
|
||||||
def valid_top_p(cls, v):
|
def valid_top_p(cls, v):
|
||||||
if v is not None and (v <= 0 or v > 1.0):
|
if v is not None and (v <= 0 or v >= 1.0):
|
||||||
raise ValidationError("`top_p` must be > 0.0 and <= 1.0")
|
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("truncate")
|
||||||
|
def valid_truncate(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`truncate` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("typical_p")
|
||||||
|
def valid_typical_p(cls, v):
|
||||||
|
if v is not None and (v <= 0 or v >= 1.0):
|
||||||
|
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Request(BaseModel):
|
class Request(BaseModel):
|
||||||
|
# Prompt
|
||||||
inputs: str
|
inputs: str
|
||||||
parameters: Parameters
|
# Generation parameters
|
||||||
|
parameters: Optional[Parameters]
|
||||||
|
# Whether to stream output tokens
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
|
@validator("inputs")
|
||||||
|
def valid_input(cls, v):
|
||||||
|
if not v:
|
||||||
|
raise ValidationError("`inputs` cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("stream")
|
||||||
|
def valid_best_of_stream(cls, field_value, values):
|
||||||
|
parameters = values["parameters"]
|
||||||
|
if (
|
||||||
|
parameters is not None
|
||||||
|
and parameters.best_of is not None
|
||||||
|
and parameters.best_of > 1
|
||||||
|
and field_value
|
||||||
|
):
|
||||||
|
raise ValidationError(
|
||||||
|
"`best_of` != 1 is not supported when `stream` == True"
|
||||||
|
)
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
|
# Prompt tokens
|
||||||
class PrefillToken(BaseModel):
|
class PrefillToken(BaseModel):
|
||||||
|
# Token ID from the model tokenizer
|
||||||
id: int
|
id: int
|
||||||
|
# Token text
|
||||||
text: str
|
text: str
|
||||||
|
# Logprob
|
||||||
|
# Optional since the logprob of the first token cannot be computed
|
||||||
logprob: Optional[float]
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
# Generated tokens
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
|
# Token ID from the model tokenizer
|
||||||
id: int
|
id: int
|
||||||
|
# Token text
|
||||||
text: str
|
text: str
|
||||||
|
# Logprob
|
||||||
logprob: float
|
logprob: float
|
||||||
|
# Is the token a special token
|
||||||
|
# Can be used to ignore tokens when concatenating
|
||||||
special: bool
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Generation finish reason
|
||||||
class FinishReason(Enum):
|
class FinishReason(Enum):
|
||||||
|
# number of generated tokens == `max_new_tokens`
|
||||||
Length = "length"
|
Length = "length"
|
||||||
|
# the model generated its end of sequence token
|
||||||
EndOfSequenceToken = "eos_token"
|
EndOfSequenceToken = "eos_token"
|
||||||
|
# the model generated a text included in `stop_sequences`
|
||||||
StopSequence = "stop_sequence"
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
class Details(BaseModel):
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence(BaseModel):
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
finish_reason: FinishReason
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
prefill: List[PrefillToken]
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
class StreamDetails(BaseModel):
|
# `generate` details
|
||||||
|
class Details(BaseModel):
|
||||||
|
# Generation finish reason
|
||||||
finish_reason: FinishReason
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` return value
|
||||||
class Response(BaseModel):
|
class Response(BaseModel):
|
||||||
|
# Generated text
|
||||||
generated_text: str
|
generated_text: str
|
||||||
|
# Generation details
|
||||||
details: Details
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` details
|
||||||
|
class StreamDetails(BaseModel):
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` return value
|
||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
|
# Generated token
|
||||||
token: Token
|
token: Token
|
||||||
|
# Complete generated text
|
||||||
|
# Only available when the generation is finished
|
||||||
generated_text: Optional[str]
|
generated_text: Optional[str]
|
||||||
|
# Generation details
|
||||||
|
# Only available when the generation is finished
|
||||||
details: Optional[StreamDetails]
|
details: Optional[StreamDetails]
|
||||||
|
|
Loading…
Reference in New Issue