diff --git a/.github/workflows/server-tests.yaml b/.github/workflows/server-tests.yaml new file mode 100644 index 00000000..5bb4653a --- /dev/null +++ b/.github/workflows/server-tests.yaml @@ -0,0 +1,30 @@ +name: Server Tests + +on: + pull_request: + paths: + - "server/**" + - "proto/**" + +jobs: + run_tests: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.9 + - name: Loading cache. + uses: actions/cache@v2 + id: model_cache + with: + path: ~/.cache/huggingface/ + key: models + - name: Install server dependencies + run: | + make install-server + - name: Run tests + run: | + pip install pytest + pytest -sv server/tests diff --git a/README.md b/README.md index 92879707..8496ce8f 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. - 45ms per token generation for BLOOM with 8xA100 80GB - Logits warpers (temperature scaling, topk ...) - Stop sequences +- Log probabilities ## Officially supported models diff --git a/proto/generate.proto b/proto/generate.proto index 0c67de03..16539f8b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -27,7 +27,7 @@ message ClearCacheRequest {} /// Empty response message ClearCacheResponse {} -message LogitsWarperParameters { +message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; /// restricting to the k highest probability elements @@ -52,8 +52,8 @@ message Request { string inputs = 2; /// The number of tokens inside inputs uint32 input_length = 3; - /// Logits Warper Parameters - LogitsWarperParameters parameters = 4; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters StoppingCriteriaParameters stopping_parameters = 5; } @@ -71,11 +71,17 @@ message GeneratedText { /// Request Request request = 1; /// Output - string output = 2; + string output_text = 2; /// Number of generated tokens - uint32 tokens = 3; + uint32 generated_tokens = 3; + /// Tokens + repeated string tokens = 4; + /// Token IDs + repeated uint32 token_ids = 5; + /// Logprobs + repeated float logprobs = 6; /// Finish reason - string finish_reason = 4; + string finish_reason = 7; } message GenerateRequest { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index ae337dd6..295b009b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters, + Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/batcher.rs b/router/src/batcher.rs index a72a6e44..1484434c 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -187,9 +187,13 @@ fn send_generated(finished: Vec, db: &Db) { let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); + let response = InferResponse { - output: output.output, + output_text: output.output_text, + generated_tokens: output.generated_tokens, + token_ids: output.token_ids, tokens: output.tokens, + logprobs: output.logprobs, finish_reason: output.finish_reason, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid @@ -202,8 +206,11 @@ fn send_generated(finished: Vec, db: &Db) { #[derive(Debug)] pub(crate) struct InferResponse { - pub(crate) output: String, - pub(crate) tokens: u32, + pub(crate) output_text: String, + pub(crate) generated_tokens: u32, + pub(crate) token_ids: Vec, + pub(crate) tokens: Vec, + pub(crate) logprobs: Vec, pub(crate) finish_reason: String, pub(crate) queued: Instant, pub(crate) start: Instant, diff --git a/router/src/db.rs b/router/src/db.rs index 24fb7a09..df9f2b8e 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -5,7 +5,7 @@ use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters, + Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -71,7 +71,7 @@ impl State { id: *id, inputs: entry.request.inputs.clone(), input_length: entry.input_length as u32, - parameters: Some(LogitsWarperParameters::from( + parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), stopping_parameters: Some(StoppingCriteriaParameters::from( @@ -162,7 +162,7 @@ impl Db { } } -impl From for LogitsWarperParameters { +impl From for NextTokenChooserParameters { fn from(parameters: GenerateParameters) -> Self { Self { temperature: parameters.temperature, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6c694ee..03711580 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters { pub do_sample: bool, #[serde(default = "default_max_new_tokens")] pub max_new_tokens: u32, + #[serde(default)] pub stop: Vec, + #[serde(default)] + pub details: bool, } fn default_temperature() -> f32 { @@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters { do_sample: default_do_sample(), max_new_tokens: default_max_new_tokens(), stop: vec![], + details: false, } } @@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Serialize)] +pub(crate) struct Details { + pub finish_reason: String, + pub generated_tokens: u32, + pub tokens: Vec<(u32, String, f32)>, +} + #[derive(Serialize)] pub(crate) struct GeneratedText { pub generated_text: String, - pub finish_reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option
, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 59296269..2e6c473f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,5 +1,5 @@ use crate::{ - Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, + Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -54,6 +54,7 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json { + let tokens = response + .token_ids + .into_iter() + .zip(response.tokens.into_iter()) + .zip(response.logprobs.into_iter()) + .map(|((id, text), logprob)| (id, text, logprob)) + .collect(); + Some(Details { + finish_reason: response.finish_reason, + generated_tokens: response.generated_tokens, + tokens, + }) + } + false => None, + }; + // Timings let total_time = start_time.elapsed(); let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; let inference_time = response.end - response.start; - let time_per_token = inference_time / response.tokens; + let time_per_token = inference_time / response.generated_tokens; // Headers let mut headers = HeaderMap::new(); @@ -141,12 +162,12 @@ async fn generate( tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); - tracing::info!("Output: {}", response.output); + tracing::info!("Output: {}", response.output_text); // Send response let response = vec![GeneratedText { - generated_text: response.output, - finish_reason: response.finish_reason, + generated_text: response.output_text, + details, }]; Ok((headers, Json(response))) } diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 24cdafac..eb72b8a2 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2 @pytest.fixture def default_pb_parameters(): - return generate_pb2.LogitsWarperParameters( + return generate_pb2.NextTokenChooserParameters( temperature=1.0, top_k=0, top_p=1.0, diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index c5dbaa3e..2a6e670e 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -240,10 +244,10 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -259,10 +263,12 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -279,9 +285,11 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f38776cc..683d9fdd 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784)" + assert generated_texts[0].output_text == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -242,12 +243,12 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784)" + assert generated_texts[0].output_text == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -263,10 +264,10 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -283,11 +284,11 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 94ec70d5..f1b11bc2 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 def test_seq2seq_lm_generate_token_completion_multi( @@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few " + assert generated_texts[0].output_text == "a few " assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generated_texts[0].tokens == 5 + assert generated_texts[0].generated_tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert generated_texts == [] @@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 def test_batch_concatenate( @@ -287,28 +287,28 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few " + assert generated_texts[0].output_text == "a few " assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generated_texts[0].tokens == 5 + assert generated_texts[0].generated_tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 20e26419..3561a8ea 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM): ) # Logits are sharded, so we need to gather them - logits_shard = outputs.logits[:, -1, :].contiguous() - - batch_size, vocab_shard_size = logits_shard.shape - vocab_size = self.world_size * vocab_shard_size - logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, logits_shard, group=self.process_group) - logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) return logits, outputs.past_key_values diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 72095858..5d62137f 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -22,6 +22,7 @@ class CausalLMBatch: # All tokens all_input_ids: List[torch.Tensor] + all_logprobs: List[Optional[torch.Tensor]] # Lengths of all generations present in the batch input_lengths: List[int] @@ -52,6 +53,7 @@ class CausalLMBatch: next_token_choosers = [] stopping_criterias = [] input_lengths = [] + all_logprobs = [] # Parse batch for r in pb.requests: @@ -61,6 +63,7 @@ class CausalLMBatch: stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + all_logprobs.append(None) pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( @@ -78,6 +81,7 @@ class CausalLMBatch: attention_mask=tokenized_inputs["attention_mask"], past_key_values=None, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -95,6 +99,7 @@ class CausalLMBatch: requests = [] input_lengths = [] all_input_ids = [] + all_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -110,6 +115,7 @@ class CausalLMBatch: requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) + all_logprobs.extend(batch.all_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -217,6 +223,7 @@ class CausalLMBatch: attention_mask=attention_mask, past_key_values=past_key_values, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -291,6 +298,7 @@ class CausalLM(Model): next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] + next_batch_all_logprobs = [] # Metadata next_batch_size = 0 @@ -307,6 +315,7 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.all_logprobs, ) # For each member of the batch @@ -316,34 +325,59 @@ class CausalLM(Model): logits, next_token_chooser, stopping_criteria, - all_tokens, + all_input_ids, + all_logprobs, ) in enumerate(iterator): # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + tokens, logprobs = next_token_chooser(all_input_ids, logits) + next_token = tokens[-1].view(1, 1) # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) + all_input_ids = torch.cat([all_input_ids, next_token]) + new_input_length = input_length + 1 + + if all_logprobs is None: + # logprobs of all prompt tokens (except the first one) and the generated token + all_logprobs = logprobs.gather(1, all_input_ids[1:]) + else: + # logprob of the generated token + next_token_logprob = logprobs[-1, next_token] + all_logprobs = torch.cat([all_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(all_tokens) + stop, reason = stopping_criteria(all_input_ids) if stop: # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True + output_text = self.tokenizer.decode( + all_input_ids.squeeze(-1), skip_special_tokens=True ) + # Slice with input_length to remove padding + token_ids = all_input_ids[-new_input_length:] + tokens = self.tokenizer.batch_decode(token_ids) + # Add NaN for the first prompt token + logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze( + 1 + ).tolist() + # Add to the list of finished generations with the original request generated_texts.append( GeneratedText( - request, output, stopping_criteria.current_tokens, reason + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.squeeze(1).tolist(), + logprobs=logprobs, + reason=reason, ) ) # add to the next batch else: next_batch_keep_indices.append(i) next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) + next_batch_all_input_ids.append(all_input_ids) + next_batch_all_logprobs.append(all_logprobs) next_batch_size += 1 - new_input_length = input_length + 1 next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length @@ -397,6 +431,7 @@ class CausalLM(Model): attention_mask=next_batch_attention_mask, past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, + all_logprobs=next_batch_all_logprobs, input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 680ea43e..a713e69e 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -321,12 +321,8 @@ class GalacticaSharded(Galactica): ) # Logits are sharded, so we need to gather them - logits_shard = outputs.logits[:, -1, :].contiguous() - - batch_size, vocab_shard_size = logits_shard.shape - vocab_size = self.world_size * vocab_shard_size - logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, logits_shard, group=self.process_group) - logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) return logits, outputs.past_key_values diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 93e31b4a..e51ce60b 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -30,6 +30,7 @@ class Seq2SeqLMBatch: # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + decoder_logprobs: List[Optional[torch.Tensor]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -60,6 +61,7 @@ class Seq2SeqLMBatch: decoder_input_ids = [] decoder_input_lengths = [] + decoder_logprobs = [] # Parse batch for r in pb.requests: @@ -72,6 +74,7 @@ class Seq2SeqLMBatch: stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + decoder_logprobs.append(None) # Tokenize batch pad_to_multiple_of = 8 if "gpu" in str(device) else None @@ -95,6 +98,7 @@ class Seq2SeqLMBatch: past_key_values=None, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -117,6 +121,7 @@ class Seq2SeqLMBatch: requests = [] input_lengths = [] decoder_input_lengths = [] + decoder_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -137,6 +142,7 @@ class Seq2SeqLMBatch: requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) + decoder_logprobs.extend(batch.decoder_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -286,6 +292,7 @@ class Seq2SeqLMBatch: past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -385,6 +392,7 @@ class Seq2SeqLM(Model): next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] + next_batch_decoder_logprobs = [] # Metadata next_batch_size = 0 @@ -399,6 +407,7 @@ class Seq2SeqLM(Model): batch.requests, batch.input_lengths, batch.decoder_input_lengths, + batch.decoder_logprobs, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -411,38 +420,58 @@ class Seq2SeqLM(Model): request, input_length, decoder_input_length, + decoder_logprobs, logits, next_token_chooser, stopping_criteria, input_tokens, - decoder_tokens, + decoder_input_ids, ) in enumerate(iterator): - all_tokens = torch.cat([input_tokens, decoder_tokens]) # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + next_token, logprobs = next_token_chooser(decoder_input_ids, logits) # Append next token to decoder tokens - decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token]) + new_decoder_input_length = decoder_input_length + 1 + + next_token_logprob = logprobs[-1, next_token] + if decoder_logprobs is None: + decoder_logprobs = next_token_logprob + else: + decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(decoder_tokens) + stop, reason = stopping_criteria(decoder_input_ids) if stop: - # Decode tokens - output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) + # Slice with decoder_input_length to remove padding + # Decode all tokens + token_ids = decoder_input_ids[-new_decoder_input_length:] + output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + tokens = self.tokenizer.batch_decode(token_ids) + # Add NaN for the bos token + logprobs = [float("nan")] + decoder_logprobs[ + -new_decoder_input_length: + ].tolist() # Add to the list of finished generations with the original request generated_texts.append( GeneratedText( - request, output, stopping_criteria.current_tokens, reason + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.tolist(), + logprobs=logprobs, + reason=reason, ) ) # add to the next batch else: next_batch_keep_indices.append(i) - next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0)) + next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_size += 1 - new_decoder_input_length = decoder_input_length + 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_decoder_logprobs.append(decoder_logprobs) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -515,6 +544,7 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, + decoder_logprobs=next_batch_decoder_logprobs, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 91ec75d0..e76cf697 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -30,14 +30,20 @@ class Batch(ABC): @dataclass class GeneratedText: request: generate_pb2.Request - output: str - tokens: int + output_text: str + generated_tokens: int + tokens: List[str] + token_ids: List[int] + logprobs: List[float] reason: str def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( request=self.request, - output=self.output, + output_text=self.output_text, + generated_tokens=self.generated_tokens, tokens=self.tokens, + token_ids=self.token_ids, + logprobs=self.logprobs, finish_reason=self.reason, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 661df05e..9dd91151 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -55,12 +55,16 @@ class NextTokenChooser: self.choice = Sampling() if sampling else Greedy() def __call__(self, input_ids, scores): + # Warp logits scores = self.warpers(input_ids, scores) + # Compute logprobs + logprobs = torch.log_softmax(scores, -1) + # Choose tokens next_ids = self.choice(scores) - return next_ids.unsqueeze(-1) + return next_ids, logprobs @classmethod - def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser": + def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": return NextTokenChooser( temperature=pb.temperature, top_k=pb.top_k,