From 323546df1da2929e433ce197499ab71621dec51d Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 25 Apr 2023 13:55:26 +0200 Subject: [PATCH] fix(python-client): add auth headers to is supported requests (#234) --- clients/python/pyproject.toml | 2 +- .../python/text_generation/inference_api.py | 27 ++++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 1af5de9..d883af3 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.5.0" +version = "0.5.1" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index 31635e6..93b0de8 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -1,7 +1,7 @@ import os import requests -from typing import Optional, List +from typing import Dict, Optional, List from huggingface_hub.utils import build_hf_headers from text_generation import Client, AsyncClient, __version__ @@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get( ) -def deployed_models() -> List[DeployedModel]: +def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: """ Get all currently deployed models with text-generation-inference-support @@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]: """ resp = requests.get( f"https://api-inference.huggingface.co/framework/text-generation-inference", + headers=headers, timeout=5, ) @@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]: return models -def check_model_support(repo_id: str) -> bool: +def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool: """ Check if a given model is supported by text-generation-inference @@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool: """ resp = requests.get( f"https://api-inference.huggingface.co/status/{repo_id}", + headers=headers, timeout=5, ) @@ -95,13 +97,14 @@ class InferenceAPIClient(Client): Timeout in seconds """ - # Text Generation Inference client only supports a subset of the available hub models - if not check_model_support(repo_id): - raise NotSupportedError(repo_id) - headers = build_hf_headers( token=token, library_name="text-generation", library_version=__version__ ) + + # Text Generation Inference client only supports a subset of the available hub models + if not check_model_support(repo_id, headers): + raise NotSupportedError(repo_id) + base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" super(InferenceAPIClient, self).__init__( @@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient): timeout (`int`): Timeout in seconds """ - - # Text Generation Inference client only supports a subset of the available hub models - if not check_model_support(repo_id): - raise NotSupportedError(repo_id) - headers = build_hf_headers( token=token, library_name="text-generation", library_version=__version__ ) + + # Text Generation Inference client only supports a subset of the available hub models + if not check_model_support(repo_id, headers): + raise NotSupportedError(repo_id) + base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" super(InferenceAPIAsyncClient, self).__init__(