fix: merge 'main' into lora-internal to resolve conflicts
This commit is contained in:
commit
0e1c28cafd
|
@ -0,0 +1,18 @@
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
|
@ -1856,12 +1856,23 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minijinja"
|
name = "minijinja"
|
||||||
version = "1.0.12"
|
version = "2.0.2"
|
||||||
source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja-contrib"
|
||||||
|
version = "2.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07"
|
||||||
|
dependencies = [
|
||||||
|
"minijinja",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -3604,6 +3615,7 @@ dependencies = [
|
||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
|
"minijinja-contrib",
|
||||||
"ngrok",
|
"ngrok",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1718044128,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.5-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 39,
|
||||||
|
"prompt_tokens": 136,
|
||||||
|
"total_tokens": 175
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
max_batch_prefill_tokens=3000,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def llama_grammar(llama_grammar_handle):
|
||||||
|
await llama_grammar_handle.health(300)
|
||||||
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = response.json()
|
||||||
|
called = chat_completion["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert (
|
||||||
|
called
|
||||||
|
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||||
|
)
|
||||||
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||||
|
llama_grammar,
|
||||||
|
):
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"tools": [],
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 422 means the server was unable to process the request because it contains invalid data.
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == {
|
||||||
|
"error": "Grammar and tools are mutually exclusive",
|
||||||
|
"error_type": "grammar and tools",
|
||||||
|
}
|
|
@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
minijinja = { version = "2.0.2" }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
default = ["ngrok"]
|
default = ["ngrok"]
|
||||||
ngrok = ["dep:ngrok"]
|
ngrok = ["dep:ngrok"]
|
||||||
google = []
|
google = []
|
||||||
|
kserve = []
|
||||||
|
|
|
@ -12,6 +12,8 @@ use crate::{
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -62,14 +64,7 @@ impl Infer {
|
||||||
.find(|t| t.name == "default")
|
.find(|t| t.name == "default")
|
||||||
.map(|t| t.template),
|
.map(|t| t.template),
|
||||||
})
|
})
|
||||||
.map(|t| {
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
// .strip() is not supported in minijinja
|
|
||||||
// .capitalize() is not supported in minijinja but we can use | capitalize
|
|
||||||
let t = t
|
|
||||||
.replace(".strip()", " | trim")
|
|
||||||
.replace(".capitalize()", " | capitalize");
|
|
||||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
@ -277,6 +272,8 @@ struct ChatTemplate {
|
||||||
impl ChatTemplate {
|
impl ChatTemplate {
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
|
// enable things like .strip() or .capitalize()
|
||||||
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,247 @@
|
||||||
|
use crate::{
|
||||||
|
default_parameters,
|
||||||
|
server::{generate_internal, ComputeType},
|
||||||
|
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
|
||||||
|
};
|
||||||
|
use axum::extract::{Extension, Path};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::Json;
|
||||||
|
use futures::stream::FuturesUnordered;
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use reqwest::StatusCode;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct OutputChunk {
|
||||||
|
pub name: String,
|
||||||
|
pub shape: Vec<usize>,
|
||||||
|
pub datatype: String,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct InferenceOutput {
|
||||||
|
pub id: String,
|
||||||
|
pub outputs: Vec<OutputChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct InferenceRequest {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
pub inputs: Vec<Input>,
|
||||||
|
pub outputs: Vec<Output>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct Input {
|
||||||
|
pub name: String,
|
||||||
|
pub shape: Vec<usize>,
|
||||||
|
pub datatype: String,
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct Output {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct LiveResponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct ReadyResponse {
|
||||||
|
pub live: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct MetadataServerResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub version: String,
|
||||||
|
pub extensions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/live",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Service is live", body = LiveReponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = LiveResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/health/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Service is ready", body = ReadyResponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
|
||||||
|
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: "text-generation-inference".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
extensions: vec![
|
||||||
|
"health".to_string(),
|
||||||
|
"models".to_string(),
|
||||||
|
"metrics".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_metadata(
|
||||||
|
Path((model_name, model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = MetadataServerResponse {
|
||||||
|
name: model_name,
|
||||||
|
version: model_version,
|
||||||
|
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/infer",
|
||||||
|
request_body = Json<InferenceRequest>,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_infer(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(payload): Json<InferenceRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let id = payload.id.clone();
|
||||||
|
let str_inputs = payload
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.map(|input| {
|
||||||
|
std::str::from_utf8(&input.data).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "utf8".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
if str_inputs.len() != payload.outputs.len() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Inputs and outputs length mismatch".to_string(),
|
||||||
|
error_type: "length mismatch".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_chunks = str_inputs
|
||||||
|
.iter()
|
||||||
|
.zip(&payload.outputs)
|
||||||
|
.map(|(str_input, output)| {
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: str_input.to_string(),
|
||||||
|
parameters: payload.parameters.clone(),
|
||||||
|
};
|
||||||
|
let infer = infer.clone();
|
||||||
|
let compute_type = compute_type.clone();
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
async move {
|
||||||
|
generate_internal(infer, compute_type, Json(generate_request), span)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| {
|
||||||
|
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||||
|
OutputChunk {
|
||||||
|
name: output.name.clone(),
|
||||||
|
shape: vec![1, generation_as_bytes.len()],
|
||||||
|
datatype: "BYTES".to_string(),
|
||||||
|
data: generation_as_bytes,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let inference_output = InferenceOutput {
|
||||||
|
id: id.clone(),
|
||||||
|
outputs: output_chunks,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||||
|
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||||
|
example = json!({"error": "No response"}))
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn kserve_model_metadata_ready(
|
||||||
|
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let data = ReadyResponse { live: true };
|
||||||
|
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||||
|
}
|
|
@ -4,6 +4,9 @@ mod infer;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
mod kserve;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
@ -89,6 +92,7 @@ pub(crate) enum GrammarType {
|
||||||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
/// with types and descriptions.
|
/// with types and descriptions.
|
||||||
#[serde(rename = "json")]
|
#[serde(rename = "json")]
|
||||||
|
#[serde(alias = "json_object")]
|
||||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
Json(serde_json::Value),
|
Json(serde_json::Value),
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
|
@ -797,6 +801,13 @@ pub(crate) struct ChatRequest {
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: Option<ToolType>,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
|
|
|
@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2;
|
||||||
use crate::infer::v3::SchedulerV3;
|
use crate::infer::v3::SchedulerV3;
|
||||||
use crate::infer::{HealthCheck, Scheduler};
|
use crate::infer::{HealthCheck, Scheduler};
|
||||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
use crate::kserve::{
|
||||||
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
|
@ -172,7 +177,7 @@ async fn generate(
|
||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_internal(
|
pub(crate) async fn generate_internal(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
|
@ -1017,6 +1022,7 @@ async fn chat_completions(
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
temperature,
|
temperature,
|
||||||
|
response_format,
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
|
@ -1031,6 +1037,18 @@ async fn chat_completions(
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// response_format and tools are mutually exclusive
|
||||||
|
if response_format.is_some() && tools.as_ref().is_some() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||||
|
error_type: "grammar and tools".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
Ok(grammar) => grammar,
|
Ok(grammar) => grammar,
|
||||||
|
@ -1047,16 +1065,21 @@ async fn chat_completions(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let grammar_with_prompt = tool_grammar
|
// determine the appropriate arguments for apply_chat_template
|
||||||
|
let tools_grammar_prompt = tool_grammar
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||||
|
|
||||||
let typed_grammar = grammar_with_prompt
|
let (tools_grammar_prompt, grammar) = match response_format {
|
||||||
.as_ref()
|
Some(response_format) => (None, Some(response_format)),
|
||||||
.map(|(grammar, _)| grammar.clone());
|
None => (
|
||||||
|
tools_grammar_prompt.clone(),
|
||||||
|
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
@ -1092,7 +1115,7 @@ async fn chat_completions(
|
||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar: typed_grammar,
|
grammar,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -1711,28 +1734,58 @@ pub async fn run(
|
||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
#[allow(unused_mut)] // mut is needed for conditional compilation
|
||||||
let doc = {
|
let mut doc = ApiDoc::openapi();
|
||||||
// avoid `mut` if possible
|
|
||||||
#[cfg(feature = "google")]
|
|
||||||
{
|
|
||||||
use crate::VertexInstance;
|
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[cfg(feature = "google")]
|
||||||
#[openapi(
|
{
|
||||||
paths(vertex_compatibility),
|
use crate::VertexInstance;
|
||||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
|
||||||
)]
|
|
||||||
struct VertextApiDoc;
|
|
||||||
|
|
||||||
// limiting mutability to the smallest scope necessary
|
#[derive(OpenApi)]
|
||||||
let mut doc = ApiDoc::openapi();
|
#[openapi(
|
||||||
doc.merge(VertextApiDoc::openapi());
|
paths(vertex_compatibility),
|
||||||
doc
|
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||||
}
|
)]
|
||||||
#[cfg(not(feature = "google"))]
|
struct VertexApiDoc;
|
||||||
ApiDoc::openapi()
|
|
||||||
};
|
doc.merge(VertexApiDoc::openapi());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
use crate::kserve::{
|
||||||
|
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
|
||||||
|
ReadyResponse,
|
||||||
|
};
|
||||||
|
use crate::kserve::{
|
||||||
|
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
|
||||||
|
__path_kserve_model_infer, __path_kserve_model_metadata,
|
||||||
|
__path_kserve_model_metadata_ready,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(
|
||||||
|
kserve_model_infer,
|
||||||
|
kserve_health_live,
|
||||||
|
kserve_health_ready,
|
||||||
|
kerve_server_metadata,
|
||||||
|
kserve_model_metadata,
|
||||||
|
kserve_model_metadata_ready,
|
||||||
|
),
|
||||||
|
components(schemas(
|
||||||
|
InferenceOutput,
|
||||||
|
InferenceRequest,
|
||||||
|
LiveResponse,
|
||||||
|
MetadataServerResponse,
|
||||||
|
OutputChunk,
|
||||||
|
ReadyResponse,
|
||||||
|
))
|
||||||
|
)]
|
||||||
|
struct KServeApiDoc;
|
||||||
|
|
||||||
|
doc.merge(KServeApiDoc::openapi());
|
||||||
|
}
|
||||||
|
|
||||||
// Configure Swagger UI
|
// Configure Swagger UI
|
||||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||||
|
@ -1782,6 +1835,27 @@ pub async fn run(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "kserve")]
|
||||||
|
{
|
||||||
|
tracing::info!("Built with `kserve` feature");
|
||||||
|
app = app
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version/infer",
|
||||||
|
post(kserve_model_infer),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version",
|
||||||
|
get(kserve_model_metadata),
|
||||||
|
)
|
||||||
|
.route("/v2/health/ready", get(kserve_health_ready))
|
||||||
|
.route("/v2/health/live", get(kserve_health_live))
|
||||||
|
.route("/v2", get(kerve_server_metadata))
|
||||||
|
.route(
|
||||||
|
"/v2/models/:model_name/versions/:model_version/ready",
|
||||||
|
get(kserve_model_metadata_ready),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// add layers after routes
|
// add layers after routes
|
||||||
app = app
|
app = app
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||||
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
|
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
|
|
@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
if seqlen > self.original_max_position_embeddings:
|
|
||||||
inv_freq = self.long_inv_freq
|
|
||||||
else:
|
|
||||||
inv_freq = self.short_inv_freq
|
|
||||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
|
||||||
if self.scaling_factor is not None:
|
|
||||||
t /= self.scaling_factor
|
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
|
||||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
short_freqs = torch.outer(
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.cat([short_freqs, long_freqs])
|
||||||
|
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
|
|
@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits, speculative_logits = self.lm_head(outputs)
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashPhiForCausalLM,
|
FlashPhiForCausalLM,
|
||||||
PhiConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
|
@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = PhiConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
|
@ -86,5 +86,4 @@ class GPTNeoxSharded(CausalLM):
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs.logits
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
return logits, speculative_logits, outputs.past_key_values
|
|
||||||
|
|
|
@ -76,11 +76,11 @@ class OPTSharded(CausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
|
|
@ -72,11 +72,13 @@ class RW(CausalLM):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
):
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
return outputs.logits, outputs.past_key_values
|
|
||||||
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
|
Loading…
Reference in New Issue