feat(python-client): get list of currently deployed tgi models using the inference API (#191)
This commit is contained in:
parent
c13b9d87c9
commit
b927244eb5
|
@ -22,7 +22,7 @@ to power LLMs api-inference widgets.
|
|||
## Table of contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Officially Supported Models](#officially-supported-models)
|
||||
- [Optimized Architectures](#optimized-architectures)
|
||||
- [Get Started](#get-started)
|
||||
- [Docker](#docker)
|
||||
- [API Documentation](#api-documentation)
|
||||
|
|
|
@ -52,6 +52,14 @@ print(text)
|
|||
# ' Rayleigh scattering'
|
||||
```
|
||||
|
||||
Check all currently deployed models on the Huggingface Inference API with `Text Generation` support:
|
||||
|
||||
```python
|
||||
from text_generation.inference_api import deployed_models
|
||||
|
||||
print(deployed_models())
|
||||
```
|
||||
|
||||
### Hugging Face Inference Endpoint usage
|
||||
|
||||
```python
|
||||
|
@ -193,4 +201,9 @@ class StreamResponse:
|
|||
# Generation details
|
||||
# Only available when the generation is finished
|
||||
details: Optional[StreamDetails]
|
||||
|
||||
# Inference API currently deployed model
|
||||
class DeployedModel:
|
||||
model_id: str
|
||||
sha: str
|
||||
```
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
|
|
@ -6,12 +6,20 @@ from text_generation import (
|
|||
Client,
|
||||
AsyncClient,
|
||||
)
|
||||
from text_generation.errors import NotSupportedError
|
||||
from text_generation.inference_api import get_supported_models
|
||||
from text_generation.errors import NotSupportedError, NotFoundError
|
||||
from text_generation.inference_api import check_model_support, deployed_models
|
||||
|
||||
|
||||
def test_get_supported_models():
|
||||
assert isinstance(get_supported_models(), list)
|
||||
def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
|
||||
assert check_model_support(flan_t5_xxl)
|
||||
assert not check_model_support(unsupported_model)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
check_model_support(fake_model)
|
||||
|
||||
|
||||
def test_deployed_models():
|
||||
deployed_models()
|
||||
|
||||
|
||||
def test_client(flan_t5_xxl):
|
||||
|
|
|
@ -1,44 +1,57 @@
|
|||
import os
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional, List
|
||||
from huggingface_hub.utils import build_hf_headers
|
||||
|
||||
from text_generation import Client, AsyncClient, __version__
|
||||
from text_generation.errors import NotSupportedError
|
||||
from text_generation.types import DeployedModel
|
||||
from text_generation.errors import NotSupportedError, parse_error
|
||||
|
||||
INFERENCE_ENDPOINT = os.environ.get(
|
||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||
)
|
||||
|
||||
SUPPORTED_MODELS = None
|
||||
|
||||
|
||||
def get_supported_models() -> Optional[List[str]]:
|
||||
def deployed_models() -> List[DeployedModel]:
|
||||
"""
|
||||
Get the list of supported text-generation models from GitHub
|
||||
Get all currently deployed models with text-generation-inference-support
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
|
||||
List[DeployedModel]: list of all currently deployed models
|
||||
"""
|
||||
global SUPPORTED_MODELS
|
||||
if SUPPORTED_MODELS is not None:
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
file_content = response.json()["content"]
|
||||
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
warnings.warn("Could not retrieve list of supported models.")
|
||||
return None
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
|
||||
models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]
|
||||
return models
|
||||
|
||||
|
||||
def check_model_support(repo_id: str) -> bool:
|
||||
"""
|
||||
Check if a given model is supported by text-generation-inference
|
||||
|
||||
Returns:
|
||||
bool: whether the model is supported by this client
|
||||
"""
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/status/{repo_id}",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
|
||||
framework = payload["framework"]
|
||||
supported = framework == "text-generation-inference"
|
||||
return supported
|
||||
|
||||
|
||||
class InferenceAPIClient(Client):
|
||||
|
@ -83,8 +96,7 @@ class InferenceAPIClient(Client):
|
|||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and repo_id not in supported_models:
|
||||
if not check_model_support(repo_id):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
|
@ -140,8 +152,7 @@ class InferenceAPIAsyncClient(AsyncClient):
|
|||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and repo_id not in supported_models:
|
||||
if not check_model_support(repo_id):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
|
|
|
@ -223,3 +223,9 @@ class StreamResponse(BaseModel):
|
|||
# Generation details
|
||||
# Only available when the generation is finished
|
||||
details: Optional[StreamDetails]
|
||||
|
||||
|
||||
# Inference API currently deployed model
|
||||
class DeployedModel(BaseModel):
|
||||
model_id: str
|
||||
sha: str
|
||||
|
|
Loading…
Reference in New Issue