import json import requests from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError from typing import Dict, Optional, List, AsyncIterator, Iterator from text_generation.types import ( StreamResponse, Response, Request, Parameters, Grammar, ) from text_generation.errors import parse_error class Client: """Client to make calls to a text-generation-inference instance Example: ```python >>> from text_generation import Client >>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz") >>> client.generate("Why is the sky blue?").generated_text ' Rayleigh scattering' >>> result = "" >>> for response in client.generate_stream("Why is the sky blue?"): >>> if not response.token.special: >>> result += response.token.text >>> result ' Rayleigh scattering' ``` """ def __init__( self, base_url: str, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, timeout: int = 10, ): """ Args: base_url (`str`): text-generation-inference instance base url headers (`Optional[Dict[str, str]]`): Additional headers cookies (`Optional[Dict[str, str]]`): Cookies to include in the requests timeout (`int`): Timeout in seconds """ self.base_url = base_url self.headers = headers self.cookies = cookies self.timeout = timeout def generate( self, prompt: str, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text Args: prompt (`str`): Input text do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens best_of (`int`): Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): Random sampling seed stop_sequences (`List[str]`): Stop generating tokens if a member of `stop_sequences` is generated temperature (`float`): The value used to module the logits distribution. top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step Returns: Response: generated response """ # Validate parameters parameters = Parameters( best_of=best_of, details=True, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], temperature=temperature, top_k=top_k, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, top_n_tokens=top_n_tokens, grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) resp = requests.post( self.base_url, json=request.dict(), headers=self.headers, cookies=self.cookies, timeout=self.timeout, ) payload = resp.json() if resp.status_code != 200: raise parse_error(resp.status_code, payload) return Response(**payload[0]) def generate_stream( self, prompt: str, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, grammar: Optional[Grammar] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens Args: prompt (`str`): Input text do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): Random sampling seed stop_sequences (`List[str]`): Stop generating tokens if a member of `stop_sequences` is generated temperature (`float`): The value used to module the logits distribution. top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step Returns: Iterator[StreamResponse]: stream of generated tokens """ # Validate parameters parameters = Parameters( best_of=None, details=True, decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], temperature=temperature, top_k=top_k, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) resp = requests.post( self.base_url, json=request.dict(), headers=self.headers, cookies=self.cookies, timeout=self.timeout, stream=True, ) if resp.status_code != 200: raise parse_error(resp.status_code, resp.json()) # Parse ServerSentEvents for byte_payload in resp.iter_lines(): # Skip line if byte_payload == b"\n": continue payload = byte_payload.decode("utf-8") # Event data if payload.startswith("data:"): # Decode payload json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) # Parse payload try: response = StreamResponse(**json_payload) except ValidationError: # If we failed to parse the payload, then it is an error payload raise parse_error(resp.status_code, json_payload) yield response class AsyncClient: """Asynchronous Client to make calls to a text-generation-inference instance Example: ```python >>> from text_generation import AsyncClient >>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz") >>> response = await client.generate("Why is the sky blue?") >>> response.generated_text ' Rayleigh scattering' >>> result = "" >>> async for response in client.generate_stream("Why is the sky blue?"): >>> if not response.token.special: >>> result += response.token.text >>> result ' Rayleigh scattering' ``` """ def __init__( self, base_url: str, headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, timeout: int = 10, ): """ Args: base_url (`str`): text-generation-inference instance base url headers (`Optional[Dict[str, str]]`): Additional headers cookies (`Optional[Dict[str, str]]`): Cookies to include in the requests timeout (`int`): Timeout in seconds """ self.base_url = base_url self.headers = headers self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) async def generate( self, prompt: str, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously Args: prompt (`str`): Input text do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens best_of (`int`): Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): Random sampling seed stop_sequences (`List[str]`): Stop generating tokens if a member of `stop_sequences` is generated temperature (`float`): The value used to module the logits distribution. top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids top_n_tokens (`int`): Return the `n` most likely tokens at each step Returns: Response: generated response """ # Validate parameters parameters = Parameters( best_of=best_of, details=True, decoder_input_details=decoder_input_details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], temperature=temperature, top_k=top_k, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: async with session.post(self.base_url, json=request.dict()) as resp: payload = await resp.json() if resp.status != 200: raise parse_error(resp.status, payload) return Response(**payload[0]) async def generate_stream( self, prompt: str, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, grammar: Optional[Grammar] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously Args: prompt (`str`): Input text do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. return_full_text (`bool`): Whether to prepend the prompt to the generated text seed (`int`): Random sampling seed stop_sequences (`List[str]`): Stop generating tokens if a member of `stop_sequences` is generated temperature (`float`): The value used to module the logits distribution. top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) top_n_tokens (`int`): Return the `n` most likely tokens at each step Returns: AsyncIterator[StreamResponse]: stream of generated tokens """ # Validate parameters parameters = Parameters( best_of=None, details=True, decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, stop=stop_sequences if stop_sequences is not None else [], temperature=temperature, top_k=top_k, top_p=top_p, truncate=truncate, typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: async with session.post(self.base_url, json=request.dict()) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.json()) # Parse ServerSentEvents async for byte_payload in resp.content: # Skip line if byte_payload == b"\n": continue payload = byte_payload.decode("utf-8") # Event data if payload.startswith("data:"): # Decode payload json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) # Parse payload try: response = StreamResponse(**json_payload) except ValidationError: # If we failed to parse the payload, then it is an error payload raise parse_error(resp.status, json_payload) yield response