fix: merge 'main' into lora-internal to resolve conflicts

This commit is contained in:
drbh 2024-06-14 14:02:33 +00:00
commit 0e1c28cafd
16 changed files with 548 additions and 61 deletions

18
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,18 @@
on:
push:
name: Secret Leaks
permissions:
contents: read
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

16
Cargo.lock generated
View File

@ -1856,12 +1856,23 @@ dependencies = [
[[package]] [[package]]
name = "minijinja" name = "minijinja"
version = "1.0.12" version = "2.0.2"
source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4"
dependencies = [ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "minijinja-contrib"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07"
dependencies = [
"minijinja",
"serde",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -3604,6 +3615,7 @@ dependencies = [
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja", "minijinja",
"minijinja-contrib",
"ngrok", "ngrok",
"nohash-hasher", "nohash-hasher",
"once_cell", "once_cell",

View File

@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
"role": "assistant"
}
}
],
"created": 1718044128,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.5-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 136,
"total_tokens": 175
}
}

View File

@ -0,0 +1,101 @@
import pytest
import requests
from pydantic import BaseModel
from typing import List
@pytest.fixture(scope="module")
def llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
max_batch_prefill_tokens=3000,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def llama_grammar(llama_grammar_handle):
await llama_grammar_handle.health(300)
return llama_grammar_handle.client
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert (
called
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
)
assert chat_completion == response_snapshot
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,
):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"tools": [],
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422
assert response.json() == {
"error": "Grammar and tools are mutually exclusive",
"error_type": "grammar and tools",
}

View File

@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" } minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
@ -58,3 +59,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
default = ["ngrok"] default = ["ngrok"]
ngrok = ["dep:ngrok"] ngrok = ["dep:ngrok"]
google = [] google = []
kserve = []

View File

@ -12,6 +12,8 @@ use crate::{
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -62,14 +64,7 @@ impl Infer {
.find(|t| t.name == "default") .find(|t| t.name == "default")
.map(|t| t.template), .map(|t| t.template),
}) })
.map(|t| { .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
// .strip() is not supported in minijinja
// .capitalize() is not supported in minijinja but we can use | capitalize
let t = t
.replace(".strip()", " | trim")
.replace(".capitalize()", " | capitalize");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -277,6 +272,8 @@ struct ChatTemplate {
impl ChatTemplate { impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self { fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);

247
router/src/kserve.rs Normal file
View File

@ -0,0 +1,247 @@
use crate::{
default_parameters,
server::{generate_internal, ComputeType},
Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Infer, Serialize, ToSchema,
};
use axum::extract::{Extension, Path};
use axum::response::{IntoResponse, Response};
use axum::Json;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct OutputChunk {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct InferenceOutput {
pub id: String,
pub outputs: Vec<OutputChunk>,
}
#[derive(Debug, Deserialize, ToSchema)]
pub(crate) struct InferenceRequest {
pub id: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Input {
pub name: String,
pub shape: Vec<usize>,
pub datatype: String,
pub data: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub(crate) struct Output {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct LiveResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ReadyResponse {
pub live: bool,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MetadataServerResponse {
pub name: String,
pub version: String,
pub extensions: Vec<String>,
}
// Routes
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/live",
responses(
(status = 200, description = "Service is live", body = LiveReponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_live() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = LiveResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/health/ready",
responses(
(status = 200, description = "Service is ready", body = ReadyResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_health_ready() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2",
responses(
(status = 200, description = "Metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Service not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kerve_server_metadata() -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: "text-generation-inference".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
extensions: vec![
"health".to_string(),
"models".to_string(),
"metrics".to_string(),
],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}",
responses(
(status = 200, description = "Model version metadata retrieved", body = MetadataServerResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata(
Path((model_name, model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = MetadataServerResponse {
name: model_name,
version: model_version,
extensions: vec!["infer".to_string(), "ready".to_string()],
};
Ok((HeaderMap::new(), Json(data)).into_response())
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/infer",
request_body = Json<InferenceRequest>,
responses(
(status = 200, description = "Inference executed successfully", body = InferenceOutput),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_infer(
infer: Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(payload): Json<InferenceRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let id = payload.id.clone();
let str_inputs = payload
.inputs
.iter()
.map(|input| {
std::str::from_utf8(&input.data).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "utf8".to_string(),
}),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
if str_inputs.len() != payload.outputs.len() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Inputs and outputs length mismatch".to_string(),
error_type: "length mismatch".to_string(),
}),
));
}
let output_chunks = str_inputs
.iter()
.zip(&payload.outputs)
.map(|(str_input, output)| {
let generate_request = GenerateRequest {
inputs: str_input.to_string(),
parameters: payload.parameters.clone(),
};
let infer = infer.clone();
let compute_type = compute_type.clone();
let span = tracing::Span::current();
async move {
generate_internal(infer, compute_type, Json(generate_request), span)
.await
.map(|(_, Json(generation))| {
let generation_as_bytes = generation.generated_text.as_bytes().to_vec();
OutputChunk {
name: output.name.clone(),
shape: vec![1, generation_as_bytes.len()],
datatype: "BYTES".to_string(),
data: generation_as_bytes,
}
})
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
let inference_output = InferenceOutput {
id: id.clone(),
outputs: output_chunks,
};
Ok((HeaderMap::new(), Json(inference_output)).into_response())
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v2/models/{model_name}/versions/{model_version}/ready",
responses(
(status = 200, description = "Model version is ready", body = ReadyResponse),
(status = 404, description = "Model or version not found", body = ErrorResponse,
example = json!({"error": "No response"}))
)
)]
pub async fn kserve_model_metadata_ready(
Path((_model_name, _model_version)): Path<(String, String)>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let data = ReadyResponse { live: true };
Ok((HeaderMap::new(), Json(data)).into_response())
}

View File

@ -4,6 +4,9 @@ mod infer;
pub mod server; pub mod server;
mod validation; mod validation;
#[cfg(feature = "kserve")]
mod kserve;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -89,6 +92,7 @@ pub(crate) enum GrammarType {
/// JSON Schema is a declarative language that allows to annotate JSON documents /// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions. /// with types and descriptions.
#[serde(rename = "json")] #[serde(rename = "json")]
#[serde(alias = "json_object")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value), Json(serde_json::Value),
#[serde(rename = "regex")] #[serde(rename = "regex")]
@ -797,6 +801,13 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>, pub tool_choice: Option<ToolType>,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
} }
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {

View File

@ -4,6 +4,11 @@ use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3; use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
#[cfg(feature = "kserve")]
use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -172,7 +177,7 @@ async fn generate(
generate_internal(infer, ComputeType(compute_type), Json(req), span).await generate_internal(infer, ComputeType(compute_type), Json(req), span).await
} }
async fn generate_internal( pub(crate) async fn generate_internal(
infer: Extension<Infer>, infer: Extension<Infer>,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
@ -1017,6 +1022,7 @@ async fn chat_completions(
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature, temperature,
response_format,
.. ..
} = req; } = req;
@ -1031,6 +1037,18 @@ async fn chat_completions(
other => (true, other), other => (true, other),
}; };
// response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Grammar and tools are mutually exclusive".to_string(),
error_type: "grammar and tools".to_string(),
}),
));
}
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
@ -1047,16 +1065,21 @@ async fn chat_completions(
} }
}; };
let grammar_with_prompt = tool_grammar // determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref() .as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let typed_grammar = grammar_with_prompt let (tools_grammar_prompt, grammar) = match response_format {
.as_ref() Some(response_format) => (None, Some(response_format)),
.map(|(grammar, _)| grammar.clone()); None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1092,7 +1115,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: typed_grammar, grammar,
..Default::default() ..Default::default()
}, },
}; };
@ -1711,28 +1734,58 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),
}; };
// Define VertextApiDoc conditionally only if the "google" feature is enabled #[allow(unused_mut)] // mut is needed for conditional compilation
let doc = { let mut doc = ApiDoc::openapi();
// avoid `mut` if possible
#[cfg(feature = "google")]
{
use crate::VertexInstance;
#[derive(OpenApi)] #[cfg(feature = "google")]
#[openapi( {
paths(vertex_compatibility), use crate::VertexInstance;
components(schemas(VertexInstance, VertexRequest, VertexResponse))
)]
struct VertextApiDoc;
// limiting mutability to the smallest scope necessary #[derive(OpenApi)]
let mut doc = ApiDoc::openapi(); #[openapi(
doc.merge(VertextApiDoc::openapi()); paths(vertex_compatibility),
doc components(schemas(VertexInstance, VertexRequest, VertexResponse))
} )]
#[cfg(not(feature = "google"))] struct VertexApiDoc;
ApiDoc::openapi()
}; doc.merge(VertexApiDoc::openapi());
}
#[cfg(feature = "kserve")]
{
use crate::kserve::{
InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,
ReadyResponse,
};
use crate::kserve::{
__path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,
__path_kserve_model_infer, __path_kserve_model_metadata,
__path_kserve_model_metadata_ready,
};
#[derive(OpenApi)]
#[openapi(
paths(
kserve_model_infer,
kserve_health_live,
kserve_health_ready,
kerve_server_metadata,
kserve_model_metadata,
kserve_model_metadata_ready,
),
components(schemas(
InferenceOutput,
InferenceRequest,
LiveResponse,
MetadataServerResponse,
OutputChunk,
ReadyResponse,
))
)]
struct KServeApiDoc;
doc.merge(KServeApiDoc::openapi());
}
// Configure Swagger UI // Configure Swagger UI
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
@ -1782,6 +1835,27 @@ pub async fn run(
} }
} }
#[cfg(feature = "kserve")]
{
tracing::info!("Built with `kserve` feature");
app = app
.route(
"/v2/models/:model_name/versions/:model_version/infer",
post(kserve_model_infer),
)
.route(
"/v2/models/:model_name/versions/:model_version",
get(kserve_model_metadata),
)
.route("/v2/health/ready", get(kserve_health_ready))
.route("/v2/health/live", get(kserve_health_live))
.route("/v2", get(kerve_server_metadata))
.route(
"/v2/models/:model_name/versions/:model_version/ready",
get(kserve_model_metadata_ready),
);
}
// add layers after routes // add layers after routes
app = app app = app
.layer(Extension(info)) .layer(Extension(info))

View File

@ -1,5 +1,5 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \

View File

@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
if seqlen > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.short_inv_freq
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq.to(device=t.device)) t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
self._cos_cached = torch.cos(freqs).to(dtype) short_freqs = torch.outer(
self._sin_cached = torch.sin(freqs).to(dtype) t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
freqs = torch.cat([short_freqs, long_freqs])
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):

View File

@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
logits, speculative_logits = self.lm_head(outputs) logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
loss = None loss = None

View File

@ -8,7 +8,6 @@ from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM, FlashPhiForCausalLM,
PhiConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = PhiConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize

View File

@ -86,5 +86,4 @@ class GPTNeoxSharded(CausalLM):
use_cache=True, use_cache=True,
) )
logits = outputs.logits return outputs.logits, speculative_logits, outputs.past_key_values
return logits, speculative_logits, outputs.past_key_values

View File

@ -76,11 +76,11 @@ class OPTSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
return outputs.logits, outputs.past_key_values return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -72,11 +72,13 @@ class RW(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ):
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True,
) )
return outputs.logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values