feat(python-client): get list of currently deployed tgi models using the inference API (#191)
This commit is contained in:
parent
c13b9d87c9
commit
b927244eb5
|
@ -22,7 +22,7 @@ to power LLMs api-inference widgets.
|
||||||
## Table of contents
|
## Table of contents
|
||||||
|
|
||||||
- [Features](#features)
|
- [Features](#features)
|
||||||
- [Officially Supported Models](#officially-supported-models)
|
- [Optimized Architectures](#optimized-architectures)
|
||||||
- [Get Started](#get-started)
|
- [Get Started](#get-started)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [API Documentation](#api-documentation)
|
- [API Documentation](#api-documentation)
|
||||||
|
|
|
@ -52,6 +52,14 @@ print(text)
|
||||||
# ' Rayleigh scattering'
|
# ' Rayleigh scattering'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Check all currently deployed models on the Huggingface Inference API with `Text Generation` support:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation.inference_api import deployed_models
|
||||||
|
|
||||||
|
print(deployed_models())
|
||||||
|
```
|
||||||
|
|
||||||
### Hugging Face Inference Endpoint usage
|
### Hugging Face Inference Endpoint usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -193,4 +201,9 @@ class StreamResponse:
|
||||||
# Generation details
|
# Generation details
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
details: Optional[StreamDetails]
|
details: Optional[StreamDetails]
|
||||||
|
|
||||||
|
# Inference API currently deployed model
|
||||||
|
class DeployedModel:
|
||||||
|
model_id: str
|
||||||
|
sha: str
|
||||||
```
|
```
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.4.1"
|
version = "0.5.0"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -6,12 +6,20 @@ from text_generation import (
|
||||||
Client,
|
Client,
|
||||||
AsyncClient,
|
AsyncClient,
|
||||||
)
|
)
|
||||||
from text_generation.errors import NotSupportedError
|
from text_generation.errors import NotSupportedError, NotFoundError
|
||||||
from text_generation.inference_api import get_supported_models
|
from text_generation.inference_api import check_model_support, deployed_models
|
||||||
|
|
||||||
|
|
||||||
def test_get_supported_models():
|
def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
|
||||||
assert isinstance(get_supported_models(), list)
|
assert check_model_support(flan_t5_xxl)
|
||||||
|
assert not check_model_support(unsupported_model)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
check_model_support(fake_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deployed_models():
|
||||||
|
deployed_models()
|
||||||
|
|
||||||
|
|
||||||
def test_client(flan_t5_xxl):
|
def test_client(flan_t5_xxl):
|
||||||
|
|
|
@ -1,44 +1,57 @@
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional, List
|
||||||
from huggingface_hub.utils import build_hf_headers
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
from text_generation import Client, AsyncClient, __version__
|
from text_generation import Client, AsyncClient, __version__
|
||||||
from text_generation.errors import NotSupportedError
|
from text_generation.types import DeployedModel
|
||||||
|
from text_generation.errors import NotSupportedError, parse_error
|
||||||
|
|
||||||
INFERENCE_ENDPOINT = os.environ.get(
|
INFERENCE_ENDPOINT = os.environ.get(
|
||||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||||
)
|
)
|
||||||
|
|
||||||
SUPPORTED_MODELS = None
|
|
||||||
|
|
||||||
|
def deployed_models() -> List[DeployedModel]:
|
||||||
def get_supported_models() -> Optional[List[str]]:
|
|
||||||
"""
|
"""
|
||||||
Get the list of supported text-generation models from GitHub
|
Get all currently deployed models with text-generation-inference-support
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
|
List[DeployedModel]: list of all currently deployed models
|
||||||
"""
|
"""
|
||||||
global SUPPORTED_MODELS
|
resp = requests.get(
|
||||||
if SUPPORTED_MODELS is not None:
|
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||||
return SUPPORTED_MODELS
|
|
||||||
|
|
||||||
response = requests.get(
|
|
||||||
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
|
||||||
file_content = response.json()["content"]
|
|
||||||
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
|
||||||
return SUPPORTED_MODELS
|
|
||||||
|
|
||||||
warnings.warn("Could not retrieve list of supported models.")
|
payload = resp.json()
|
||||||
return None
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
|
||||||
|
models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_support(repo_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a given model is supported by text-generation-inference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether the model is supported by this client
|
||||||
|
"""
|
||||||
|
resp = requests.get(
|
||||||
|
f"https://api-inference.huggingface.co/status/{repo_id}",
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
|
||||||
|
framework = payload["framework"]
|
||||||
|
supported = framework == "text-generation-inference"
|
||||||
|
return supported
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIClient(Client):
|
class InferenceAPIClient(Client):
|
||||||
|
@ -83,8 +96,7 @@ class InferenceAPIClient(Client):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Text Generation Inference client only supports a subset of the available hub models
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
supported_models = get_supported_models()
|
if not check_model_support(repo_id):
|
||||||
if supported_models is not None and repo_id not in supported_models:
|
|
||||||
raise NotSupportedError(repo_id)
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
headers = build_hf_headers(
|
headers = build_hf_headers(
|
||||||
|
@ -140,8 +152,7 @@ class InferenceAPIAsyncClient(AsyncClient):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Text Generation Inference client only supports a subset of the available hub models
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
supported_models = get_supported_models()
|
if not check_model_support(repo_id):
|
||||||
if supported_models is not None and repo_id not in supported_models:
|
|
||||||
raise NotSupportedError(repo_id)
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
headers = build_hf_headers(
|
headers = build_hf_headers(
|
||||||
|
|
|
@ -223,3 +223,9 @@ class StreamResponse(BaseModel):
|
||||||
# Generation details
|
# Generation details
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
details: Optional[StreamDetails]
|
details: Optional[StreamDetails]
|
||||||
|
|
||||||
|
|
||||||
|
# Inference API currently deployed model
|
||||||
|
class DeployedModel(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
sha: str
|
||||||
|
|
Loading…
Reference in New Issue