fix(python-client): add auth headers to is supported requests (#234)
This commit is contained in:
parent
37b64a5c10
commit
323546df1d
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import requests
|
||||
|
||||
from typing import Optional, List
|
||||
from typing import Dict, Optional, List
|
||||
from huggingface_hub.utils import build_hf_headers
|
||||
|
||||
from text_generation import Client, AsyncClient, __version__
|
||||
|
@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
|
|||
)
|
||||
|
||||
|
||||
def deployed_models() -> List[DeployedModel]:
|
||||
def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
|
||||
"""
|
||||
Get all currently deployed models with text-generation-inference-support
|
||||
|
||||
|
@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
|
|||
"""
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
|
|||
return models
|
||||
|
||||
|
||||
def check_model_support(repo_id: str) -> bool:
|
||||
def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
|
||||
"""
|
||||
Check if a given model is supported by text-generation-inference
|
||||
|
||||
|
@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
|
|||
"""
|
||||
resp = requests.get(
|
||||
f"https://api-inference.huggingface.co/status/{repo_id}",
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
|
|||
Timeout in seconds
|
||||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
if not check_model_support(repo_id):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
token=token, library_name="text-generation", library_version=__version__
|
||||
)
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
if not check_model_support(repo_id, headers):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||
|
||||
super(InferenceAPIClient, self).__init__(
|
||||
|
@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
|
|||
timeout (`int`):
|
||||
Timeout in seconds
|
||||
"""
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
if not check_model_support(repo_id):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
headers = build_hf_headers(
|
||||
token=token, library_name="text-generation", library_version=__version__
|
||||
)
|
||||
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
if not check_model_support(repo_id, headers):
|
||||
raise NotSupportedError(repo_id)
|
||||
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||
|
||||
super(InferenceAPIAsyncClient, self).__init__(
|
||||
|
|
Loading…
Reference in New Issue