fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher
This commit is contained in:
Nick Hill 2023-01-17 00:10:22 -08:00 committed by GitHub
parent fcc2c5fcbf
commit e6d3eb5d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 46 additions and 89 deletions

View File

@ -289,36 +289,25 @@ fn shard_manager(
}
let mut env = vec![
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
(
"WORLD_SIZE".parse().unwrap(),
world_size.to_string().parse().unwrap(),
),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
(
"SAFETENSORS_FAST_GPU".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
("RANK".into(), rank.to_string().into()),
("WORLD_SIZE".into(), world_size.to_string().into()),
("MASTER_ADDR".into(), master_addr.into()),
("MASTER_PORT".into(), master_port.to_string().into()),
("SAFETENSORS_FAST_GPU".into(), "1".into()),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
env.push((
"HUGGINGFACE_HUB_CACHE".parse().unwrap(),
huggingface_hub_cache.parse().unwrap(),
"HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
));
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push((
"CUDA_VISIBLE_DEVICES".parse().unwrap(),
cuda_visible_devices.parse().unwrap(),
"CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
));
};

View File

@ -74,10 +74,9 @@ impl Batcher {
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
match response_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(err) => Err(InferError::GenerationError(err.to_string())),
}
response_rx.await.unwrap().map_err(
|err| InferError::GenerationError(err.to_string())
)
}
}

View File

@ -23,5 +23,5 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_name, quantize=quantize)
except Exception as e:
except Exception:
return Seq2SeqLM(model_name, quantize=quantize)

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -34,7 +34,7 @@ torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device
@ -203,9 +203,7 @@ class BLOOMSharded(BLOOM):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = torch.empty(
size_out, device=input.device, dtype=input.dtype
)
out = input.new_empty(size_out)
out = bnb.matmul(
input,
weight,

View File

@ -1,17 +1,17 @@
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass
class CausalLMBatch:
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
@ -38,7 +38,7 @@ class CausalLMBatch:
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self):
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
@ -47,7 +47,7 @@ class CausalLMBatch:
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
@ -130,20 +130,14 @@ class CausalLMBatch:
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = torch.zeros(
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
# We need to slice the attention mask to remove padding from previous steps
@ -171,8 +165,8 @@ class CausalLMBatch:
if batch.keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
# seq_length is last for BLOOM
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
@ -182,16 +176,8 @@ class CausalLMBatch:
# This will run only once per layer
if j == len(past_key_values):
padded_past_keys = torch.zeros(
padded_past_keys_shape,
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
padded_past_values_shape,
dtype=past_values.dtype,
device=past_values.device,
)
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
padded_past_values = past_values.new_zeros(padded_past_values_shape)
past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches

View File

@ -6,7 +6,7 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -82,7 +82,7 @@ def escape_custom_split_sequence(text):
class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "GalacticaCausalLMBatch":
inputs = []
next_token_choosers = []
@ -278,9 +278,7 @@ class GalacticaSharded(Galactica):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = torch.empty(
size_out, device=input.device, dtype=input.dtype
)
out = input.new_empty(size_out)
out = bnb.matmul(
input,
weight,

View File

@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
@ -10,7 +10,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: Tokenizer, device: torch.device):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.device = device

View File

@ -1,17 +1,17 @@
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText
from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass
class Seq2SeqLMBatch:
class Seq2SeqLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
@ -41,7 +41,7 @@ class Seq2SeqLMBatch:
max_input_length: int
max_decoder_input_length: int
def to_pb(self):
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
@ -51,7 +51,7 @@ class Seq2SeqLMBatch:
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
@ -155,10 +155,8 @@ class Seq2SeqLMBatch:
# Create padded tensor
if input_ids is None:
input_ids = torch.zeros(
input_ids = batch.input_ids.new_zeros(
(total_batch_size, max_input_length),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
# Copy to correct indices
input_ids[
@ -167,10 +165,8 @@ class Seq2SeqLMBatch:
# Create padded tensor
if attention_mask is None:
attention_mask = torch.zeros(
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
# Copy to correct indices
attention_mask[
@ -179,10 +175,8 @@ class Seq2SeqLMBatch:
# Create padded tensor
if decoder_input_ids is None:
decoder_input_ids = torch.zeros(
decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.decoder_input_ids.dtype,
device=batch.decoder_input_ids.device,
)
# Copy to correct indices
decoder_input_ids[
@ -191,10 +185,9 @@ class Seq2SeqLMBatch:
# Create padded tensor
if decoder_attention_mask is None:
decoder_attention_mask = torch.zeros(
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
@ -210,14 +203,12 @@ class Seq2SeqLMBatch:
# Create padded tensor
if encoder_last_hidden_state is None:
encoder_last_hidden_state = torch.zeros(
encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
(
total_batch_size,
max_input_length,
batch.encoder_last_hidden_state.shape[-1],
),
dtype=batch.encoder_last_hidden_state.dtype,
device=batch.encoder_last_hidden_state.device,
)
# Copy to correct indices
@ -245,9 +236,7 @@ class Seq2SeqLMBatch:
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
past_key_values[j].append(t.new_zeros(padded_t_shape))
# We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][
@ -271,9 +260,7 @@ class Seq2SeqLMBatch:
# Initialize tensors
# This will run only once per layer and per past tensor
if idx == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
past_key_values[j].append(t.new_zeros(padded_t_shape))
past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length :, :

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
@ -17,7 +17,7 @@ class Batch(ABC):
@classmethod
@abstractmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "Batch":
raise NotImplementedError

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
@ -114,7 +114,7 @@ class StoppingCriteria:
@classmethod
def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences