fix(python-client): add auth headers to is supported requests (#234)

This commit is contained in:
OlivierDehaene 2023-04-25 13:55:26 +02:00 committed by GitHub
parent 37b64a5c10
commit 323546df1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation"
version = "0.5.0"
version = "0.5.1"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -1,7 +1,7 @@
import os
import requests
from typing import Optional, List
from typing import Dict, Optional, List
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
)
def deployed_models() -> List[DeployedModel]:
def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
"""
Get all currently deployed models with text-generation-inference-support
@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
"""
resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5,
)
@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
return models
def check_model_support(repo_id: str) -> bool:
def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
"""
Check if a given model is supported by text-generation-inference
@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
"""
resp = requests.get(
f"https://api-inference.huggingface.co/status/{repo_id}",
headers=headers,
timeout=5,
)
@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id):
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(
@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id):
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(