From 9fb1cdc8d52dee19330f80cf9231d3dde87751ba Mon Sep 17 00:00:00 2001 From: Thomas SCHILLACI Date: Tue, 7 May 2024 18:27:28 +0200 Subject: [PATCH 1/3] Add completion route to client and stop parameter --- clients/python/text_generation/client.py | 194 ++++++++++++++++++++++- clients/python/text_generation/types.py | 106 ++++++++----- 2 files changed, 253 insertions(+), 47 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0e86901d..8acedfb9 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -11,6 +11,9 @@ from text_generation.types import ( Request, Parameters, Grammar, + CompletionRequest, + Completion, + CompletionComplete, ChatRequest, ChatCompletionChunk, ChatComplete, @@ -64,6 +67,94 @@ class Client: self.cookies = cookies self.timeout = timeout + def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ): + """ + Given a prompt, generate a response synchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Completion(**payload) + else: + return self._completion_stream_response(request) + + def _completion_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + def chat( self, messages: List[Message], @@ -80,8 +171,8 @@ class Client: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, - tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ): """ Given a list of messages, generate a response asynchronously @@ -120,10 +211,10 @@ class Client: higher are kept for generation tools (`List[Tool]`): List of tools to use - tool_prompt (`str`): - A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -142,8 +233,8 @@ class Client: temperature=temperature, top_p=top_p, tools=tools, - tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: resp = requests.post( @@ -454,6 +545,93 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout) + async def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ) -> Union[Completion, AsyncIterator[CompletionComplete]]: + """ + Given a prompt, generate a response asynchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + return await self._completion_single_response(request) + else: + return self._completion_stream_response(request) + + async def _completion_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Completion(**payload) + + async def _completion_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def chat( self, messages: List[Message], @@ -470,8 +648,8 @@ class AsyncClient: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, - tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: """ Given a list of messages, generate a response asynchronously @@ -510,10 +688,10 @@ class AsyncClient: higher are kept for generation tools (`List[Tool]`): List of tools to use - tool_prompt (`str`): - A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -532,8 +710,8 @@ class AsyncClient: temperature=temperature, top_p=top_p, tools=tools, - tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: return await self._chat_single_response(request) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 5e32bc6f..eb872ee6 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -46,30 +46,6 @@ class Tool(BaseModel): function: dict -class ChatCompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - message: Message - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - # Usage details of the chat completion - usage: Optional[Any] = None - - -class CompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - text: str - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - - class Function(BaseModel): name: Optional[str] arguments: str @@ -95,24 +71,41 @@ class Choice(BaseModel): finish_reason: Optional[str] = None -class ChatCompletionChunk(BaseModel): - id: str - object: str - created: int +class CompletionRequest(BaseModel): + # Model identifier model: str - system_fingerprint: str - choices: List[Choice] + # Prompt + prompt: str + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None -class ChatComplete(BaseModel): - # Chat completion details - id: str - object: str - created: int - model: str - system_fingerprint: str - choices: List[ChatCompletionComplete] - usage: Any +class CompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + text: str + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str class Completion(BaseModel): @@ -163,6 +156,41 @@ class ChatRequest(BaseModel): tool_prompt: Optional[str] = None # Choice of tool to be used tool_choice: Optional[str] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Optional[Any] = None + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] class Parameters(BaseModel): From 2f644779cb82eecf7023ffff6d66bcb54283800c Mon Sep 17 00:00:00 2001 From: Thomas SCHILLACI Date: Tue, 7 May 2024 19:27:05 +0200 Subject: [PATCH 2/3] Add stop parameter to completions route --- docs/openapi.json | 9 +++++++++ router/src/lib.rs | 5 +++++ router/src/server.rs | 27 ++++++++++++++++++++------- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 2a387c2f..79c3b80f 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1121,6 +1121,15 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true + }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 96a9fdf6..37194c63 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -401,6 +401,11 @@ pub struct CompletionRequest { #[serde(default)] #[schema(example = "1.0")] pub frequency_penalty: Option, + + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub stop: Option>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] diff --git a/router/src/server.rs b/router/src/server.rs index cb55d897..747702f5 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -597,9 +597,22 @@ async fn completions( let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); - let stream = req.stream; - let max_new_tokens = req.max_tokens.or(Some(100)); - let seed = req.seed; + let CompletionRequest { + max_tokens, + seed, + stop, + stream, + temperature, + .. + } = req; + + let max_new_tokens = max_tokens.or(Some(100)); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; // if suffix is present throw an error if req.suffix.is_some() { @@ -629,22 +642,22 @@ async fn completions( } let generate_requests: Vec = req - .prompt + .prompt .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), parameters: GenerateParameters { best_of: None, - temperature: req.temperature, + temperature: temperature, repetition_penalty: req.repetition_penalty, frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, - do_sample: true, + do_sample, max_new_tokens, return_full_text: None, - stop: Vec::new(), + stop: stop.clone(), truncate: None, watermark: false, details: true, From 4607b7e9c43cb11eba66dc07db0dad392032b269 Mon Sep 17 00:00:00 2001 From: Thomas SCHILLACI Date: Tue, 7 May 2024 19:30:18 +0200 Subject: [PATCH 3/3] Readd tool_prompt --- clients/python/text_generation/client.py | 10 +++++++++- clients/python/text_generation/types.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 8acedfb9..2db46fbe 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -171,6 +171,7 @@ class Client: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, + tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, stop: Optional[List[str]] = None, ): @@ -211,6 +212,8 @@ class Client: higher are kept for generation tools (`List[Tool]`): List of tools to use + tool_prompt (`str`): + A prompt to be appended before the tools tool_choice (`str`): The tool to use stop (`List[str]`): @@ -233,6 +236,7 @@ class Client: temperature=temperature, top_p=top_p, tools=tools, + tool_prompt=tool_prompt, tool_choice=tool_choice, stop=stop, ) @@ -648,6 +652,7 @@ class AsyncClient: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, + tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, stop: Optional[List[str]] = None, ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: @@ -688,6 +693,8 @@ class AsyncClient: higher are kept for generation tools (`List[Tool]`): List of tools to use + tool_prompt (`str`): + A prompt to be appended before the tools tool_choice (`str`): The tool to use stop (`List[str]`): @@ -710,6 +717,7 @@ class AsyncClient: temperature=temperature, top_p=top_p, tools=tools, + tool_prompt=tool_prompt, tool_choice=tool_choice, stop=stop, ) @@ -967,4 +975,4 @@ class AsyncClient: except ValidationError: # If we failed to parse the payload, then it is an error payload raise parse_error(resp.status, json_payload) - yield response + yield response \ No newline at end of file diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index eb872ee6..3436d94e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -453,4 +453,4 @@ class StreamResponse(BaseModel): # Inference API currently deployed model class DeployedModel(BaseModel): model_id: str - sha: str + sha: str \ No newline at end of file