fix(python-client): add auth headers to is supported requests (#234)
This commit is contained in:
parent
37b64a5c10
commit
323546df1d
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.5.0"
|
version = "0.5.1"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Dict, Optional, List
|
||||||
from huggingface_hub.utils import build_hf_headers
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
from text_generation import Client, AsyncClient, __version__
|
from text_generation import Client, AsyncClient, __version__
|
||||||
|
@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def deployed_models() -> List[DeployedModel]:
|
def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
|
||||||
"""
|
"""
|
||||||
Get all currently deployed models with text-generation-inference-support
|
Get all currently deployed models with text-generation-inference-support
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
|
||||||
"""
|
"""
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
f"https://api-inference.huggingface.co/framework/text-generation-inference",
|
||||||
|
headers=headers,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
def check_model_support(repo_id: str) -> bool:
|
def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a given model is supported by text-generation-inference
|
Check if a given model is supported by text-generation-inference
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
f"https://api-inference.huggingface.co/status/{repo_id}",
|
f"https://api-inference.huggingface.co/status/{repo_id}",
|
||||||
|
headers=headers,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Text Generation Inference client only supports a subset of the available hub models
|
|
||||||
if not check_model_support(repo_id):
|
|
||||||
raise NotSupportedError(repo_id)
|
|
||||||
|
|
||||||
headers = build_hf_headers(
|
headers = build_hf_headers(
|
||||||
token=token, library_name="text-generation", library_version=__version__
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
if not check_model_support(repo_id, headers):
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
super(InferenceAPIClient, self).__init__(
|
super(InferenceAPIClient, self).__init__(
|
||||||
|
@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
|
||||||
timeout (`int`):
|
timeout (`int`):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Text Generation Inference client only supports a subset of the available hub models
|
|
||||||
if not check_model_support(repo_id):
|
|
||||||
raise NotSupportedError(repo_id)
|
|
||||||
|
|
||||||
headers = build_hf_headers(
|
headers = build_hf_headers(
|
||||||
token=token, library_name="text-generation", library_version=__version__
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
if not check_model_support(repo_id, headers):
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
super(InferenceAPIAsyncClient, self).__init__(
|
super(InferenceAPIAsyncClient, self).__init__(
|
||||||
|
|
Loading…
Reference in New Issue