Add completion route to client and add stop parameter where it's missing (#1869)
# What does this PR do? - Add the stop parameter to the completion route - Add the completion method to the python client - Add the stop parameter to the python client's chat method ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --------- Co-authored-by: Thomas SCHILLACI <tschilla@px101.prod.exalead.com> Co-authored-by: Thomas Schillaci <thomas.schillaci@3ds.com>
This commit is contained in:
parent
f4a073ae6d
commit
629047cb82
|
@ -13,6 +13,9 @@ from text_generation.types import (
|
||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
CompletionRequest,
|
||||||
|
Completion,
|
||||||
|
CompletionComplete,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
|
@ -70,6 +73,94 @@ class Client:
|
||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given a prompt, generate a response synchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Prompt
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty
|
||||||
|
Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
"""
|
||||||
|
request = CompletionRequest(
|
||||||
|
model="tgi",
|
||||||
|
prompt=prompt,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return Completion(**payload)
|
||||||
|
else:
|
||||||
|
return self._completion_stream_response(request)
|
||||||
|
|
||||||
|
def _completion_stream_response(self, request):
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# iterate and print stream
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = CompletionComplete(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -88,6 +179,7 @@ class Client:
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
tool_prompt: Optional[str] = None,
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
|
@ -130,6 +222,8 @@ class Client:
|
||||||
A prompt to be appended before the tools
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
|
@ -150,6 +244,7 @@ class Client:
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_prompt=tool_prompt,
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
|
@ -461,6 +556,93 @@ class AsyncClient:
|
||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout)
|
self.timeout = ClientTimeout(timeout)
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate a response asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Prompt
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty
|
||||||
|
Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
"""
|
||||||
|
request = CompletionRequest(
|
||||||
|
model="tgi",
|
||||||
|
prompt=prompt,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
return await self._completion_single_response(request)
|
||||||
|
else:
|
||||||
|
return self._completion_stream_response(request)
|
||||||
|
|
||||||
|
async def _completion_single_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return Completion(**payload)
|
||||||
|
|
||||||
|
async def _completion_stream_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = CompletionComplete(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -479,6 +661,7 @@ class AsyncClient:
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
tool_prompt: Optional[str] = None,
|
tool_prompt: Optional[str] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
|
@ -521,6 +704,8 @@ class AsyncClient:
|
||||||
A prompt to be appended before the tools
|
A prompt to be appended before the tools
|
||||||
tool_choice (`str`):
|
tool_choice (`str`):
|
||||||
The tool to use
|
The tool to use
|
||||||
|
stop (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop` is generated
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
|
@ -541,6 +726,7 @@ class AsyncClient:
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_prompt=tool_prompt,
|
tool_prompt=tool_prompt,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
if not stream:
|
if not stream:
|
||||||
return await self._chat_single_response(request)
|
return await self._chat_single_response(request)
|
||||||
|
|
|
@ -46,30 +46,6 @@ class Tool(BaseModel):
|
||||||
function: dict
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionComplete(BaseModel):
|
|
||||||
# Index of the chat completion
|
|
||||||
index: int
|
|
||||||
# Message associated with the chat completion
|
|
||||||
message: Message
|
|
||||||
# Log probabilities for the chat completion
|
|
||||||
logprobs: Optional[Any]
|
|
||||||
# Reason for completion
|
|
||||||
finish_reason: str
|
|
||||||
# Usage details of the chat completion
|
|
||||||
usage: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionComplete(BaseModel):
|
|
||||||
# Index of the chat completion
|
|
||||||
index: int
|
|
||||||
# Message associated with the chat completion
|
|
||||||
text: str
|
|
||||||
# Log probabilities for the chat completion
|
|
||||||
logprobs: Optional[Any]
|
|
||||||
# Reason for completion
|
|
||||||
finish_reason: str
|
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
|
@ -95,24 +71,41 @@ class Choice(BaseModel):
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunk(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
id: str
|
# Model identifier
|
||||||
object: str
|
|
||||||
created: int
|
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
# Prompt
|
||||||
choices: List[Choice]
|
prompt: str
|
||||||
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
# The parameter for frequency penalty. 1.0 means no penalty
|
||||||
|
# Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
# Maximum number of tokens to generate
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
# Flag to indicate streaming response
|
||||||
|
stream: bool = False
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int] = None
|
||||||
|
# Sampling temperature
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
# Top-p value for nucleus sampling
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
# Stop generating tokens if a member of `stop` is generated
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatComplete(BaseModel):
|
class CompletionComplete(BaseModel):
|
||||||
# Chat completion details
|
# Index of the chat completion
|
||||||
id: str
|
index: int
|
||||||
object: str
|
# Message associated with the chat completion
|
||||||
created: int
|
text: str
|
||||||
model: str
|
# Log probabilities for the chat completion
|
||||||
system_fingerprint: str
|
logprobs: Optional[Any]
|
||||||
choices: List[ChatCompletionComplete]
|
# Reason for completion
|
||||||
usage: Any
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class Completion(BaseModel):
|
class Completion(BaseModel):
|
||||||
|
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
|
||||||
tool_prompt: Optional[str] = None
|
tool_prompt: Optional[str] = None
|
||||||
# Choice of tool to be used
|
# Choice of tool to be used
|
||||||
tool_choice: Optional[str] = None
|
tool_choice: Optional[str] = None
|
||||||
|
# Stop generating tokens if a member of `stop` is generated
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionComplete(BaseModel):
|
||||||
|
# Index of the chat completion
|
||||||
|
index: int
|
||||||
|
# Message associated with the chat completion
|
||||||
|
message: Message
|
||||||
|
# Log probabilities for the chat completion
|
||||||
|
logprobs: Optional[Any]
|
||||||
|
# Reason for completion
|
||||||
|
finish_reason: str
|
||||||
|
# Usage details of the chat completion
|
||||||
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatComplete(BaseModel):
|
||||||
|
# Chat completion details
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[ChatCompletionComplete]
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
|
|
@ -1121,6 +1121,15 @@
|
||||||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||||
"example": 0.95,
|
"example": 0.95,
|
||||||
"nullable": true
|
"nullable": true
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -402,6 +402,11 @@ pub struct CompletionRequest {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
|
|
@ -597,9 +597,22 @@ async fn completions(
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let stream = req.stream;
|
let CompletionRequest {
|
||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
max_tokens,
|
||||||
let seed = req.seed;
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
temperature,
|
||||||
|
..
|
||||||
|
} = req;
|
||||||
|
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
|
||||||
// if suffix is present throw an error
|
// if suffix is present throw an error
|
||||||
if req.suffix.is_some() {
|
if req.suffix.is_some() {
|
||||||
|
@ -635,16 +648,16 @@ async fn completions(
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature,
|
||||||
repetition_penalty: req.repetition_penalty,
|
repetition_penalty: req.repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: stop.clone(),
|
||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: true,
|
details: true,
|
||||||
|
|
Loading…
Reference in New Issue