hf_text-generation-inference/clients/python/text_generation/inference_api.py

169 lines
5.4 KiB
Python
Raw Normal View History

2023-03-07 10:52:22 -07:00
import os
import requests
from typing import Dict, Optional, List
2023-03-07 10:52:22 -07:00
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.types import DeployedModel
from text_generation.errors import NotSupportedError, parse_error
2023-03-07 10:52:22 -07:00
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
2023-03-07 10:52:22 -07:00
"""
Get all currently deployed models with text-generation-inference-support
2023-03-07 10:52:22 -07:00
Returns:
List[DeployedModel]: list of all currently deployed models
2023-03-07 10:52:22 -07:00
"""
resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]
return models
2023-03-07 10:52:22 -07:00
def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
"""
Check if a given model is supported by text-generation-inference
Returns:
bool: whether the model is supported by this client
"""
resp = requests.get(
f"https://api-inference.huggingface.co/status/{repo_id}",
headers=headers,
2023-03-07 10:52:22 -07:00
timeout=5,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
framework = payload["framework"]
supported = framework == "text-generation-inference"
return supported
2023-03-07 10:52:22 -07:00
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
2023-03-07 10:52:22 -07:00
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(
base_url, headers=headers, timeout=timeout
)
2023-03-07 10:52:22 -07:00
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
2023-03-07 10:52:22 -07:00
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
2023-03-07 10:52:22 -07:00
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(
base_url, headers=headers, timeout=timeout
)