fix(python-client): add auth headers to is supported requests (#234)

This commit is contained in:
OlivierDehaene 2023-04-25 13:55:26 +02:00 committed by GitHub
parent 37b64a5c10
commit 323546df1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.5.0" version = "0.5.1"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -1,7 +1,7 @@
import os import os
import requests import requests
from typing import Optional, List from typing import Dict, Optional, List
from huggingface_hub.utils import build_hf_headers from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__ from text_generation import Client, AsyncClient, __version__
@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
) )
def deployed_models() -> List[DeployedModel]: def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
""" """
Get all currently deployed models with text-generation-inference-support Get all currently deployed models with text-generation-inference-support
@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference", f"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5, timeout=5,
) )
@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
return models return models
def check_model_support(repo_id: str) -> bool: def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
""" """
Check if a given model is supported by text-generation-inference Check if a given model is supported by text-generation-inference
@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/status/{repo_id}", f"https://api-inference.huggingface.co/status/{repo_id}",
headers=headers,
timeout=5, timeout=5,
) )
@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
Timeout in seconds Timeout in seconds
""" """
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id):
raise NotSupportedError(repo_id)
headers = build_hf_headers( headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__ token=token, library_name="text-generation", library_version=__version__
) )
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__( super(InferenceAPIClient, self).__init__(
@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id):
raise NotSupportedError(repo_id)
headers = build_hf_headers( headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__ token=token, library_name="text-generation", library_version=__version__
) )
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__( super(InferenceAPIAsyncClient, self).__init__(