feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com>
This commit is contained in:
OlivierDehaene 2023-01-30 15:36:16 +01:00 committed by GitHub
parent 1539d3cbbe
commit cd298bc5e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 78 additions and 16 deletions

View File

@ -36,6 +36,8 @@ message NextTokenChooserParameters {
float top_p = 3;
/// apply sampling on the logits
bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
}
message StoppingCriteriaParameters {
@ -82,6 +84,8 @@ message GeneratedText {
repeated float logprobs = 6;
/// Finish reason
string finish_reason = 7;
/// Seed
optional uint64 seed = 8;
}
message GenerateRequest {

View File

@ -1,6 +1,7 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(());
tonic_build::configure()
.build_client(true)

View File

@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
tokens: output.tokens,
logprobs: output.logprobs,
finish_reason: output.finish_reason,
seed: output.seed,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String,
pub(crate) seed: Option<u64>,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,

View File

@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
seed: parameters.seed,
}
}
}

View File

@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
pub stop: Vec<String>,
#[serde(default)]
pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
}
fn default_temperature() -> f32 {
@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(),
stop: vec![],
details: false,
seed: None,
}
}
@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
pub(crate) struct Details {
pub finish_reason: String,
pub generated_tokens: u32,
pub seed: Option<u64>,
pub tokens: Vec<(u32, String, f32)>,
}

View File

@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
max_new_tokens: 1,
stop: vec![],
details: false,
seed: None,
},
},
)
@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
validation_time,
queue_time,
inference_time,
time_per_token
time_per_token,
seed
)
)]
async fn generate(
@ -118,6 +120,7 @@ async fn generate(
.map(|((id, text), logprob)| (id, text, logprob))
.collect();
Some(Details {
seed: response.seed,
finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens,
tokens,
@ -162,6 +165,7 @@ async fn generate(
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.output_text);
// Send response

View File

@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass
@ -296,7 +296,10 @@ class CausalLM(Model):
)
with context_manager():
logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
@ -373,6 +376,12 @@ class CausalLM(Model):
1
).tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
@ -383,6 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
)
# add to the next batch

View File

@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,

View File

@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
}
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
).to(device).eval()
self.model = (
AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass
@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length:
].tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(
@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(),
logprobs=logprobs,
reason=reason,
seed=seed,
)
)
# add to the next batch

View File

@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from typing import List, Optional
from transformers import PreTrainedTokenizerBase
@ -39,6 +39,7 @@ class GeneratedText:
token_ids: List[int]
logprobs: List[float]
reason: str
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
@ -49,4 +50,5 @@ class GeneratedText:
token_ids=self.token_ids,
logprobs=self.logprobs,
finish_reason=self.reason,
seed=self.seed,
)

View File

@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
class Sampling:
def __init__(self, seed: Optional[int] = None):
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy:
def __call__(self, logits):
@ -36,7 +49,9 @@ class Greedy:
class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
@ -53,7 +68,7 @@ class NextTokenChooser:
sampling = True
self.warpers = warpers
self.choice = Sampling() if sampling else Greedy()
self.choice = Sampling(seed) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
@ -66,11 +81,14 @@ class NextTokenChooser:
@classmethod
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser(
temperature=pb.temperature,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=seed,
)