feat(python-client): get list of currently deployed tgi models using the inference API (#191)

This commit is contained in:
OlivierDehaene 2023-04-17 18:43:24 +02:00 committed by GitHub
parent c13b9d87c9
commit b927244eb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 70 additions and 32 deletions

View File

@ -22,7 +22,7 @@ to power LLMs api-inference widgets.
## Table of contents ## Table of contents
- [Features](#features) - [Features](#features)
- [Officially Supported Models](#officially-supported-models) - [Optimized Architectures](#optimized-architectures)
- [Get Started](#get-started) - [Get Started](#get-started)
- [Docker](#docker) - [Docker](#docker)
- [API Documentation](#api-documentation) - [API Documentation](#api-documentation)

View File

@ -52,6 +52,14 @@ print(text)
# ' Rayleigh scattering' # ' Rayleigh scattering'
``` ```
Check all currently deployed models on the Huggingface Inference API with `Text Generation` support:
```python
from text_generation.inference_api import deployed_models
print(deployed_models())
```
### Hugging Face Inference Endpoint usage ### Hugging Face Inference Endpoint usage
```python ```python
@ -193,4 +201,9 @@ class StreamResponse:
# Generation details # Generation details
# Only available when the generation is finished # Only available when the generation is finished
details: Optional[StreamDetails] details: Optional[StreamDetails]
# Inference API currently deployed model
class DeployedModel:
model_id: str
sha: str
``` ```

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.4.1" version = "0.5.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -6,12 +6,20 @@ from text_generation import (
Client, Client,
AsyncClient, AsyncClient,
) )
from text_generation.errors import NotSupportedError from text_generation.errors import NotSupportedError, NotFoundError
from text_generation.inference_api import get_supported_models from text_generation.inference_api import check_model_support, deployed_models
def test_get_supported_models(): def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
assert isinstance(get_supported_models(), list) assert check_model_support(flan_t5_xxl)
assert not check_model_support(unsupported_model)
with pytest.raises(NotFoundError):
check_model_support(fake_model)
def test_deployed_models():
deployed_models()
def test_client(flan_t5_xxl): def test_client(flan_t5_xxl):

View File

@ -1,44 +1,57 @@
import os import os
import requests import requests
import base64
import json
import warnings
from typing import List, Optional from typing import Optional, List
from huggingface_hub.utils import build_hf_headers from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__ from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError from text_generation.types import DeployedModel
from text_generation.errors import NotSupportedError, parse_error
INFERENCE_ENDPOINT = os.environ.get( INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
) )
SUPPORTED_MODELS = None
def deployed_models() -> List[DeployedModel]:
def get_supported_models() -> Optional[List[str]]:
""" """
Get the list of supported text-generation models from GitHub Get all currently deployed models with text-generation-inference-support
Returns: Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub List[DeployedModel]: list of all currently deployed models
""" """
global SUPPORTED_MODELS resp = requests.get(
if SUPPORTED_MODELS is not None: f"https://api-inference.huggingface.co/framework/text-generation-inference",
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5, timeout=5,
) )
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.") payload = resp.json()
return None if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]
return models
def check_model_support(repo_id: str) -> bool:
"""
Check if a given model is supported by text-generation-inference
Returns:
bool: whether the model is supported by this client
"""
resp = requests.get(
f"https://api-inference.huggingface.co/status/{repo_id}",
timeout=5,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
framework = payload["framework"]
supported = framework == "text-generation-inference"
return supported
class InferenceAPIClient(Client): class InferenceAPIClient(Client):
@ -83,8 +96,7 @@ class InferenceAPIClient(Client):
""" """
# Text Generation Inference client only supports a subset of the available hub models # Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models() if not check_model_support(repo_id):
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id) raise NotSupportedError(repo_id)
headers = build_hf_headers( headers = build_hf_headers(
@ -140,8 +152,7 @@ class InferenceAPIAsyncClient(AsyncClient):
""" """
# Text Generation Inference client only supports a subset of the available hub models # Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models() if not check_model_support(repo_id):
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id) raise NotSupportedError(repo_id)
headers = build_hf_headers( headers = build_hf_headers(

View File

@ -223,3 +223,9 @@ class StreamResponse(BaseModel):
# Generation details # Generation details
# Only available when the generation is finished # Only available when the generation is finished
details: Optional[StreamDetails] details: Optional[StreamDetails]
# Inference API currently deployed model
class DeployedModel(BaseModel):
model_id: str
sha: str