diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0e86901d..8acedfb9 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -11,6 +11,9 @@ from text_generation.types import ( Request, Parameters, Grammar, + CompletionRequest, + Completion, + CompletionComplete, ChatRequest, ChatCompletionChunk, ChatComplete, @@ -64,6 +67,94 @@ class Client: self.cookies = cookies self.timeout = timeout + def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ): + """ + Given a prompt, generate a response synchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Completion(**payload) + else: + return self._completion_stream_response(request) + + def _completion_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + def chat( self, messages: List[Message], @@ -80,8 +171,8 @@ class Client: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, - tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ): """ Given a list of messages, generate a response asynchronously @@ -120,10 +211,10 @@ class Client: higher are kept for generation tools (`List[Tool]`): List of tools to use - tool_prompt (`str`): - A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -142,8 +233,8 @@ class Client: temperature=temperature, top_p=top_p, tools=tools, - tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: resp = requests.post( @@ -454,6 +545,93 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout) + async def completion( + self, + prompt: str, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + stop: Optional[List[str]] = None, + ) -> Union[Completion, AsyncIterator[CompletionComplete]]: + """ + Given a prompt, generate a response asynchronously + + Args: + prompt (`str`): + Prompt + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + max_tokens (`int`): + Maximum number of generated tokens + repetition_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + seed (`int`): + Random sampling seed + stream (`bool`): + Stream the response + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated + """ + request = CompletionRequest( + model="tgi", + prompt=prompt, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + seed=seed, + stream=stream, + temperature=temperature, + top_p=top_p, + stop=stop, + ) + if not stream: + return await self._completion_single_response(request) + else: + return self._completion_stream_response(request) + + async def _completion_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Completion(**payload) + + async def _completion_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = CompletionComplete(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def chat( self, messages: List[Message], @@ -470,8 +648,8 @@ class AsyncClient: temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, - tool_prompt: Optional[str] = None, tool_choice: Optional[str] = None, + stop: Optional[List[str]] = None, ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: """ Given a list of messages, generate a response asynchronously @@ -510,10 +688,10 @@ class AsyncClient: higher are kept for generation tools (`List[Tool]`): List of tools to use - tool_prompt (`str`): - A prompt to be appended before the tools tool_choice (`str`): The tool to use + stop (`List[str]`): + Stop generating tokens if a member of `stop` is generated """ request = ChatRequest( @@ -532,8 +710,8 @@ class AsyncClient: temperature=temperature, top_p=top_p, tools=tools, - tool_prompt=tool_prompt, tool_choice=tool_choice, + stop=stop, ) if not stream: return await self._chat_single_response(request) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 5e32bc6f..eb872ee6 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -46,30 +46,6 @@ class Tool(BaseModel): function: dict -class ChatCompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - message: Message - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - # Usage details of the chat completion - usage: Optional[Any] = None - - -class CompletionComplete(BaseModel): - # Index of the chat completion - index: int - # Message associated with the chat completion - text: str - # Log probabilities for the chat completion - logprobs: Optional[Any] - # Reason for completion - finish_reason: str - - class Function(BaseModel): name: Optional[str] arguments: str @@ -95,24 +71,41 @@ class Choice(BaseModel): finish_reason: Optional[str] = None -class ChatCompletionChunk(BaseModel): - id: str - object: str - created: int +class CompletionRequest(BaseModel): + # Model identifier model: str - system_fingerprint: str - choices: List[Choice] + # Prompt + prompt: str + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # The parameter for frequency penalty. 1.0 means no penalty + # Penalize new tokens based on their existing frequency in the text so far, + # decreasing the model's likelihood to repeat the same line verbatim. + frequency_penalty: Optional[float] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None -class ChatComplete(BaseModel): - # Chat completion details - id: str - object: str - created: int - model: str - system_fingerprint: str - choices: List[ChatCompletionComplete] - usage: Any +class CompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + text: str + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str class Completion(BaseModel): @@ -163,6 +156,41 @@ class ChatRequest(BaseModel): tool_prompt: Optional[str] = None # Choice of tool to be used tool_choice: Optional[str] = None + # Stop generating tokens if a member of `stop` is generated + stop: Optional[List[str]] = None + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Optional[Any] = None + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] class Parameters(BaseModel):