feat: Support sampling seeding (#37)
Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com>
This commit is contained in:
parent
1539d3cbbe
commit
cd298bc5e5
|
@ -36,6 +36,8 @@ message NextTokenChooserParameters {
|
|||
float top_p = 3;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 4;
|
||||
/// random seed for sampling
|
||||
optional uint64 seed = 5;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -82,6 +84,8 @@ message GeneratedText {
|
|||
repeated float logprobs = 6;
|
||||
/// Finish reason
|
||||
string finish_reason = 7;
|
||||
/// Seed
|
||||
optional uint64 seed = 8;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
||||
fs::create_dir("src/pb").unwrap_or(());
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
|
|
|
@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
|
|||
tokens: output.tokens,
|
||||
logprobs: output.logprobs,
|
||||
finish_reason: output.finish_reason,
|
||||
seed: output.seed,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
end: Instant::now(),
|
||||
|
@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
|
|||
pub(crate) tokens: Vec<String>,
|
||||
pub(crate) logprobs: Vec<f32>,
|
||||
pub(crate) finish_reason: String,
|
||||
pub(crate) seed: Option<u64>,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) end: Instant,
|
||||
|
|
|
@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
|
|||
top_k: parameters.top_k as u32,
|
||||
top_p: parameters.top_p,
|
||||
do_sample: parameters.do_sample,
|
||||
seed: parameters.seed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
|
|||
pub stop: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
|
@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
max_new_tokens: default_max_new_tokens(),
|
||||
stop: vec![],
|
||||
details: false,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
|
|||
pub(crate) struct Details {
|
||||
pub finish_reason: String,
|
||||
pub generated_tokens: u32,
|
||||
pub seed: Option<u64>,
|
||||
pub tokens: Vec<(u32, String, f32)>,
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||
max_new_tokens: 1,
|
||||
stop: vec![],
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token
|
||||
time_per_token,
|
||||
seed
|
||||
)
|
||||
)]
|
||||
async fn generate(
|
||||
|
@ -118,6 +120,7 @@ async fn generate(
|
|||
.map(|((id, text), logprob)| (id, text, logprob))
|
||||
.collect();
|
||||
Some(Details {
|
||||
seed: response.seed,
|
||||
finish_reason: response.finish_reason,
|
||||
generated_tokens: response.generated_tokens,
|
||||
tokens,
|
||||
|
@ -162,6 +165,7 @@ async fn generate(
|
|||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||
tracing::Span::current().record("seed", format!("{:?}", response.seed));
|
||||
tracing::info!("Output: {}", response.output_text);
|
||||
|
||||
// Send response
|
||||
|
|
|
@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
|
|||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
|
|||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -296,7 +296,10 @@ class CausalLM(Model):
|
|||
)
|
||||
with context_manager():
|
||||
logits, past = self.forward(
|
||||
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.position_ids,
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
# List of indices to cache
|
||||
|
@ -373,6 +376,12 @@ class CausalLM(Model):
|
|||
1
|
||||
).tolist()
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
|
@ -383,6 +392,7 @@ class CausalLM(Model):
|
|||
token_ids=token_ids.squeeze(1).tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
|
|
|
@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
|
|||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
|
|
@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize,
|
||||
trust_remote_code=True, # required
|
||||
).to(device).eval()
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
|
|||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
|
|||
logprobs = [float("nan")] + decoder_logprobs[
|
||||
-decoder_input_length:
|
||||
].tolist()
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
|
@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
|
|||
token_ids=token_ids.tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
@ -39,6 +39,7 @@ class GeneratedText:
|
|||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
reason: str
|
||||
seed: Optional[int]
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
|
@ -49,4 +50,5 @@ class GeneratedText:
|
|||
token_ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
finish_reason=self.reason,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
|
|
@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
|
|||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: Optional[int] = None):
|
||||
self.generator = torch.Generator()
|
||||
if seed is not None:
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
self.generator.seed()
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1, generator=self.generator
|
||||
).squeeze(1)
|
||||
return next_tokens
|
||||
|
||||
@property
|
||||
def seed(self) -> int:
|
||||
return self.generator.initial_seed()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
|
@ -36,7 +49,9 @@ class Greedy:
|
|||
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
|
||||
def __init__(
|
||||
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
||||
):
|
||||
warpers = LogitsProcessorList()
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
|
@ -53,7 +68,7 @@ class NextTokenChooser:
|
|||
sampling = True
|
||||
|
||||
self.warpers = warpers
|
||||
self.choice = Sampling() if sampling else Greedy()
|
||||
self.choice = Sampling(seed) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
|
@ -66,11 +81,14 @@ class NextTokenChooser:
|
|||
|
||||
@classmethod
|
||||
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
||||
# handle protobuf making default values 0
|
||||
seed = pb.seed if pb.HasField("seed") else None
|
||||
return NextTokenChooser(
|
||||
temperature=pb.temperature,
|
||||
top_k=pb.top_k,
|
||||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue