hf_text-generation-inference/clients/python/text_generation/inference_api.py

151 lines
5.1 KiB
Python

import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
"""
Get the list of supported text-generation models from GitHub
Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
"""
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)