feat(python-client): add cookies to Client constructors and requests (#132)
I have a use case where we need to pass cookies (for auth reasons) to an internally hosted server. Note: I couldn't get the client tests to pass - do you need to have an HF token? ```python FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid ```
This commit is contained in:
parent
a3b7db932f
commit
7850119055
|
@ -36,7 +36,11 @@ class Client:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
self,
|
||||||
|
base_url: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: int = 10,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -44,11 +48,14 @@ class Client:
|
||||||
text-generation-inference instance base url
|
text-generation-inference instance base url
|
||||||
headers (`Optional[Dict[str, str]]`):
|
headers (`Optional[Dict[str, str]]`):
|
||||||
Additional headers
|
Additional headers
|
||||||
|
cookies (`Optional[Dict[str, str]]`):
|
||||||
|
Cookies to include in the requests
|
||||||
timeout (`int`):
|
timeout (`int`):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -130,6 +137,7 @@ class Client:
|
||||||
self.base_url,
|
self.base_url,
|
||||||
json=request.dict(),
|
json=request.dict(),
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
payload = resp.json()
|
payload = resp.json()
|
||||||
|
@ -216,6 +224,7 @@ class Client:
|
||||||
self.base_url,
|
self.base_url,
|
||||||
json=request.dict(),
|
json=request.dict(),
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
@ -267,7 +276,11 @@ class AsyncClient:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10
|
self,
|
||||||
|
base_url: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: int = 10,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -275,11 +288,14 @@ class AsyncClient:
|
||||||
text-generation-inference instance base url
|
text-generation-inference instance base url
|
||||||
headers (`Optional[Dict[str, str]]`):
|
headers (`Optional[Dict[str, str]]`):
|
||||||
Additional headers
|
Additional headers
|
||||||
|
cookies (`Optional[Dict[str, str]]`):
|
||||||
|
Cookies to include in the requests
|
||||||
timeout (`int`):
|
timeout (`int`):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout * 60)
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
|
@ -357,7 +373,9 @@ class AsyncClient:
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
payload = await resp.json()
|
payload = await resp.json()
|
||||||
|
|
||||||
|
@ -440,7 +458,9 @@ class AsyncClient:
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
|
|
|
@ -92,7 +92,9 @@ class InferenceAPIClient(Client):
|
||||||
)
|
)
|
||||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
super(InferenceAPIClient, self).__init__(base_url, headers, timeout)
|
super(InferenceAPIClient, self).__init__(
|
||||||
|
base_url, headers=headers, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIAsyncClient(AsyncClient):
|
class InferenceAPIAsyncClient(AsyncClient):
|
||||||
|
@ -147,4 +149,6 @@ class InferenceAPIAsyncClient(AsyncClient):
|
||||||
)
|
)
|
||||||
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout)
|
super(InferenceAPIAsyncClient, self).__init__(
|
||||||
|
base_url, headers=headers, timeout=timeout
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue