* Use minijinja's pycompat mode for python methods * fix: cargo fmt lint for pre commit --------- Co-authored-by: Armin Ronacher <armin.ronacher@active-4.com>
This commit is contained in:
parent
90184df79c
commit
42aa8ee1bb
|
@ -1856,12 +1856,23 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minijinja"
|
name = "minijinja"
|
||||||
version = "1.0.12"
|
version = "2.0.2"
|
||||||
source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja-contrib"
|
||||||
|
version = "2.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07"
|
||||||
|
dependencies = [
|
||||||
|
"minijinja",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -3604,6 +3615,7 @@ dependencies = [
|
||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
|
"minijinja-contrib",
|
||||||
"ngrok",
|
"ngrok",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
|
|
@ -44,7 +44,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
minijinja = { version = "2.0.2" }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
|
|
@ -12,6 +12,8 @@ use crate::{
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -62,14 +64,7 @@ impl Infer {
|
||||||
.find(|t| t.name == "default")
|
.find(|t| t.name == "default")
|
||||||
.map(|t| t.template),
|
.map(|t| t.template),
|
||||||
})
|
})
|
||||||
.map(|t| {
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
// .strip() is not supported in minijinja
|
|
||||||
// .capitalize() is not supported in minijinja but we can use | capitalize
|
|
||||||
let t = t
|
|
||||||
.replace(".strip()", " | trim")
|
|
||||||
.replace(".capitalize()", " | capitalize");
|
|
||||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
@ -277,6 +272,8 @@ struct ChatTemplate {
|
||||||
impl ChatTemplate {
|
impl ChatTemplate {
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
|
// enable things like .strip() or .capitalize()
|
||||||
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue