diff --git a/proto/generate.proto b/proto/generate.proto index 16539f8b..921bd5c0 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -36,6 +36,8 @@ message NextTokenChooserParameters { float top_p = 3; /// apply sampling on the logits bool do_sample = 4; + /// random seed for sampling + optional uint64 seed = 5; } message StoppingCriteriaParameters { @@ -82,6 +84,8 @@ message GeneratedText { repeated float logprobs = 6; /// Finish reason string finish_reason = 7; + /// Seed + optional uint64 seed = 8; } message GenerateRequest { diff --git a/router/client/build.rs b/router/client/build.rs index 1876b217..7b577fcc 100644 --- a/router/client/build.rs +++ b/router/client/build.rs @@ -1,6 +1,7 @@ use std::fs; fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/generate.proto"); fs::create_dir("src/pb").unwrap_or(()); tonic_build::configure() .build_client(true) diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 624ac82d..baf58af4 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -191,6 +191,7 @@ fn send_generated(finished: Vec, entries: &mut IntMap tokens: output.tokens, logprobs: output.logprobs, finish_reason: output.finish_reason, + seed: output.seed, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid end: Instant::now(), @@ -208,6 +209,7 @@ pub(crate) struct InferResponse { pub(crate) tokens: Vec, pub(crate) logprobs: Vec, pub(crate) finish_reason: String, + pub(crate) seed: Option, pub(crate) queued: Instant, pub(crate) start: Instant, pub(crate) end: Instant, diff --git a/router/src/db.rs b/router/src/db.rs index 51de9d05..15007b64 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters { top_k: parameters.top_k as u32, top_p: parameters.top_p, do_sample: parameters.do_sample, + seed: parameters.seed, } } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 03711580..1aeac302 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters { pub stop: Vec, #[serde(default)] pub details: bool, + #[serde(default)] + pub seed: Option, } fn default_temperature() -> f32 { @@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters { max_new_tokens: default_max_new_tokens(), stop: vec![], details: false, + seed: None, } } @@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest { pub(crate) struct Details { pub finish_reason: String, pub generated_tokens: u32, + pub seed: Option, pub tokens: Vec<(u32, String, f32)>, } diff --git a/router/src/server.rs b/router/src/server.rs index 623dd07c..86041b96 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -55,6 +55,7 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json generate_pb2.GeneratedText: return generate_pb2.GeneratedText( @@ -49,4 +50,5 @@ class GeneratedText: token_ids=self.token_ids, logprobs=self.logprobs, finish_reason=self.reason, + seed=self.seed, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 1ddeed6e..1d087a42 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2 class Sampling: + def __init__(self, seed: Optional[int] = None): + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) + else: + self.generator.seed() + def __call__(self, logits): probs = torch.nn.functional.softmax(logits, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + next_tokens = torch.multinomial( + probs, num_samples=1, generator=self.generator + ).squeeze(1) return next_tokens + @property + def seed(self) -> int: + return self.generator.initial_seed() + class Greedy: def __call__(self, logits): @@ -36,7 +49,9 @@ class Greedy: class NextTokenChooser: - def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False): + def __init__( + self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None + ): warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -53,7 +68,7 @@ class NextTokenChooser: sampling = True self.warpers = warpers - self.choice = Sampling() if sampling else Greedy() + self.choice = Sampling(seed) if sampling else Greedy() def __call__(self, input_ids, scores): # Warp logits @@ -66,11 +81,14 @@ class NextTokenChooser: @classmethod def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": + # handle protobuf making default values 0 + seed = pb.seed if pb.HasField("seed") else None return NextTokenChooser( temperature=pb.temperature, top_k=pb.top_k, top_p=pb.top_p, do_sample=pb.do_sample, + seed=seed, )