Merge commit 'refs/pull/1869/head' of github.com:huggingface/text-generation-inference into main

This commit is contained in:
drbh 2024-05-21 20:56:18 +00:00
commit 9b08e4ab32
5 changed files with 289 additions and 48 deletions

View File

@ -13,6 +13,9 @@ from text_generation.types import (
Request,
Parameters,
Grammar,
CompletionRequest,
Completion,
CompletionComplete,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
@ -70,6 +73,94 @@ class Client:
self.cookies = cookies
self.timeout = timeout
def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
):
"""
Given a prompt, generate a response synchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Completion(**payload)
else:
return self._completion_stream_response(request)
def _completion_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def chat(
self,
messages: List[Message],
@ -88,6 +179,7 @@ class Client:
tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
):
"""
Given a list of messages, generate a response asynchronously
@ -130,6 +222,8 @@ class Client:
A prompt to be appended before the tools
tool_choice (`str`):
The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = ChatRequest(
@ -150,6 +244,7 @@ class Client:
tools=tools,
tool_prompt=tool_prompt,
tool_choice=tool_choice,
stop=stop,
)
if not stream:
resp = requests.post(
@ -461,6 +556,93 @@ class AsyncClient:
self.cookies = cookies
self.timeout = ClientTimeout(timeout)
async def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
"""
Given a prompt, generate a response asynchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
return await self._completion_single_response(request)
else:
return self._completion_stream_response(request)
async def _completion_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Completion(**payload)
async def _completion_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def chat(
self,
messages: List[Message],
@ -479,6 +661,7 @@ class AsyncClient:
tools: Optional[List[Tool]] = None,
tool_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
@ -521,6 +704,8 @@ class AsyncClient:
A prompt to be appended before the tools
tool_choice (`str`):
The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = ChatRequest(
@ -541,6 +726,7 @@ class AsyncClient:
tools=tools,
tool_prompt=tool_prompt,
tool_choice=tool_choice,
stop=stop,
)
if not stream:
return await self._chat_single_response(request)

View File

@ -46,30 +46,6 @@ class Tool(BaseModel):
function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel):
name: Optional[str]
arguments: str
@ -95,24 +71,41 @@ class Choice(BaseModel):
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
class CompletionRequest(BaseModel):
# Model identifier
model: str
system_fingerprint: str
choices: List[Choice]
# Prompt
prompt: str
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Completion(BaseModel):
@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
tool_prompt: Optional[str] = None
# Choice of tool to be used
tool_choice: Optional[str] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class Parameters(BaseModel):

View File

@ -1121,6 +1121,15 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
}
}
},

View File

@ -402,6 +402,11 @@ pub struct CompletionRequest {
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]

View File

@ -597,9 +597,22 @@ async fn completions(
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let seed = req.seed;
let CompletionRequest {
max_tokens,
seed,
stop,
stream,
temperature,
..
} = req;
let max_new_tokens = max_tokens.or(Some(100));
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// if suffix is present throw an error
if req.suffix.is_some() {
@ -635,16 +648,16 @@ async fn completions(
inputs: prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
temperature: temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
do_sample,
max_new_tokens,
return_full_text: None,
stop: Vec::new(),
stop: stop.clone(),
truncate: None,
watermark: false,
details: true,