diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index e05a002e..03bc3888 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -36,7 +36,11 @@ class Client: """ def __init__( - self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10 + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + timeout: int = 10, ): """ Args: @@ -44,11 +48,14 @@ class Client: text-generation-inference instance base url headers (`Optional[Dict[str, str]]`): Additional headers + cookies (`Optional[Dict[str, str]]`): + Cookies to include in the requests timeout (`int`): Timeout in seconds """ self.base_url = base_url self.headers = headers + self.cookies = cookies self.timeout = timeout def generate( @@ -130,6 +137,7 @@ class Client: self.base_url, json=request.dict(), headers=self.headers, + cookies=self.cookies, timeout=self.timeout, ) payload = resp.json() @@ -216,6 +224,7 @@ class Client: self.base_url, json=request.dict(), headers=self.headers, + cookies=self.cookies, timeout=self.timeout, stream=True, ) @@ -267,7 +276,11 @@ class AsyncClient: """ def __init__( - self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10 + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + timeout: int = 10, ): """ Args: @@ -275,11 +288,14 @@ class AsyncClient: text-generation-inference instance base url headers (`Optional[Dict[str, str]]`): Additional headers + cookies (`Optional[Dict[str, str]]`): + Cookies to include in the requests timeout (`int`): Timeout in seconds """ self.base_url = base_url self.headers = headers + self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) async def generate( @@ -357,7 +373,9 @@ class AsyncClient: ) request = Request(inputs=prompt, stream=False, parameters=parameters) - async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: async with session.post(self.base_url, json=request.dict()) as resp: payload = await resp.json() @@ -440,7 +458,9 @@ class AsyncClient: ) request = Request(inputs=prompt, stream=True, parameters=parameters) - async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: async with session.post(self.base_url, json=request.dict()) as resp: if resp.status != 200: diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index bc6022b3..eb70b3d1 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -92,7 +92,9 @@ class InferenceAPIClient(Client): ) base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" - super(InferenceAPIClient, self).__init__(base_url, headers, timeout) + super(InferenceAPIClient, self).__init__( + base_url, headers=headers, timeout=timeout + ) class InferenceAPIAsyncClient(AsyncClient): @@ -147,4 +149,6 @@ class InferenceAPIAsyncClient(AsyncClient): ) base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" - super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout) + super(InferenceAPIAsyncClient, self).__init__( + base_url, headers=headers, timeout=timeout + )