Add completion route to client and stop parameter
This commit is contained in:
parent
59b3ffea14
commit
9fb1cdc8d5
|
@ -11,6 +11,9 @@ from text_generation.types import (
|
|||
Request,
|
||||
Parameters,
|
||||
Grammar,
|
||||
CompletionRequest,
|
||||
Completion,
|
||||
CompletionComplete,
|
||||
ChatRequest,
|
||||
ChatCompletionChunk,
|
||||
ChatComplete,
|
||||
|
@ -64,6 +67,94 @@ class Client:
|
|||
self.cookies = cookies
|
||||
self.timeout = timeout
|
||||
|
||||
def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Given a prompt, generate a response synchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Prompt
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
"""
|
||||
request = CompletionRequest(
|
||||
model="tgi",
|
||||
prompt=prompt,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return Completion(**payload)
|
||||
else:
|
||||
return self._completion_stream_response(request)
|
||||
|
||||
def _completion_stream_response(self, request):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
# iterate and print stream
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = CompletionComplete(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
|
@ -80,8 +171,8 @@ class Client:
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
@ -120,10 +211,10 @@ class Client:
|
|||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_prompt (`str`):
|
||||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
|
@ -142,8 +233,8 @@ class Client:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
|
@ -454,6 +545,93 @@ class AsyncClient:
|
|||
self.cookies = cookies
|
||||
self.timeout = ClientTimeout(timeout)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
|
||||
"""
|
||||
Given a prompt, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Prompt
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
"""
|
||||
request = CompletionRequest(
|
||||
model="tgi",
|
||||
prompt=prompt,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
return await self._completion_single_response(request)
|
||||
else:
|
||||
return self._completion_stream_response(request)
|
||||
|
||||
async def _completion_single_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/completions", json=request.dict()
|
||||
) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return Completion(**payload)
|
||||
|
||||
async def _completion_stream_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/completions", json=request.dict()
|
||||
) as resp:
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = CompletionComplete(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
|
@ -470,8 +648,8 @@ class AsyncClient:
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
@ -510,10 +688,10 @@ class AsyncClient:
|
|||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_prompt (`str`):
|
||||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
|
@ -532,8 +710,8 @@ class AsyncClient:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
return await self._chat_single_response(request)
|
||||
|
|
|
@ -46,30 +46,6 @@ class Tool(BaseModel):
|
|||
function: dict
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class CompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
text: str
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
@ -95,24 +71,41 @@ class Choice(BaseModel):
|
|||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
class CompletionRequest(BaseModel):
|
||||
# Model identifier
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
# Prompt
|
||||
prompt: str
|
||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float] = None
|
||||
# The parameter for frequency penalty. 1.0 means no penalty
|
||||
# Penalize new tokens based on their existing frequency in the text so far,
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Maximum number of tokens to generate
|
||||
max_tokens: Optional[int] = None
|
||||
# Flag to indicate streaming response
|
||||
stream: bool = False
|
||||
# Random sampling seed
|
||||
seed: Optional[int] = None
|
||||
# Sampling temperature
|
||||
temperature: Optional[float] = None
|
||||
# Top-p value for nucleus sampling
|
||||
top_p: Optional[float] = None
|
||||
# Stop generating tokens if a member of `stop` is generated
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
class CompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
text: str
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
|
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
|
|||
tool_prompt: Optional[str] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
# Stop generating tokens if a member of `stop` is generated
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
|
|
Loading…
Reference in New Issue