diff --git a/clients/python/README.md b/clients/python/README.md index 79b8837..f509e65 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -133,6 +133,22 @@ class FinishReason(Enum): StopSequence = "stop_sequence" +# Additional sequences when using the `best_of` parameter +class BestOfSequence: + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + + # `generate` details class Details: # Generation finish reason @@ -145,6 +161,8 @@ class Details: prefill: List[PrefillToken] # Generated tokens tokens: List[Token] + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] # `generate` return value diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 0b8fa8c..51ecce8 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.2.1" +version = "0.3.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 4298623..48734f0 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -4,11 +4,6 @@ from text_generation import __version__ from huggingface_hub.utils import build_hf_headers -@pytest.fixture -def bloom_model(): - return "bigscience/bloom" - - @pytest.fixture def flan_t5_xxl(): return "google/flan-t5-xxl" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 2f96aa8..c998de4 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -5,24 +5,32 @@ from text_generation.errors import NotFoundError, ValidationError from text_generation.types import FinishReason, PrefillToken, Token -def test_generate(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) response = client.generate("test", max_new_tokens=1) - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken( - id=9234, text="test", logprob=None - ) + assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( - id=17, text=".", logprob=-1.75, special=False + id=3, text=" ", logprob=-1.984375, special=False ) +def test_generate_best_of(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) + response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + def test_generate_not_found(fake_url, hf_headers): client = Client(fake_url, hf_headers) with pytest.raises(NotFoundError): @@ -35,8 +43,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers): client.generate("test", max_new_tokens=10_000) -def test_generate_stream(bloom_url, hf_headers): - client = Client(bloom_url, hf_headers) +def test_generate_stream(flan_t5_xxl_url, hf_headers): + client = Client(flan_t5_xxl_url, hf_headers) responses = [ response for response in client.generate_stream("test", max_new_tokens=1) ] @@ -44,7 +52,7 @@ def test_generate_stream(bloom_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None @@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio -async def test_generate_async(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_async(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) response = await client.generate("test", max_new_tokens=1) - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == PrefillToken( - id=9234, text="test", logprob=None - ) + assert response.details.prefill[0] == PrefillToken(id=0, text="", logprob=None) assert len(response.details.tokens) == 1 assert response.details.tokens[0] == Token( - id=17, text=".", logprob=-1.75, special=False + id=3, text=" ", logprob=-1.984375, special=False ) @@ -96,8 +102,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): @pytest.mark.asyncio -async def test_generate_stream_async(bloom_url, hf_headers): - client = AsyncClient(bloom_url, hf_headers) +async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): + client = AsyncClient(flan_t5_xxl_url, hf_headers) responses = [ response async for response in client.generate_stream("test", max_new_tokens=1) ] @@ -105,7 +111,7 @@ async def test_generate_stream_async(bloom_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "." + assert response.generated_text == "" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py index dc74494..79e503a 100644 --- a/clients/python/tests/test_inference_api.py +++ b/clients/python/tests/test_inference_api.py @@ -14,8 +14,8 @@ def test_get_supported_models(): assert isinstance(get_supported_models(), list) -def test_client(bloom_model): - client = InferenceAPIClient(bloom_model) +def test_client(flan_t5_xxl): + client = InferenceAPIClient(flan_t5_xxl) assert isinstance(client, Client) @@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model): InferenceAPIClient(unsupported_model) -def test_async_client(bloom_model): - client = InferenceAPIAsyncClient(bloom_model) +def test_async_client(flan_t5_xxl): + client = InferenceAPIAsyncClient(flan_t5_xxl) assert isinstance(client, AsyncClient) diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py index d319b57..4c9d4c8 100644 --- a/clients/python/tests/test_types.py +++ b/clients/python/tests/test_types.py @@ -1,10 +1,20 @@ import pytest -from text_generation.types import Parameters +from text_generation.types import Parameters, Request from text_generation.errors import ValidationError def test_parameters_validation(): + # Test best_of + Parameters(best_of=1) + with pytest.raises(ValidationError): + Parameters(best_of=0) + with pytest.raises(ValidationError): + Parameters(best_of=-1) + Parameters(best_of=2, do_sample=True) + with pytest.raises(ValidationError): + Parameters(best_of=2) + # Test repetition_penalty Parameters(repetition_penalty=1) with pytest.raises(ValidationError): @@ -32,8 +42,41 @@ def test_parameters_validation(): Parameters(top_k=-1) # Test top_p - Parameters(top_p=1) + Parameters(top_p=0.5) with pytest.raises(ValidationError): Parameters(top_p=0) with pytest.raises(ValidationError): Parameters(top_p=-1) + with pytest.raises(ValidationError): + Parameters(top_p=1) + + # Test truncate + Parameters(truncate=1) + with pytest.raises(ValidationError): + Parameters(truncate=0) + with pytest.raises(ValidationError): + Parameters(truncate=-1) + + # Test typical_p + Parameters(typical_p=0.5) + with pytest.raises(ValidationError): + Parameters(typical_p=0) + with pytest.raises(ValidationError): + Parameters(typical_p=-1) + with pytest.raises(ValidationError): + Parameters(typical_p=1) + + +def test_request_validation(): + Request(inputs="test") + + with pytest.raises(ValidationError): + Request(inputs="") + + Request(inputs="test", stream=True) + Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) + + with pytest.raises(ValidationError): + Request( + inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True + ) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index db09cdf..4610983 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.1" +__version__ = "0.3.0" from text_generation.client import Client, AsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 2365524..e05a002 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -56,6 +56,7 @@ class Client: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -63,6 +64,8 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Response: """ @@ -75,6 +78,8 @@ class Client: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -91,6 +96,11 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -99,6 +109,7 @@ class Client: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -109,6 +120,8 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -129,6 +142,7 @@ class Client: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -136,6 +150,8 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Iterator[StreamResponse]: """ @@ -148,6 +164,8 @@ class Client: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -164,6 +182,11 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -172,6 +195,7 @@ class Client: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -182,6 +206,8 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -261,6 +287,7 @@ class AsyncClient: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -268,6 +295,8 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> Response: """ @@ -280,6 +309,8 @@ class AsyncClient: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -296,6 +327,11 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -304,6 +340,7 @@ class AsyncClient: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -314,6 +351,8 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -331,6 +370,7 @@ class AsyncClient: prompt: str, do_sample: bool = False, max_new_tokens: int = 20, + best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -338,6 +378,8 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, watermark: bool = False, ) -> AsyncIterator[StreamResponse]: """ @@ -350,6 +392,8 @@ class AsyncClient: Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -366,6 +410,11 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) @@ -374,6 +423,7 @@ class AsyncClient: """ # Validate parameters parameters = Parameters( + best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -384,6 +434,8 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, + truncate=truncate, + typical_p=typical_p, watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index d276b60..ea2070b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -6,27 +6,64 @@ from text_generation.errors import ValidationError class Parameters(BaseModel): + # Activate logits sampling do_sample: bool = False + # Maximum number of generated tokens max_new_tokens: int = 20 + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None + # Whether to prepend the prompt to the generated text return_full_text: bool = False + # Stop generating tokens if a member of `stop_sequences` is generated stop: List[str] = [] + # Random sampling seed seed: Optional[int] + # The value used to module the logits distribution. temperature: Optional[float] + # The number of highest probability vocabulary tokens to keep for top-k-filtering. top_k: Optional[int] + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. top_p: Optional[float] + # truncate inputs tokens to the given size + truncate: Optional[int] + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) watermark: bool = False + # Get generation details details: bool = False + @validator("best_of") + def valid_best_of(cls, field_value, values): + if field_value is not None: + if field_value <= 0: + raise ValidationError("`best_of` must be strictly positive") + sampling = ( + values["do_sample"] + | (values["temperature"] is not None) + | (values["top_k"] is not None) + | (values["top_p"] is not None) + | (values["typical_p"] is not None) + ) + if field_value > 1 and not sampling: + raise ValidationError("you must use sampling when `best_of` is > 1") + + return field_value + @validator("repetition_penalty") def valid_repetition_penalty(cls, v): - if v is not None and v is v <= 0: + if v is not None and v <= 0: raise ValidationError("`repetition_penalty` must be strictly positive") return v @validator("seed") def valid_seed(cls, v): - if v is not None and v is v < 0: + if v is not None and v < 0: raise ValidationError("`seed` must be positive") return v @@ -44,56 +81,143 @@ class Parameters(BaseModel): @validator("top_p") def valid_top_p(cls, v): - if v is not None and (v <= 0 or v > 1.0): - raise ValidationError("`top_p` must be > 0.0 and <= 1.0") + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`top_p` must be > 0.0 and < 1.0") + return v + + @validator("truncate") + def valid_truncate(cls, v): + if v is not None and v <= 0: + raise ValidationError("`truncate` must be strictly positive") + return v + + @validator("typical_p") + def valid_typical_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v class Request(BaseModel): + # Prompt inputs: str - parameters: Parameters + # Generation parameters + parameters: Optional[Parameters] + # Whether to stream output tokens stream: bool = False + @validator("inputs") + def valid_input(cls, v): + if not v: + raise ValidationError("`inputs` cannot be empty") + return v + @validator("stream") + def valid_best_of_stream(cls, field_value, values): + parameters = values["parameters"] + if ( + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value + ): + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) + return field_value + + +# Prompt tokens class PrefillToken(BaseModel): + # Token ID from the model tokenizer id: int + # Token text text: str + # Logprob + # Optional since the logprob of the first token cannot be computed logprob: Optional[float] +# Generated tokens class Token(BaseModel): + # Token ID from the model tokenizer id: int + # Token text text: str + # Logprob logprob: float + # Is the token a special token + # Can be used to ignore tokens when concatenating special: bool +# Generation finish reason class FinishReason(Enum): + # number of generated tokens == `max_new_tokens` Length = "length" + # the model generated its end of sequence token EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` StopSequence = "stop_sequence" -class Details(BaseModel): +# Additional sequences when using the `best_of` parameter +class BestOfSequence(BaseModel): + # Generated text + generated_text: str + # Generation finish reason finish_reason: FinishReason + # Number of generated tokens generated_tokens: int + # Sampling seed if sampling was activated seed: Optional[int] + # Prompt tokens prefill: List[PrefillToken] + # Generated tokens tokens: List[Token] -class StreamDetails(BaseModel): +# `generate` details +class Details(BaseModel): + # Generation finish reason finish_reason: FinishReason + # Number of generated tokens generated_tokens: int + # Sampling seed if sampling was activated seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] +# `generate` return value class Response(BaseModel): + # Generated text generated_text: str + # Generation details details: Details +# `generate_stream` details +class StreamDetails(BaseModel): + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + + +# `generate_stream` return value class StreamResponse(BaseModel): + # Generated token token: Token + # Complete generated text + # Only available when the generation is finished generated_text: Optional[str] + # Generation details + # Only available when the generation is finished details: Optional[StreamDetails]