Add completion route to client and add stop parameter where it's missing (#1869)
# What does this PR do? - Add the stop parameter to the completion route - Add the completion method to the python client - Add the stop parameter to the python client's chat method ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --------- Co-authored-by: Thomas SCHILLACI <tschilla@px101.prod.exalead.com> Co-authored-by: Thomas Schillaci <thomas.schillaci@3ds.com>
This commit is contained in:
parent
f4a073ae6d
commit
629047cb82
|
@ -13,6 +13,9 @@ from text_generation.types import (
|
|||
Request,
|
||||
Parameters,
|
||||
Grammar,
|
||||
CompletionRequest,
|
||||
Completion,
|
||||
CompletionComplete,
|
||||
ChatRequest,
|
||||
ChatCompletionChunk,
|
||||
ChatComplete,
|
||||
|
@ -70,6 +73,94 @@ class Client:
|
|||
self.cookies = cookies
|
||||
self.timeout = timeout
|
||||
|
||||
def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Given a prompt, generate a response synchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Prompt
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
"""
|
||||
request = CompletionRequest(
|
||||
model="tgi",
|
||||
prompt=prompt,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return Completion(**payload)
|
||||
else:
|
||||
return self._completion_stream_response(request)
|
||||
|
||||
def _completion_stream_response(self, request):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
# iterate and print stream
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = CompletionComplete(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
|
@ -88,6 +179,7 @@ class Client:
|
|||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
@ -130,6 +222,8 @@ class Client:
|
|||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
|
@ -150,6 +244,7 @@ class Client:
|
|||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
|
@ -461,6 +556,93 @@ class AsyncClient:
|
|||
self.cookies = cookies
|
||||
self.timeout = ClientTimeout(timeout)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
prompt: str,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
|
||||
"""
|
||||
Given a prompt, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
Prompt
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty
|
||||
Penalize new tokens based on their existing frequency in the text so far,
|
||||
decreasing the model's likelihood to repeat the same line verbatim.
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
repetition_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
"""
|
||||
request = CompletionRequest(
|
||||
model="tgi",
|
||||
prompt=prompt,
|
||||
frequency_penalty=frequency_penalty,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
return await self._completion_single_response(request)
|
||||
else:
|
||||
return self._completion_stream_response(request)
|
||||
|
||||
async def _completion_single_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/completions", json=request.dict()
|
||||
) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return Completion(**payload)
|
||||
|
||||
async def _completion_stream_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/completions", json=request.dict()
|
||||
) as resp:
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = CompletionComplete(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
|
@ -479,6 +661,7 @@ class AsyncClient:
|
|||
tools: Optional[List[Tool]] = None,
|
||||
tool_prompt: Optional[str] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
@ -521,6 +704,8 @@ class AsyncClient:
|
|||
A prompt to be appended before the tools
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
stop (`List[str]`):
|
||||
Stop generating tokens if a member of `stop` is generated
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
|
@ -541,6 +726,7 @@ class AsyncClient:
|
|||
tools=tools,
|
||||
tool_prompt=tool_prompt,
|
||||
tool_choice=tool_choice,
|
||||
stop=stop,
|
||||
)
|
||||
if not stream:
|
||||
return await self._chat_single_response(request)
|
||||
|
|
|
@ -46,30 +46,6 @@ class Tool(BaseModel):
|
|||
function: dict
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class CompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
text: str
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
@ -95,24 +71,41 @@ class Choice(BaseModel):
|
|||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
class CompletionRequest(BaseModel):
|
||||
# Model identifier
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
# Prompt
|
||||
prompt: str
|
||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
repetition_penalty: Optional[float] = None
|
||||
# The parameter for frequency penalty. 1.0 means no penalty
|
||||
# Penalize new tokens based on their existing frequency in the text so far,
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Maximum number of tokens to generate
|
||||
max_tokens: Optional[int] = None
|
||||
# Flag to indicate streaming response
|
||||
stream: bool = False
|
||||
# Random sampling seed
|
||||
seed: Optional[int] = None
|
||||
# Sampling temperature
|
||||
temperature: Optional[float] = None
|
||||
# Top-p value for nucleus sampling
|
||||
top_p: Optional[float] = None
|
||||
# Stop generating tokens if a member of `stop` is generated
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
class CompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
text: str
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
|
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
|
|||
tool_prompt: Optional[str] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
# Stop generating tokens if a member of `stop` is generated
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
|
|
|
@ -1121,6 +1121,15 @@
|
|||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||
"example": 0.95,
|
||||
"nullable": true
|
||||
},
|
||||
"stop": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -402,6 +402,11 @@ pub struct CompletionRequest {
|
|||
#[serde(default)]
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
|
|
|
@ -597,9 +597,22 @@ async fn completions(
|
|||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let seed = req.seed;
|
||||
let CompletionRequest {
|
||||
max_tokens,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
temperature,
|
||||
..
|
||||
} = req;
|
||||
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
|
||||
// if suffix is present throw an error
|
||||
if req.suffix.is_some() {
|
||||
|
@ -635,16 +648,16 @@ async fn completions(
|
|||
inputs: prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
stop: stop.clone(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
|
|
Loading…
Reference in New Issue