fix: merge 'main' into lora-internal to resolve conflicts
This commit is contained in:
commit
0e1c28cafd
|
@ -0,0 +1,18 @@
|
|||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
|
@ -1856,12 +1856,23 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.12"
|
||||
source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minijinja-contrib"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07"
|
||||
dependencies = [
|
||||
"minijinja",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
|
@ -3604,6 +3615,7 @@ dependencies = [
|
|||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"minijinja-contrib",
|
||||
"ngrok",
|
||||
"nohash-hasher",
|
||||
"once_cell",
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1718044128,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.5-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 136,
|
||||
"total_tokens": 175
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
num_shard=1,
|
||||
disable_grammar_support=False,
|
||||
use_flash_attention=False,
|
||||
max_batch_prefill_tokens=3000,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def llama_grammar(llama_grammar_handle):
|
||||
await llama_grammar_handle.health(300)
|
||||
return llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
||||
# send the request
|
||||
response = requests.post(
|
||||
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||
headers=llama_grammar.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
"seed": 42,
|
||||
"max_tokens": 500,
|
||||
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||
},
|
||||
)
|
||||
|
||||
chat_completion = response.json()
|
||||
called = chat_completion["choices"][0]["message"]["content"]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
called
|
||||
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||
)
|
||||
assert chat_completion == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||
llama_grammar,
|
||||
):
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
||||
# send the request
|
||||
response = requests.post(
|
||||
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||
headers=llama_grammar.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
"seed": 42,
|
||||
"max_tokens": 500,
|
||||
"tools": [],
|
||||
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||
},
|
||||
)
|
||||
|
||||
# 422 means the server was unable to process the request because it contains invalid data.
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"error": "Grammar and tools are mutually exclusive",
|
||||
"error_type": "grammar and tools",
|
||||
}
|
|
@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
|
@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|||
default = ["ngrok"]
|
||||
ngrok = ["dep:ngrok"]
|
||||
google = []
|
||||
kserve = []
|
||||
|
|
|
@ -12,6 +12,8 @@ use crate::{
|
|||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
@ -62,14 +64,7 @@ impl Infer {
|
|||
.find(|t| t.name == "default")
|
||||
.map(|t| t.template),
|
||||
})
|
||||
.map(|t| {
|
||||
// .strip() is not supported in minijinja
|
||||
// .capitalize() is not supported in minijinja but we can use | capitalize
|
||||
let t = t
|
||||
.replace(".strip()", " | trim")
|
||||
.replace(".capitalize()", " | capitalize");
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
});
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
@ -277,6 +272,8 @@ struct ChatTemplate {
|
|||
impl ChatTemplate {
|
||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
use crate::{
|
||||
default_parameters,
|
||||
server::{generate_internal, ComputeType},
|
||||
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
|
||||
};
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct OutputChunk {
|
||||
pub name: String,
|
||||
pub shape: Vec<usize>,
|
||||
pub datatype: String,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct InferenceOutput {
|
||||
pub id: String,
|
||||
pub outputs: Vec<OutputChunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
pub(crate) struct InferenceRequest {
|
||||
pub id: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
pub inputs: Vec<Input>,
|
||||
pub outputs: Vec<Output>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct Input {
|
||||
pub name: String,
|
||||
pub shape: Vec<usize>,
|
||||
pub datatype: String,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct Output {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct LiveResponse {
|
||||
pub live: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ReadyResponse {
|
||||
pub live: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct MetadataServerResponse {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub extensions: Vec<String>,
|
||||
}
|
||||
|
||||
// Routes
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/health/live",
|
||||
responses(
|
||||
(status = 200, description = "Service is live", body = LiveReponse),
|
||||
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = LiveResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/health/ready",
|
||||
responses(
|
||||
(status = 200, description = "Service is ready", body = ReadyResponse),
|
||||
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = ReadyResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2",
|
||||
responses(
|
||||
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
|
||||
(status = 404, description = "Service not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = MetadataServerResponse {
|
||||
name: "text-generation-inference".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
extensions: vec![
|
||||
"health".to_string(),
|
||||
"models".to_string(),
|
||||
"metrics".to_string(),
|
||||
],
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}",
|
||||
responses(
|
||||
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
|
||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_model_metadata(
|
||||
Path((model_name, model_version)): Path<(String, String)>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = MetadataServerResponse {
|
||||
name: model_name,
|
||||
version: model_version,
|
||||
extensions: vec!["infer".to_string(), "ready".to_string()],
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/infer",
|
||||
request_body = Json<InferenceRequest>,
|
||||
responses(
|
||||
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
|
||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_model_infer(
|
||||
infer: Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Json(payload): Json<InferenceRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let id = payload.id.clone();
|
||||
let str_inputs = payload
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
std::str::from_utf8(&input.data).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "utf8".to_string(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if str_inputs.len() != payload.outputs.len() {
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Inputs and outputs length mismatch".to_string(),
|
||||
error_type: "length mismatch".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let output_chunks = str_inputs
|
||||
.iter()
|
||||
.zip(&payload.outputs)
|
||||
.map(|(str_input, output)| {
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: str_input.to_string(),
|
||||
parameters: payload.parameters.clone(),
|
||||
};
|
||||
let infer = infer.clone();
|
||||
let compute_type = compute_type.clone();
|
||||
let span = tracing::Span::current();
|
||||
async move {
|
||||
generate_internal(infer, compute_type, Json(generate_request), span)
|
||||
.await
|
||||
.map(|(_, Json(generation))| {
|
||||
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
|
||||
OutputChunk {
|
||||
name: output.name.clone(),
|
||||
shape: vec![1, generation_as_bytes.len()],
|
||||
datatype: "BYTES".to_string(),
|
||||
data: generation_as_bytes,
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
let inference_output = InferenceOutput {
|
||||
id: id.clone(),
|
||||
outputs: output_chunks,
|
||||
};
|
||||
|
||||
Ok((HeaderMap::new(), Json(inference_output)).into_response())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v2/models/{model_name}/versions/{model_version}/ready",
|
||||
responses(
|
||||
(status = 200, description = "Model version is ready", body = ReadyResponse),
|
||||
(status = 404, description = "Model or version not found", body = ErrorResponse,
|
||||
example = json!({"error": "No response"}))
|
||||
)
|
||||
)]
|
||||
pub async fn kserve_model_metadata_ready(
|
||||
Path((_model_name, _model_version)): Path<(String, String)>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let data = ReadyResponse { live: true };
|
||||
Ok((HeaderMap::new(), Json(data)).into_response())
|
||||
}
|
|
@ -4,6 +4,9 @@ mod infer;
|
|||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
mod kserve;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
|
@ -89,6 +92,7 @@ pub(crate) enum GrammarType {
|
|||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
/// with types and descriptions.
|
||||
#[serde(rename = "json")]
|
||||
#[serde(alias = "json_object")]
|
||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||
Json(serde_json::Value),
|
||||
#[serde(rename = "regex")]
|
||||
|
@ -797,6 +801,13 @@ pub(crate) struct ChatRequest {
|
|||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub response_format: Option<GrammarType>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
|
|
|
@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2;
|
|||
use crate::infer::v3::SchedulerV3;
|
||||
use crate::infer::{HealthCheck, Scheduler};
|
||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
#[cfg(feature = "kserve")]
|
||||
use crate::kserve::{
|
||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
|
@ -172,7 +177,7 @@ async fn generate(
|
|||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||
}
|
||||
|
||||
async fn generate_internal(
|
||||
pub(crate) async fn generate_internal(
|
||||
infer: Extension<Infer>,
|
||||
ComputeType(compute_type): ComputeType,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
|
@ -1017,6 +1022,7 @@ async fn chat_completions(
|
|||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
..
|
||||
} = req;
|
||||
|
||||
|
@ -1031,6 +1037,18 @@ async fn chat_completions(
|
|||
other => (true, other),
|
||||
};
|
||||
|
||||
// response_format and tools are mutually exclusive
|
||||
if response_format.is_some() && tools.as_ref().is_some() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||
error_type: "grammar and tools".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
|
@ -1047,16 +1065,21 @@ async fn chat_completions(
|
|||
}
|
||||
};
|
||||
|
||||
let grammar_with_prompt = tool_grammar
|
||||
// determine the appropriate arguments for apply_chat_template
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let typed_grammar = grammar_with_prompt
|
||||
.as_ref()
|
||||
.map(|(grammar, _)| grammar.clone());
|
||||
let (tools_grammar_prompt, grammar) = match response_format {
|
||||
Some(response_format) => (None, Some(response_format)),
|
||||
None => (
|
||||
tools_grammar_prompt.clone(),
|
||||
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||
),
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
|
@ -1092,7 +1115,7 @@ async fn chat_completions(
|
|||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: typed_grammar,
|
||||
grammar,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
@ -1711,28 +1734,58 @@ pub async fn run(
|
|||
docker_label: option_env!("DOCKER_LABEL"),
|
||||
};
|
||||
|
||||
// Define VertextApiDoc conditionally only if the "google" feature is enabled
|
||||
let doc = {
|
||||
// avoid `mut` if possible
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
use crate::VertexInstance;
|
||||
#[allow(unused_mut)] // mut is needed for conditional compilation
|
||||
let mut doc = ApiDoc::openapi();
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(vertex_compatibility),
|
||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||
)]
|
||||
struct VertextApiDoc;
|
||||
#[cfg(feature = "google")]
|
||||
{
|
||||
use crate::VertexInstance;
|
||||
|
||||
// limiting mutability to the smallest scope necessary
|
||||
let mut doc = ApiDoc::openapi();
|
||||
doc.merge(VertextApiDoc::openapi());
|
||||
doc
|
||||
}
|
||||
#[cfg(not(feature = "google"))]
|
||||
ApiDoc::openapi()
|
||||
};
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(vertex_compatibility),
|
||||
components(schemas(VertexInstance, VertexRequest, VertexResponse))
|
||||
)]
|
||||
struct VertexApiDoc;
|
||||
|
||||
doc.merge(VertexApiDoc::openapi());
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
{
|
||||
use crate::kserve::{
|
||||
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
|
||||
ReadyResponse,
|
||||
};
|
||||
use crate::kserve::{
|
||||
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
|
||||
__path_kserve_model_infer, __path_kserve_model_metadata,
|
||||
__path_kserve_model_metadata_ready,
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(
|
||||
kserve_model_infer,
|
||||
kserve_health_live,
|
||||
kserve_health_ready,
|
||||
kerve_server_metadata,
|
||||
kserve_model_metadata,
|
||||
kserve_model_metadata_ready,
|
||||
),
|
||||
components(schemas(
|
||||
InferenceOutput,
|
||||
InferenceRequest,
|
||||
LiveResponse,
|
||||
MetadataServerResponse,
|
||||
OutputChunk,
|
||||
ReadyResponse,
|
||||
))
|
||||
)]
|
||||
struct KServeApiDoc;
|
||||
|
||||
doc.merge(KServeApiDoc::openapi());
|
||||
}
|
||||
|
||||
// Configure Swagger UI
|
||||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||
|
@ -1782,6 +1835,27 @@ pub async fn run(
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
{
|
||||
tracing::info!("Built with `kserve` feature");
|
||||
app = app
|
||||
.route(
|
||||
"/v2/models/:model_name/versions/:model_version/infer",
|
||||
post(kserve_model_infer),
|
||||
)
|
||||
.route(
|
||||
"/v2/models/:model_name/versions/:model_version",
|
||||
get(kserve_model_metadata),
|
||||
)
|
||||
.route("/v2/health/ready", get(kserve_health_ready))
|
||||
.route("/v2/health/live", get(kserve_health_live))
|
||||
.route("/v2", get(kerve_server_metadata))
|
||||
.route(
|
||||
"/v2/models/:model_name/versions/:model_version/ready",
|
||||
get(kserve_model_metadata_ready),
|
||||
);
|
||||
}
|
||||
|
||||
// add layers after routes
|
||||
app = app
|
||||
.layer(Extension(info))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
|
|
|
@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
|||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
if seqlen > self.original_max_position_embeddings:
|
||||
inv_freq = self.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.short_inv_freq
|
||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||
short_freqs = torch.outer(
|
||||
t[: self.original_max_position_embeddings],
|
||||
self.short_inv_freq.to(device=t.device),
|
||||
)
|
||||
long_freqs = torch.outer(
|
||||
t[self.original_max_position_embeddings :],
|
||||
self.long_inv_freq.to(device=t.device),
|
||||
)
|
||||
|
||||
freqs = torch.cat([short_freqs, long_freqs])
|
||||
|
||||
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
|
|
|
@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.lm_head(outputs)
|
||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||
|
||||
loss = None
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional
|
|||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
PhiConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
|
@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = PhiConfig.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
|
|
@ -86,5 +86,4 @@ class GPTNeoxSharded(CausalLM):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -76,11 +76,11 @@ class OPTSharded(CausalLM):
|
|||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, outputs.past_key_values
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -72,11 +72,13 @@ class RW(CausalLM):
|
|||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
):
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
Loading…
Reference in New Issue