Add completion route to client and stop parameter

This commit is contained in:
Thomas SCHILLACI 2024-05-07 18:27:28 +02:00
parent 59b3ffea14
commit 9fb1cdc8d5
2 changed files with 253 additions and 47 deletions

View File

@ -11,6 +11,9 @@ from text_generation.types import (
Request,
Parameters,
Grammar,
CompletionRequest,
Completion,
CompletionComplete,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
@ -64,6 +67,94 @@ class Client:
self.cookies = cookies
self.timeout = timeout
def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
):
"""
Given a prompt, generate a response synchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Completion(**payload)
else:
return self._completion_stream_response(request)
def _completion_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def chat(
self,
messages: List[Message],
@ -80,8 +171,8 @@ class Client:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
):
"""
Given a list of messages, generate a response asynchronously
@ -120,10 +211,10 @@ class Client:
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_prompt (`str`):
A prompt to be appended before the tools
tool_choice (`str`):
The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = ChatRequest(
@ -142,8 +233,8 @@ class Client:
temperature=temperature,
top_p=top_p,
tools=tools,
tool_prompt=tool_prompt,
tool_choice=tool_choice,
stop=stop,
)
if not stream:
resp = requests.post(
@ -454,6 +545,93 @@ class AsyncClient:
self.cookies = cookies
self.timeout = ClientTimeout(timeout)
async def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
"""
Given a prompt, generate a response asynchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
return await self._completion_single_response(request)
else:
return self._completion_stream_response(request)
async def _completion_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Completion(**payload)
async def _completion_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def chat(
self,
messages: List[Message],
@ -470,8 +648,8 @@ class AsyncClient:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
@ -510,10 +688,10 @@ class AsyncClient:
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_prompt (`str`):
A prompt to be appended before the tools
tool_choice (`str`):
The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = ChatRequest(
@ -532,8 +710,8 @@ class AsyncClient:
temperature=temperature,
top_p=top_p,
tools=tools,
tool_prompt=tool_prompt,
tool_choice=tool_choice,
stop=stop,
)
if not stream:
return await self._chat_single_response(request)

View File

@ -46,30 +46,6 @@ class Tool(BaseModel):
function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel):
name: Optional[str]
arguments: str
@ -95,24 +71,41 @@ class Choice(BaseModel):
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
class CompletionRequest(BaseModel):
# Model identifier
model: str
system_fingerprint: str
choices: List[Choice]
# Prompt
prompt: str
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Completion(BaseModel):
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
tool_prompt: Optional[str] = None
# Choice of tool to be used
tool_choice: Optional[str] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class Parameters(BaseModel):