From 629047cb82d2ff97a8f0d0446ed7a3a68bed63a7 Mon Sep 17 00:00:00 2001 From: Thomas Schillaci Date: Thu, 23 May 2024 15:37:09 +0200 Subject: [PATCH] Add completion route to client and add stop parameter where it's missing (#1869) # What does this PR do? - Add the stop parameter to the completion route - Add the completion method to the python client - Add the stop parameter to the python client's chat method ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --------- Co-authored-by: Thomas SCHILLACI Co-authored-by: Thomas Schillaci --- clients/python/text_generation/client.py | 186 +++++++++++++++++++++++ clients/python/text_generation/types.py | 106 ++++++++----- docs/openapi.json | 9 ++ router/src/lib.rs | 5 + router/src/server.rs | 25 ++- 5 files changed, 286 insertions(+), 45 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 98c018d5..12966747 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -13,6 +13,9 @@ from text_generation.types import ( Request, Parameters, Grammar, + CompletionRequest, + Completion, + CompletionComplete, ChatRequest, ChatCompletionChunk, ChatComplete, @@ -70,6 +73,94 @@ class Client: self.cookies = cookies self.timeout = timeout + def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ): + """ + Given a prompt, generate a response synchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Completion(**payload) + else: + return self._completion_stream_response(request) + + def _completion_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + def chat( self, messages: List[Message], @@ -88,6 +179,7 @@ class Client: tools: Optional[List[Tool]] = None, tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ): """ Given a list of messages, generate a response asynchronously @@ -130,6 +222,8 @@ class Client: A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -150,6 +244,7 @@ class Client: tools=tools, tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: resp = requests.post( @@ -461,6 +556,93 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout) + async def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ) -> Union[Completion, AsyncIterator[CompletionComplete]]: + """ + Given a prompt, generate a response asynchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + return await self._completion_single_response(request) + else: + return self._completion_stream_response(request) + + async def _completion_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Completion(**payload) + + async def _completion_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def chat( self, messages: List[Message], @@ -479,6 +661,7 @@ class AsyncClient: tools: Optional[List[Tool]] = None, tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: """ Given a list of messages, generate a response asynchronously @@ -521,6 +704,8 @@ class AsyncClient: A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -541,6 +726,7 @@ class AsyncClient: tools=tools, tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: return await self._chat_single_response(request) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 5e32bc6f..eb872ee6 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -46,30 +46,6 @@ class Tool(BaseModel): function: dict -class ChatCompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - message: Message - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - # Usage details of the chat completion - usage: Optional[Any] = None - - -class CompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - text: str - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - - class Function(BaseModel): name: Optional[str] arguments: str @@ -95,24 +71,41 @@ class Choice(BaseModel): finish_reason: Optional[str] = None -class ChatCompletionChunk(BaseModel): - id: str - object: str - created: int +class CompletionRequest(BaseModel): + # Model identifier model: str - system_fingerprint: str - choices: List[Choice] + # Prompt + prompt: str + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None -class ChatComplete(BaseModel): - # Chat completion details - id: str - object: str - created: int - model: str - system_fingerprint: str - choices: List[ChatCompletionComplete] - usage: Any +class CompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + text: str + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str class Completion(BaseModel): @@ -163,6 +156,41 @@ class ChatRequest(BaseModel): tool_prompt: Optional[str] = None # Choice of tool to be used tool_choice: Optional[str] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Optional[Any] = None + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] class Parameters(BaseModel): diff --git a/docs/openapi.json b/docs/openapi.json index 2a387c2f..79c3b80f 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1121,6 +1121,15 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index febbf277..ba1d9acc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -402,6 +402,11 @@ pub struct CompletionRequest { #[serde(default)] #[schema(example = "1.0")] pub frequency_penalty: Option, + + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub stop: Option>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] diff --git a/router/src/server.rs b/router/src/server.rs index f51bbbef..e7570ded 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -597,9 +597,22 @@ async fn completions( let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); - let stream = req.stream; - let max_new_tokens = req.max_tokens.or(Some(100)); - let seed = req.seed; + let CompletionRequest { + max_tokens, + seed, + stop, + stream, + temperature, + .. + } = req; + + let max_new_tokens = max_tokens.or(Some(100)); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; // if suffix is present throw an error if req.suffix.is_some() { @@ -635,16 +648,16 @@ async fn completions( inputs: prompt.to_string(), parameters: GenerateParameters { best_of: None, - temperature: req.temperature, + temperature, repetition_penalty: req.repetition_penalty, frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, - do_sample: true, + do_sample, max_new_tokens, return_full_text: None, - stop: Vec::new(), + stop: stop.clone(), truncate: None, watermark: false, details: true,