feat(server): add special token bool (#85)
This commit is contained in:
parent
4b1c9720c0
commit
0ac184ce77
|
@ -1,122 +1,142 @@
|
||||||
{
|
{
|
||||||
|
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
|
"seed": null,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 10264,
|
"id": 10264,
|
||||||
"logprob": null,
|
"text": "Test",
|
||||||
"text": "Test"
|
"logprob": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8821,
|
"id": 8821,
|
||||||
"logprob": -11.894989,
|
"text": " request",
|
||||||
"text": " request"
|
"logprob": -11.894989
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 17,
|
"id": 17,
|
||||||
|
"text": ".",
|
||||||
"logprob": -1.8267672,
|
"logprob": -1.8267672,
|
||||||
"text": "."
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1587,
|
"id": 1587,
|
||||||
|
"text": "get",
|
||||||
"logprob": -2.4674969,
|
"logprob": -2.4674969,
|
||||||
"text": "get"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11,
|
"id": 11,
|
||||||
|
"text": "(",
|
||||||
"logprob": -1.906001,
|
"logprob": -1.906001,
|
||||||
"text": "("
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
"text": "\"",
|
||||||
"logprob": -1.2279545,
|
"logprob": -1.2279545,
|
||||||
"text": "\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4899,
|
"id": 4899,
|
||||||
|
"text": "action",
|
||||||
"logprob": -4.170299,
|
"logprob": -4.170299,
|
||||||
"text": "action"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
"text": "\"",
|
||||||
"logprob": -0.32478866,
|
"logprob": -0.32478866,
|
||||||
"text": "\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12,
|
"id": 12,
|
||||||
|
"text": ")",
|
||||||
"logprob": -1.0773665,
|
"logprob": -1.0773665,
|
||||||
"text": ")"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
|
"text": ";",
|
||||||
"logprob": -0.27640742,
|
"logprob": -0.27640742,
|
||||||
"text": ";"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 837,
|
"id": 837,
|
||||||
|
"text": "\n ",
|
||||||
"logprob": -1.6970354,
|
"logprob": -1.6970354,
|
||||||
"text": "\n "
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1320,
|
"id": 1320,
|
||||||
|
"text": " if",
|
||||||
"logprob": -1.4495516,
|
"logprob": -1.4495516,
|
||||||
"text": " if"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 375,
|
"id": 375,
|
||||||
|
"text": " (",
|
||||||
"logprob": -0.23609057,
|
"logprob": -0.23609057,
|
||||||
"text": " ("
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4899,
|
"id": 4899,
|
||||||
|
"text": "action",
|
||||||
"logprob": -1.1916996,
|
"logprob": -1.1916996,
|
||||||
"text": "action"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3535,
|
"id": 3535,
|
||||||
|
"text": " ==",
|
||||||
"logprob": -0.8918753,
|
"logprob": -0.8918753,
|
||||||
"text": " =="
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5109,
|
"id": 5109,
|
||||||
|
"text": " null",
|
||||||
"logprob": -0.3933342,
|
"logprob": -0.3933342,
|
||||||
"text": " null"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12,
|
"id": 12,
|
||||||
|
"text": ")",
|
||||||
"logprob": -0.43212673,
|
"logprob": -0.43212673,
|
||||||
"text": ")"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 731,
|
"id": 731,
|
||||||
|
"text": " {",
|
||||||
"logprob": -0.17702064,
|
"logprob": -0.17702064,
|
||||||
"text": " {"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1260,
|
"id": 1260,
|
||||||
|
"text": "\n ",
|
||||||
"logprob": -0.07027565,
|
"logprob": -0.07027565,
|
||||||
"text": "\n "
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10519,
|
"id": 10519,
|
||||||
|
"text": " throw",
|
||||||
"logprob": -1.3915029,
|
"logprob": -1.3915029,
|
||||||
"text": " throw"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2084,
|
"id": 2084,
|
||||||
|
"text": " new",
|
||||||
"logprob": -0.04201372,
|
"logprob": -0.04201372,
|
||||||
"text": " new"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 150858,
|
"id": 150858,
|
||||||
|
"text": " RuntimeException",
|
||||||
"logprob": -1.7329919,
|
"logprob": -1.7329919,
|
||||||
"text": " RuntimeException"
|
"special": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
|
||||||
}
|
}
|
|
@ -14,6 +14,7 @@ pub struct Token {
|
||||||
id: u32,
|
id: u32,
|
||||||
text: String,
|
text: String,
|
||||||
logprob: Option<f32>,
|
logprob: Option<f32>,
|
||||||
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
{
|
{
|
||||||
assert_eq!(token.id, expected_token.id);
|
assert_eq!(token.id, expected_token.id);
|
||||||
assert_eq!(token.text, expected_token.text);
|
assert_eq!(token.text, expected_token.text);
|
||||||
|
assert_eq!(token.special, expected_token.special);
|
||||||
if let Some(logprob) = token.logprob {
|
if let Some(logprob) = token.logprob {
|
||||||
let expected_logprob = expected_token.logprob.unwrap();
|
let expected_logprob = expected_token.logprob.unwrap();
|
||||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||||
|
|
|
@ -1,117 +1,137 @@
|
||||||
{
|
{
|
||||||
|
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
|
"seed": null,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"logprob": null,
|
"text": "<pad>",
|
||||||
"text": "<pad>"
|
"logprob": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": "",
|
||||||
"logprob": -1.3656927,
|
"logprob": -1.3656927,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 215100,
|
"id": 215100,
|
||||||
|
"text": "\"\"\"",
|
||||||
"logprob": -2.6551573,
|
"logprob": -2.6551573,
|
||||||
"text": "\"\"\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 46138,
|
"id": 46138,
|
||||||
|
"text": "Test",
|
||||||
"logprob": -1.8059857,
|
"logprob": -1.8059857,
|
||||||
"text": "Test"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": "the",
|
||||||
"logprob": -1.2102449,
|
"logprob": -1.2102449,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": "",
|
||||||
"logprob": -1.6057279,
|
"logprob": -1.6057279,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -3.6060903,
|
"logprob": -3.6060903,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
|
"text": "of",
|
||||||
"logprob": -0.5270343,
|
"logprob": -0.5270343,
|
||||||
"text": "of"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": "the",
|
||||||
"logprob": -0.62522805,
|
"logprob": -0.62522805,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": "",
|
||||||
"logprob": -1.4069618,
|
"logprob": -1.4069618,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -2.621994,
|
"logprob": -2.621994,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
|
"text": "of",
|
||||||
"logprob": -1.3172221,
|
"logprob": -1.3172221,
|
||||||
"text": "of"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": "the",
|
||||||
"logprob": -0.3501925,
|
"logprob": -0.3501925,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": "",
|
||||||
"logprob": -0.7219573,
|
"logprob": -0.7219573,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -1.0494149,
|
"logprob": -1.0494149,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 260,
|
"id": 260,
|
||||||
|
"text": ".",
|
||||||
"logprob": -1.0803378,
|
"logprob": -1.0803378,
|
||||||
"text": "."
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": "",
|
||||||
"logprob": -0.32933083,
|
"logprob": -0.32933083,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 215100,
|
"id": 215100,
|
||||||
|
"text": "\"\"\"",
|
||||||
"logprob": -0.11268901,
|
"logprob": -0.11268901,
|
||||||
"text": "\"\"\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2978,
|
"id": 2978,
|
||||||
|
"text": "test",
|
||||||
"logprob": -1.5846587,
|
"logprob": -1.5846587,
|
||||||
"text": "test"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 290,
|
"id": 290,
|
||||||
|
"text": "_",
|
||||||
"logprob": -0.49796978,
|
"logprob": -0.49796978,
|
||||||
"text": "_"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4125,
|
"id": 4125,
|
||||||
|
"text": "test",
|
||||||
"logprob": -2.0026445,
|
"logprob": -2.0026445,
|
||||||
"text": "test"
|
"special": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
|
||||||
}
|
}
|
|
@ -108,8 +108,10 @@ message Generation {
|
||||||
float token_logprob = 4;
|
float token_logprob = 4;
|
||||||
/// Text
|
/// Text
|
||||||
string token_text = 5;
|
string token_text = 5;
|
||||||
|
/// Is it a special token
|
||||||
|
bool token_is_special = 6;
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
GeneratedText generated_text = 6;
|
GeneratedText generated_text = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::GenerateRequest;
|
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
|
@ -138,7 +138,7 @@ impl Infer {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens.logprobs.into_iter())
|
.zip(tokens.logprobs.into_iter())
|
||||||
.zip(tokens.texts.into_iter())
|
.zip(tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token { id, text, logprob })
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
|
@ -372,6 +372,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
||||||
id: generation.token_id,
|
id: generation.token_id,
|
||||||
text: generation.token_text,
|
text: generation.token_text,
|
||||||
logprob: generation.token_logprob,
|
logprob: generation.token_logprob,
|
||||||
|
special: generation.token_is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
|
@ -420,7 +421,7 @@ pub(crate) enum InferStreamResponse {
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
pub(crate) prefill: Vec<Token>,
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
pub(crate) tokens: Vec<Token>,
|
pub(crate) tokens: Vec<Token>,
|
||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
|
|
|
@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest {
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct PrefillToken {
|
||||||
|
#[schema(example = 0)]
|
||||||
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
|
text: String,
|
||||||
|
#[schema(nullable = true, example = -0.34)]
|
||||||
|
logprob: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
|
@ -94,6 +104,8 @@ pub struct Token {
|
||||||
text: String,
|
text: String,
|
||||||
#[schema(nullable = true, example = -0.34)]
|
#[schema(nullable = true, example = -0.34)]
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -116,7 +128,7 @@ pub(crate) struct Details {
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Option<Vec<Token>>,
|
pub prefill: Option<Vec<PrefillToken>>,
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Option<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -255,11 +255,11 @@ async fn generate_stream(
|
||||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
span.record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{total_time:?}"));
|
||||||
span.record("validation_time", format!("{:?}", validation_time));
|
span.record("validation_time", format!("{validation_time:?}"));
|
||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{queue_time:?}"));
|
||||||
span.record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{inference_time:?}"));
|
||||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
@ -349,6 +349,7 @@ pub async fn run(
|
||||||
schemas(
|
schemas(
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
Details,
|
Details,
|
||||||
|
|
|
@ -172,7 +172,9 @@ class CausalLMBatch(Batch):
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
left_offset = max_sequence_length - batch.max_sequence_length
|
left_offset = max_sequence_length - batch.max_sequence_length
|
||||||
batch_left_offset = (
|
batch_left_offset = (
|
||||||
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
|
batch.attention_mask.shape[1]
|
||||||
|
- batch.max_sequence_length
|
||||||
|
- batch.padding_right_offset
|
||||||
)
|
)
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
|
@ -443,6 +445,7 @@ class CausalLM(Model):
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
next_token_id_squeezed in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch)
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
else:
|
else:
|
||||||
batch_left_offset = (
|
batch_left_offset = (
|
||||||
batch.decoder_attention_mask.shape[1]
|
batch.decoder_attention_mask.shape[1]
|
||||||
- batch.max_decoder_input_length - batch.padding_right_offset
|
- batch.max_decoder_input_length
|
||||||
|
- batch.padding_right_offset
|
||||||
)
|
)
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
|
@ -494,14 +495,10 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
prefill_token_ids, [float("nan")], prefill_texts
|
[self.tokenizer.bos_token_id],
|
||||||
|
[float("nan")],
|
||||||
|
[self.tokenizer.bos_token],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
@ -512,6 +509,7 @@ class Seq2SeqLM(Model):
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
next_token_id_squeezed in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -73,6 +73,7 @@ class Generation:
|
||||||
token_id: int
|
token_id: int
|
||||||
token_logprob: float
|
token_logprob: float
|
||||||
token_text: str
|
token_text: str
|
||||||
|
token_is_special: bool
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
|
@ -84,6 +85,7 @@ class Generation:
|
||||||
token_id=self.token_id,
|
token_id=self.token_id,
|
||||||
token_logprob=self.token_logprob,
|
token_logprob=self.token_logprob,
|
||||||
token_text=self.token_text,
|
token_text=self.token_text,
|
||||||
|
token_is_special=self.token_is_special,
|
||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
|
|
Loading…
Reference in New Issue