From b927244eb57d69fd92ecf36525f97e48f5b8b54f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 17 Apr 2023 18:43:24 +0200 Subject: [PATCH] feat(python-client): get list of currently deployed tgi models using the inference API (#191) --- README.md | 2 +- clients/python/README.md | 13 ++++ clients/python/pyproject.toml | 2 +- clients/python/tests/test_inference_api.py | 16 +++-- .../python/text_generation/inference_api.py | 63 +++++++++++-------- clients/python/text_generation/types.py | 6 ++ 6 files changed, 70 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 60c1a6b..0c63f36 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ to power LLMs api-inference widgets. ## Table of contents - [Features](#features) -- [Officially Supported Models](#officially-supported-models) +- [Optimized Architectures](#optimized-architectures) - [Get Started](#get-started) - [Docker](#docker) - [API Documentation](#api-documentation) diff --git a/clients/python/README.md b/clients/python/README.md index f509e65..99ff185 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -52,6 +52,14 @@ print(text) # ' Rayleigh scattering' ``` +Check all currently deployed models on the Huggingface Inference API with `Text Generation` support: + +```python +from text_generation.inference_api import deployed_models + +print(deployed_models()) +``` + ### Hugging Face Inference Endpoint usage ```python @@ -193,4 +201,9 @@ class StreamResponse: # Generation details # Only available when the generation is finished details: Optional[StreamDetails] + +# Inference API currently deployed model +class DeployedModel: + model_id: str + sha: str ``` \ No newline at end of file diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index d07f1db..1af5de9 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.4.1" +version = "0.5.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py index 79e503a..59297c2 100644 --- a/clients/python/tests/test_inference_api.py +++ b/clients/python/tests/test_inference_api.py @@ -6,12 +6,20 @@ from text_generation import ( Client, AsyncClient, ) -from text_generation.errors import NotSupportedError -from text_generation.inference_api import get_supported_models +from text_generation.errors import NotSupportedError, NotFoundError +from text_generation.inference_api import check_model_support, deployed_models -def test_get_supported_models(): - assert isinstance(get_supported_models(), list) +def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model): + assert check_model_support(flan_t5_xxl) + assert not check_model_support(unsupported_model) + + with pytest.raises(NotFoundError): + check_model_support(fake_model) + + +def test_deployed_models(): + deployed_models() def test_client(flan_t5_xxl): diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index eb70b3d..31635e6 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -1,44 +1,57 @@ import os import requests -import base64 -import json -import warnings -from typing import List, Optional +from typing import Optional, List from huggingface_hub.utils import build_hf_headers from text_generation import Client, AsyncClient, __version__ -from text_generation.errors import NotSupportedError +from text_generation.types import DeployedModel +from text_generation.errors import NotSupportedError, parse_error INFERENCE_ENDPOINT = os.environ.get( "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" ) -SUPPORTED_MODELS = None - -def get_supported_models() -> Optional[List[str]]: +def deployed_models() -> List[DeployedModel]: """ - Get the list of supported text-generation models from GitHub + Get all currently deployed models with text-generation-inference-support Returns: - Optional[List[str]]: supported models list or None if unable to get the list from GitHub + List[DeployedModel]: list of all currently deployed models """ - global SUPPORTED_MODELS - if SUPPORTED_MODELS is not None: - return SUPPORTED_MODELS - - response = requests.get( - "https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json", + resp = requests.get( + f"https://api-inference.huggingface.co/framework/text-generation-inference", timeout=5, ) - if response.status_code == 200: - file_content = response.json()["content"] - SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8")) - return SUPPORTED_MODELS - warnings.warn("Could not retrieve list of supported models.") - return None + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + + models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload] + return models + + +def check_model_support(repo_id: str) -> bool: + """ + Check if a given model is supported by text-generation-inference + + Returns: + bool: whether the model is supported by this client + """ + resp = requests.get( + f"https://api-inference.huggingface.co/status/{repo_id}", + timeout=5, + ) + + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + + framework = payload["framework"] + supported = framework == "text-generation-inference" + return supported class InferenceAPIClient(Client): @@ -83,8 +96,7 @@ class InferenceAPIClient(Client): """ # Text Generation Inference client only supports a subset of the available hub models - supported_models = get_supported_models() - if supported_models is not None and repo_id not in supported_models: + if not check_model_support(repo_id): raise NotSupportedError(repo_id) headers = build_hf_headers( @@ -140,8 +152,7 @@ class InferenceAPIAsyncClient(AsyncClient): """ # Text Generation Inference client only supports a subset of the available hub models - supported_models = get_supported_models() - if supported_models is not None and repo_id not in supported_models: + if not check_model_support(repo_id): raise NotSupportedError(repo_id) headers = build_hf_headers( diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 21a9849..f3f9dcb 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -223,3 +223,9 @@ class StreamResponse(BaseModel): # Generation details # Only available when the generation is finished details: Optional[StreamDetails] + + +# Inference API currently deployed model +class DeployedModel(BaseModel): + model_id: str + sha: str