fix(server): Minor refactorization using new_zeros (#24)
- Fix some type hints, in particular base tokenizer class - Make use of `tensor.new_zero/empty` methods - Simplify env var string parsing in launcher
This commit is contained in:
parent
fcc2c5fcbf
commit
e6d3eb5d5d
|
@ -289,36 +289,25 @@ fn shard_manager(
|
|||
}
|
||||
|
||||
let mut env = vec![
|
||||
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
||||
(
|
||||
"WORLD_SIZE".parse().unwrap(),
|
||||
world_size.to_string().parse().unwrap(),
|
||||
),
|
||||
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
|
||||
(
|
||||
"MASTER_PORT".parse().unwrap(),
|
||||
master_port.to_string().parse().unwrap(),
|
||||
),
|
||||
(
|
||||
"SAFETENSORS_FAST_GPU".parse().unwrap(),
|
||||
"1".to_string().parse().unwrap(),
|
||||
),
|
||||
("RANK".into(), rank.to_string().into()),
|
||||
("WORLD_SIZE".into(), world_size.to_string().into()),
|
||||
("MASTER_ADDR".into(), master_addr.into()),
|
||||
("MASTER_PORT".into(), master_port.to_string().into()),
|
||||
("SAFETENSORS_FAST_GPU".into(), "1".into()),
|
||||
];
|
||||
|
||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
||||
env.push((
|
||||
"HUGGINGFACE_HUB_CACHE".parse().unwrap(),
|
||||
huggingface_hub_cache.parse().unwrap(),
|
||||
"HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
|
||||
));
|
||||
};
|
||||
|
||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||
env.push((
|
||||
"CUDA_VISIBLE_DEVICES".parse().unwrap(),
|
||||
cuda_visible_devices.parse().unwrap(),
|
||||
"CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
|
||||
));
|
||||
};
|
||||
|
||||
|
|
|
@ -74,10 +74,9 @@ impl Batcher {
|
|||
|
||||
// Await on the response from the background task
|
||||
// We can safely unwrap as the background task will never drop the sender
|
||||
match response_rx.await.unwrap() {
|
||||
Ok(output) => Ok(output),
|
||||
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
||||
}
|
||||
response_rx.await.unwrap().map_err(
|
||||
|err| InferError::GenerationError(err.to_string())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,5 +23,5 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
|||
raise ValueError("sharded is not supported for AutoModel")
|
||||
try:
|
||||
return CausalLM(model_name, quantize=quantize)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return Seq2SeqLM(model_name, quantize=quantize)
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
|||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
||||
from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -34,7 +34,7 @@ torch.manual_seed(0)
|
|||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||
) -> "CausalLMBatch":
|
||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||
pb=pb, tokenizer=tokenizer, device=device
|
||||
|
@ -203,9 +203,7 @@ class BLOOMSharded(BLOOM):
|
|||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = torch.empty(
|
||||
size_out, device=input.device, dtype=input.dtype
|
||||
)
|
||||
out = input.new_empty(size_out)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMBatch:
|
||||
class CausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
|
||||
|
@ -38,7 +38,7 @@ class CausalLMBatch:
|
|||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
def to_pb(self):
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
|
@ -47,7 +47,7 @@ class CausalLMBatch:
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
|
@ -130,20 +130,14 @@ class CausalLMBatch:
|
|||
# input_ids is always of shape [batch_size, 1]
|
||||
# We do not need to pad it
|
||||
if input_ids is None:
|
||||
input_ids = torch.empty(
|
||||
(total_batch_size, 1),
|
||||
dtype=batch.input_ids.dtype,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||
# Copy to correct indices
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
|
||||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.zeros(
|
||||
attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_sequence_length),
|
||||
dtype=batch.attention_mask.dtype,
|
||||
device=batch.attention_mask.device,
|
||||
)
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
|
@ -171,8 +165,8 @@ class CausalLMBatch:
|
|||
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
# seq_length is last for BLOOM
|
||||
else:
|
||||
# seq_length is last for BLOOM
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
|
@ -182,16 +176,8 @@ class CausalLMBatch:
|
|||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
padded_past_keys = torch.zeros(
|
||||
padded_past_keys_shape,
|
||||
dtype=past_keys.dtype,
|
||||
device=past_keys.device,
|
||||
)
|
||||
padded_past_values = torch.zeros(
|
||||
padded_past_values_shape,
|
||||
dtype=past_values.dtype,
|
||||
device=past_values.device,
|
||||
)
|
||||
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
|
||||
padded_past_values = past_values.new_zeros(padded_past_values_shape)
|
||||
past_key_values.append((padded_past_keys, padded_past_values))
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Optional, Type
|
|||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
||||
from transformers.models.opt.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -82,7 +82,7 @@ def escape_custom_split_sequence(text):
|
|||
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
|
@ -278,9 +278,7 @@ class GalacticaSharded(Galactica):
|
|||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = torch.empty(
|
||||
size_out, device=input.device, dtype=input.dtype
|
||||
)
|
||||
out = input.new_empty(size_out)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.models.types import Batch, GeneratedText
|
||||
|
||||
|
@ -10,7 +10,7 @@ B = TypeVar("B", bound=Batch)
|
|||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(self, tokenizer: Tokenizer, device: torch.device):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqLMBatch:
|
||||
class Seq2SeqLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
|
||||
|
@ -41,7 +41,7 @@ class Seq2SeqLMBatch:
|
|||
max_input_length: int
|
||||
max_decoder_input_length: int
|
||||
|
||||
def to_pb(self):
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
|
@ -51,7 +51,7 @@ class Seq2SeqLMBatch:
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
inputs = []
|
||||
|
@ -155,10 +155,8 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Create padded tensor
|
||||
if input_ids is None:
|
||||
input_ids = torch.zeros(
|
||||
input_ids = batch.input_ids.new_zeros(
|
||||
(total_batch_size, max_input_length),
|
||||
dtype=batch.input_ids.dtype,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
# Copy to correct indices
|
||||
input_ids[
|
||||
|
@ -167,10 +165,8 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Create padded tensor
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.zeros(
|
||||
attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_input_length),
|
||||
dtype=batch.attention_mask.dtype,
|
||||
device=batch.attention_mask.device,
|
||||
)
|
||||
# Copy to correct indices
|
||||
attention_mask[
|
||||
|
@ -179,10 +175,8 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Create padded tensor
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = torch.zeros(
|
||||
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
||||
(total_batch_size, max_decoder_input_length),
|
||||
dtype=batch.decoder_input_ids.dtype,
|
||||
device=batch.decoder_input_ids.device,
|
||||
)
|
||||
# Copy to correct indices
|
||||
decoder_input_ids[
|
||||
|
@ -191,10 +185,9 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Create padded tensor
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = torch.zeros(
|
||||
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
||||
decoder_attention_mask = batch.attention_mask.new_zeros(
|
||||
(total_batch_size, max_decoder_input_length),
|
||||
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
|
||||
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
|
||||
)
|
||||
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||
|
@ -210,14 +203,12 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Create padded tensor
|
||||
if encoder_last_hidden_state is None:
|
||||
encoder_last_hidden_state = torch.zeros(
|
||||
encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
max_input_length,
|
||||
batch.encoder_last_hidden_state.shape[-1],
|
||||
),
|
||||
dtype=batch.encoder_last_hidden_state.dtype,
|
||||
device=batch.encoder_last_hidden_state.device,
|
||||
)
|
||||
|
||||
# Copy to correct indices
|
||||
|
@ -245,9 +236,7 @@ class Seq2SeqLMBatch:
|
|||
# Initialize tensors
|
||||
# This will run only once per layer and per past tensor
|
||||
if k == len(past_key_values[j]):
|
||||
past_key_values[j].append(
|
||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
||||
)
|
||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_key_values[j][k][
|
||||
|
@ -271,9 +260,7 @@ class Seq2SeqLMBatch:
|
|||
# Initialize tensors
|
||||
# This will run only once per layer and per past tensor
|
||||
if idx == len(past_key_values[j]):
|
||||
past_key_values[j].append(
|
||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
||||
)
|
||||
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||
|
||||
past_key_values[j][idx][
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
|
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
|
||||
|
@ -17,7 +17,7 @@ class Batch(ABC):
|
|||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||
) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
|||
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
|
@ -114,7 +114,7 @@ class StoppingCriteria:
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
|
|
Loading…
Reference in New Issue