2023-03-07 10:52:22 -07:00
|
|
|
import os
|
|
|
|
import requests
|
|
|
|
|
2023-04-17 10:43:24 -06:00
|
|
|
from typing import Optional, List
|
2023-03-07 10:52:22 -07:00
|
|
|
from huggingface_hub.utils import build_hf_headers
|
|
|
|
|
|
|
|
from text_generation import Client, AsyncClient, __version__
|
2023-04-17 10:43:24 -06:00
|
|
|
from text_generation.types import DeployedModel
|
|
|
|
from text_generation.errors import NotSupportedError, parse_error
|
2023-03-07 10:52:22 -07:00
|
|
|
|
|
|
|
INFERENCE_ENDPOINT = os.environ.get(
|
|
|
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-04-17 10:43:24 -06:00
|
|
|
def deployed_models() -> List[DeployedModel]:
|
2023-03-07 10:52:22 -07:00
|
|
|
"""
|
2023-04-17 10:43:24 -06:00
|
|
|
Get all currently deployed models with text-generation-inference-support
|
2023-03-07 10:52:22 -07:00
|
|
|
|
|
|
|
Returns:
|
2023-04-17 10:43:24 -06:00
|
|
|
List[DeployedModel]: list of all currently deployed models
|
2023-03-07 10:52:22 -07:00
|
|
|
"""
|
2023-04-17 10:43:24 -06:00
|
|
|
resp = requests.get(
|
|
|
|
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
|
|
|
timeout=5,
|
|
|
|
)
|
|
|
|
|
|
|
|
payload = resp.json()
|
|
|
|
if resp.status_code != 200:
|
|
|
|
raise parse_error(resp.status_code, payload)
|
|
|
|
|
|
|
|
models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]
|
|
|
|
return models
|
|
|
|
|
2023-03-07 10:52:22 -07:00
|
|
|
|
2023-04-17 10:43:24 -06:00
|
|
|
def check_model_support(repo_id: str) -> bool:
|
|
|
|
"""
|
|
|
|
Check if a given model is supported by text-generation-inference
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: whether the model is supported by this client
|
|
|
|
"""
|
|
|
|
resp = requests.get(
|
|
|
|
f"https://api-inference.huggingface.co/status/{repo_id}",
|
2023-03-07 10:52:22 -07:00
|
|
|
timeout=5,
|
|
|
|
)
|
|
|
|
|
2023-04-17 10:43:24 -06:00
|
|
|
payload = resp.json()
|
|
|
|
if resp.status_code != 200:
|
|
|
|
raise parse_error(resp.status_code, payload)
|
|
|
|
|
|
|
|
framework = payload["framework"]
|
|
|
|
supported = framework == "text-generation-inference"
|
|
|
|
return supported
|
2023-03-07 10:52:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
class InferenceAPIClient(Client):
|
|
|
|
"""Client to make calls to the HuggingFace Inference API.
|
|
|
|
|
|
|
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
|
|
|
text-generation-inference
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
>>> from text_generation import InferenceAPIClient
|
|
|
|
|
|
|
|
>>> client = InferenceAPIClient("bigscience/bloomz")
|
|
|
|
>>> client.generate("Why is the sky blue?").generated_text
|
|
|
|
' Rayleigh scattering'
|
|
|
|
|
|
|
|
>>> result = ""
|
|
|
|
>>> for response in client.generate_stream("Why is the sky blue?"):
|
|
|
|
>>> if not response.token.special:
|
|
|
|
>>> result += response.token.text
|
|
|
|
>>> result
|
|
|
|
' Rayleigh scattering'
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
|
|
|
"""
|
|
|
|
Init headers and API information
|
|
|
|
|
|
|
|
Args:
|
|
|
|
repo_id (`str`):
|
|
|
|
Id of repository (e.g. `bigscience/bloom`).
|
|
|
|
token (`str`, `optional`):
|
|
|
|
The API token to use as HTTP bearer authorization. This is not
|
|
|
|
the authentication token. You can find the token in
|
|
|
|
https://huggingface.co/settings/token. Alternatively, you can
|
|
|
|
find both your organizations and personal API tokens using
|
|
|
|
`HfApi().whoami(token)`.
|
|
|
|
timeout (`int`):
|
|
|
|
Timeout in seconds
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Text Generation Inference client only supports a subset of the available hub models
|
2023-04-17 10:43:24 -06:00
|
|
|
if not check_model_support(repo_id):
|
2023-03-07 10:52:22 -07:00
|
|
|
raise NotSupportedError(repo_id)
|
|
|
|
|
|
|
|
headers = build_hf_headers(
|
|
|
|
token=token, library_name="text-generation", library_version=__version__
|
|
|
|
)
|
|
|
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
|
|
|
|
2023-03-23 11:01:01 -06:00
|
|
|
super(InferenceAPIClient, self).__init__(
|
|
|
|
base_url, headers=headers, timeout=timeout
|
|
|
|
)
|
2023-03-07 10:52:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
class InferenceAPIAsyncClient(AsyncClient):
|
|
|
|
"""Aynschronous Client to make calls to the HuggingFace Inference API.
|
|
|
|
|
|
|
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
|
|
|
text-generation-inference
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
>>> from text_generation import InferenceAPIAsyncClient
|
|
|
|
|
|
|
|
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
|
|
|
|
>>> response = await client.generate("Why is the sky blue?")
|
|
|
|
>>> response.generated_text
|
|
|
|
' Rayleigh scattering'
|
|
|
|
|
|
|
|
>>> result = ""
|
|
|
|
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
|
|
|
>>> if not response.token.special:
|
|
|
|
>>> result += response.token.text
|
|
|
|
>>> result
|
|
|
|
' Rayleigh scattering'
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
|
|
|
"""
|
|
|
|
Init headers and API information
|
|
|
|
|
|
|
|
Args:
|
|
|
|
repo_id (`str`):
|
|
|
|
Id of repository (e.g. `bigscience/bloom`).
|
|
|
|
token (`str`, `optional`):
|
|
|
|
The API token to use as HTTP bearer authorization. This is not
|
|
|
|
the authentication token. You can find the token in
|
|
|
|
https://huggingface.co/settings/token. Alternatively, you can
|
|
|
|
find both your organizations and personal API tokens using
|
|
|
|
`HfApi().whoami(token)`.
|
|
|
|
timeout (`int`):
|
|
|
|
Timeout in seconds
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Text Generation Inference client only supports a subset of the available hub models
|
2023-04-17 10:43:24 -06:00
|
|
|
if not check_model_support(repo_id):
|
2023-03-07 10:52:22 -07:00
|
|
|
raise NotSupportedError(repo_id)
|
|
|
|
|
|
|
|
headers = build_hf_headers(
|
|
|
|
token=token, library_name="text-generation", library_version=__version__
|
|
|
|
)
|
|
|
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
|
|
|
|
2023-03-23 11:01:01 -06:00
|
|
|
super(InferenceAPIAsyncClient, self).__init__(
|
|
|
|
base_url, headers=headers, timeout=timeout
|
|
|
|
)
|