feat(server): Support SantaCoder (#26)
This commit is contained in:
parent
f7ac394935
commit
15511edc01
|
@ -25,6 +25,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
||||||
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
||||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
||||||
|
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||||
|
|
||||||
Other models are supported on a best effort basis using:
|
Other models are supported on a best effort basis using:
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use serde_json::Value;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::io::{BufRead, BufReader, Read};
|
use std::io::{BufRead, BufReader, Read};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -11,7 +12,6 @@ use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use serde_json::Value;
|
|
||||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
|
@ -299,16 +299,12 @@ fn shard_manager(
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
env.push((
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
"HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
|
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
env.push((
|
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
||||||
"CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
|
|
||||||
));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
|
|
|
@ -74,9 +74,10 @@ impl Batcher {
|
||||||
|
|
||||||
// Await on the response from the background task
|
// Await on the response from the background task
|
||||||
// We can safely unwrap as the background task will never drop the sender
|
// We can safely unwrap as the background task will never drop the sender
|
||||||
response_rx.await.unwrap().map_err(
|
response_rx
|
||||||
|err| InferError::GenerationError(err.to_string())
|
.await
|
||||||
)
|
.unwrap()
|
||||||
|
.map_err(|err| InferError::GenerationError(err.to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,9 @@ fn validation_worker(
|
||||||
) {
|
) {
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Some((request, response_tx)) = receiver.blocking_recv() {
|
while let Some((request, response_tx)) = receiver.blocking_recv() {
|
||||||
response_tx.send(validate(request, &tokenizer, max_input_length)).unwrap_or(())
|
response_tx
|
||||||
|
.send(validate(request, &tokenizer, max_input_length))
|
||||||
|
.unwrap_or(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,8 +119,9 @@ fn validate(
|
||||||
}
|
}
|
||||||
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
|
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
|
||||||
return Err(ValidationError::StopSequence(
|
return Err(ValidationError::StopSequence(
|
||||||
MAX_STOP_SEQUENCES, request.parameters.stop.len(),
|
MAX_STOP_SEQUENCES,
|
||||||
))
|
request.parameters.stop.len(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
|
@ -127,14 +130,11 @@ fn validate(
|
||||||
let input_length = inputs.len();
|
let input_length = inputs.len();
|
||||||
|
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
Err(ValidationError::InputLength(
|
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||||
input_length,
|
|
||||||
max_input_length,
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok((input_length, request))
|
Ok((input_length, request))
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,24 +16,3 @@ def default_pb_parameters():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_stop_parameters():
|
def default_pb_stop_parameters():
|
||||||
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def bloom_560m_tokenizer():
|
|
||||||
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def gpt2_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
|
|
||||||
tokenizer.pad_token_id = 50256
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def mt0_small_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
"bigscience/mt0-small", padding_side="left"
|
|
||||||
)
|
|
||||||
tokenizer.bos_token_id = 0
|
|
||||||
return tokenizer
|
|
||||||
|
|
|
@ -2,12 +2,23 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def default_bloom():
|
||||||
|
return BLOOM("bigscience/bloom-560m")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def bloom_560m_tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
|
@ -44,11 +55,6 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def default_bloom():
|
|
||||||
return BLOOM("bigscience/bloom-560m")
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
||||||
batch = default_bloom_batch
|
batch = default_bloom_batch
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,24 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def default_causal_lm():
|
||||||
|
return CausalLM("gpt2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def gpt2_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
|
||||||
|
tokenizer.pad_token_id = 50256
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
|
@ -39,11 +52,6 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def default_causal_lm():
|
|
||||||
return CausalLM("gpt2")
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||||
batch = default_causal_lm_batch
|
batch = default_causal_lm_batch
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.models.causal_lm import CausalLMBatch
|
||||||
|
from text_generation.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def default_santacoder():
|
||||||
|
return SantaCoder("bigcode/santacoder")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
|
return generate_pb2.Request(
|
||||||
|
id=0,
|
||||||
|
inputs="def",
|
||||||
|
input_length=1,
|
||||||
|
parameters=default_pb_parameters,
|
||||||
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_pb_batch(default_pb_request):
|
||||||
|
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
|
return generate_pb2.Request(
|
||||||
|
id=0,
|
||||||
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
input_length=5,
|
||||||
|
parameters=default_pb_parameters,
|
||||||
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_fim_pb_batch(default_fim_pb_request):
|
||||||
|
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||||
|
batch = CausalLMBatch.from_pb(
|
||||||
|
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||||
|
)
|
||||||
|
next_batch = batch
|
||||||
|
|
||||||
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
|
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
|
assert generated_texts == []
|
||||||
|
|
||||||
|
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
|
assert next_batch is None
|
||||||
|
|
||||||
|
assert len(generated_texts) == 1
|
||||||
|
assert generated_texts[0].output_text == "def test_get_all_users_with_"
|
||||||
|
assert generated_texts[0].request == batch.requests[0]
|
||||||
|
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||||
|
assert (
|
||||||
|
generated_texts[0].generated_tokens
|
||||||
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fim_santacoder_generate_token_completion(
|
||||||
|
default_santacoder, default_fim_pb_batch
|
||||||
|
):
|
||||||
|
batch = CausalLMBatch.from_pb(
|
||||||
|
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||||
|
)
|
||||||
|
next_batch = batch
|
||||||
|
|
||||||
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
|
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
|
assert generated_texts == []
|
||||||
|
|
||||||
|
generated_texts, next_batch = default_santacoder.generate_token(next_batch)
|
||||||
|
assert next_batch is None
|
||||||
|
|
||||||
|
assert len(generated_texts) == 1
|
||||||
|
assert (
|
||||||
|
generated_texts[0].output_text
|
||||||
|
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||||
|
)
|
||||||
|
assert generated_texts[0].request == batch.requests[0]
|
||||||
|
assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs)
|
||||||
|
assert (
|
||||||
|
generated_texts[0].generated_tokens
|
||||||
|
== batch.stopping_criterias[0].max_new_tokens
|
||||||
|
)
|
|
@ -3,10 +3,26 @@ import torch
|
||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mt0_small_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"bigscience/mt0-small", padding_side="left"
|
||||||
|
)
|
||||||
|
tokenizer.bos_token_id = 0
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def default_seq2seq_lm():
|
||||||
|
return Seq2SeqLM("bigscience/mt0-small")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
|
@ -41,11 +57,6 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
|
||||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def default_seq2seq_lm():
|
|
||||||
return Seq2SeqLM("bigscience/mt0-small")
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||||
batch = default_seq2seq_lm_batch
|
batch = default_seq2seq_lm_batch
|
||||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
|
|
|
@ -1,10 +1,28 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
from text_generation.models.model import Model
|
from text_generation.models.model import Model
|
||||||
from text_generation.models.causal_lm import CausalLM
|
from text_generation.models.causal_lm import CausalLM
|
||||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
from text_generation.models.galactica import Galactica, GalacticaSharded
|
||||||
|
from text_generation.models.santacoder import SantaCoder
|
||||||
|
|
||||||
__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM", "get_model"]
|
__all__ = [
|
||||||
|
"Model",
|
||||||
|
"BLOOM",
|
||||||
|
"BLOOMSharded",
|
||||||
|
"CausalLM",
|
||||||
|
"Seq2SeqLM",
|
||||||
|
"SantaCoder",
|
||||||
|
"get_model",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
|
# in PyTorch 1.12 and later.
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||||
|
@ -18,6 +36,8 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||||
return GalacticaSharded(model_name, quantize=quantize)
|
return GalacticaSharded(model_name, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return Galactica(model_name, quantize=quantize)
|
return Galactica(model_name, quantize=quantize)
|
||||||
|
elif "santacoder" in model_name:
|
||||||
|
return SantaCoder(model_name, quantize)
|
||||||
else:
|
else:
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
|
|
@ -5,7 +5,12 @@ from typing import List, Optional, Type
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoConfig,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
from transformers.models.bloom.parallel_layers import (
|
from transformers.models.bloom.parallel_layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
@ -34,7 +39,10 @@ torch.manual_seed(0)
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||||
pb=pb, tokenizer=tokenizer, device=device
|
pb=pb, tokenizer=tokenizer, device=device
|
||||||
|
@ -70,13 +78,6 @@ class BLOOMSharded(BLOOM):
|
||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
||||||
# in PyTorch 1.12 and later.
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
# Only download weights for small models
|
# Only download weights for small models
|
||||||
if self.master and model_name == "bigscience/bloom-560m":
|
if self.master and model_name == "bigscience/bloom-560m":
|
||||||
download_weights(model_name, extension=".safetensors")
|
download_weights(model_name, extension=".safetensors")
|
||||||
|
|
|
@ -47,7 +47,10 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
|
@ -71,6 +74,7 @@ class CausalLMBatch(Batch):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -253,6 +257,11 @@ class CausalLM(Model):
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
|
||||||
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, past_key_values: Optional = None
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
@ -338,11 +347,11 @@ class CausalLM(Model):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if stop:
|
if stop:
|
||||||
# Decode all tokens
|
# Decode generated tokens
|
||||||
output_text = self.tokenizer.decode(
|
generated_text = self.decode(
|
||||||
all_input_ids.squeeze(-1), skip_special_tokens=True,
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
cleanup_tokenization_spaces=False
|
|
||||||
)
|
)
|
||||||
|
output_text = request.inputs + generated_text
|
||||||
# Slice with input_length to remove padding
|
# Slice with input_length to remove padding
|
||||||
token_ids = all_input_ids[-new_input_length:]
|
token_ids = all_input_ids[-new_input_length:]
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
|
|
|
@ -6,7 +6,12 @@ from typing import List, Optional, Type
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoConfig,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
from transformers.models.opt.parallel_layers import (
|
from transformers.models.opt.parallel_layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
@ -82,7 +87,10 @@ def escape_custom_split_sequence(text):
|
||||||
class GalacticaCausalLMBatch(CausalLMBatch):
|
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
) -> "GalacticaCausalLMBatch":
|
) -> "GalacticaCausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
|
@ -99,8 +107,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Tokenize batch
|
||||||
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -124,6 +138,12 @@ class Galactica(CausalLM):
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return GalacticaCausalLMBatch
|
return GalacticaCausalLMBatch
|
||||||
|
|
||||||
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(self, model_name: str, quantize: bool = False):
|
def __init__(self, model_name: str, quantize: bool = False):
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
from text_generation.models import CausalLM
|
||||||
|
|
||||||
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
|
FIM_MIDDLE = "<fim-middle>"
|
||||||
|
FIM_SUFFIX = "<fim-suffix>"
|
||||||
|
FIM_PAD = "<fim-pad>"
|
||||||
|
EOD = "<|endoftext|>"
|
||||||
|
|
||||||
|
|
||||||
|
class SantaCoder(CausalLM):
|
||||||
|
def __init__(self, model_name: str, quantize=False):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
EOD,
|
||||||
|
FIM_PREFIX,
|
||||||
|
FIM_MIDDLE,
|
||||||
|
FIM_SUFFIX,
|
||||||
|
FIM_PAD,
|
||||||
|
],
|
||||||
|
"pad_token": EOD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
load_in_8bit=quantize,
|
||||||
|
trust_remote_code=True, # required
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
|
||||||
|
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
|
||||||
|
# this leads to position_ids being wrong
|
||||||
|
|
||||||
|
input_length = input_ids.shape[-1]
|
||||||
|
past_key_values_length = (
|
||||||
|
0 if past_key_values is None else past_key_values[0][0].shape[-1]
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
input_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=input_ids.device,
|
||||||
|
).view(1, input_length)
|
||||||
|
|
||||||
|
# Model Forward
|
||||||
|
outputs = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
return outputs.logits, outputs.past_key_values
|
|
@ -51,7 +51,10 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
|
@ -83,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
@ -318,6 +322,9 @@ class Seq2SeqLM(Model):
|
||||||
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
||||||
return Seq2SeqLMBatch
|
return Seq2SeqLMBatch
|
||||||
|
|
||||||
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -438,7 +445,7 @@ class Seq2SeqLM(Model):
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
token_ids = decoder_input_ids[-new_decoder_input_length:]
|
||||||
output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
output_text = self.decode(token_ids)
|
||||||
tokens = self.tokenizer.batch_decode(token_ids)
|
tokens = self.tokenizer.batch_decode(token_ids)
|
||||||
# Add NaN for the bos token
|
# Add NaN for the bos token
|
||||||
logprobs = [float("nan")] + decoder_logprobs[
|
logprobs = [float("nan")] + decoder_logprobs[
|
||||||
|
|
|
@ -17,7 +17,10 @@ class Batch(ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,9 @@ class StoppingCriteria:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase
|
cls,
|
||||||
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
|
Loading…
Reference in New Issue