fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher
This commit is contained in:
Nick Hill 2023-01-17 00:10:22 -08:00 committed by GitHub
parent fcc2c5fcbf
commit e6d3eb5d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 46 additions and 89 deletions

View File

@ -289,36 +289,25 @@ fn shard_manager(
} }
let mut env = vec![ let mut env = vec![
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), ("RANK".into(), rank.to_string().into()),
( ("WORLD_SIZE".into(), world_size.to_string().into()),
"WORLD_SIZE".parse().unwrap(), ("MASTER_ADDR".into(), master_addr.into()),
world_size.to_string().parse().unwrap(), ("MASTER_PORT".into(), master_port.to_string().into()),
), ("SAFETENSORS_FAST_GPU".into(), "1".into()),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
(
"SAFETENSORS_FAST_GPU".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
]; ];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
env.push(( env.push((
"HUGGINGFACE_HUB_CACHE".parse().unwrap(), "HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
huggingface_hub_cache.parse().unwrap(),
)); ));
}; };
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(( env.push((
"CUDA_VISIBLE_DEVICES".parse().unwrap(), "CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
cuda_visible_devices.parse().unwrap(),
)); ));
}; };

View File

@ -74,10 +74,9 @@ impl Batcher {
// Await on the response from the background task // Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender // We can safely unwrap as the background task will never drop the sender
match response_rx.await.unwrap() { response_rx.await.unwrap().map_err(
Ok(output) => Ok(output), |err| InferError::GenerationError(err.to_string())
Err(err) => Err(InferError::GenerationError(err.to_string())), )
}
} }
} }

View File

@ -23,5 +23,5 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
try: try:
return CausalLM(model_name, quantize=quantize) return CausalLM(model_name, quantize=quantize)
except Exception as e: except Exception:
return Seq2SeqLM(model_name, quantize=quantize) return Seq2SeqLM(model_name, quantize=quantize)

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
from transformers.models.bloom.parallel_layers import ( from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -34,7 +34,7 @@ torch.manual_seed(0)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb( batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device pb=pb, tokenizer=tokenizer, device=device
@ -203,9 +203,7 @@ class BLOOMSharded(BLOOM):
def linear(input, weight, bias): def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,) size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features) input = input.view(-1, in_features)
out = torch.empty( out = input.new_empty(size_out)
size_out, device=input.device, dtype=input.dtype
)
out = bnb.matmul( out = bnb.matmul(
input, input,
weight, weight,

View File

@ -1,17 +1,17 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass @dataclass
class CausalLMBatch: class CausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
@ -38,7 +38,7 @@ class CausalLMBatch:
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
def to_pb(self): def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -47,7 +47,7 @@ class CausalLMBatch:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -130,20 +130,14 @@ class CausalLMBatch:
# input_ids is always of shape [batch_size, 1] # input_ids is always of shape [batch_size, 1]
# We do not need to pad it # We do not need to pad it
if input_ids is None: if input_ids is None:
input_ids = torch.empty( input_ids = batch.input_ids.new_empty((total_batch_size, 1))
(total_batch_size, 1),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
# Copy to correct indices # Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length), (total_batch_size, max_sequence_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
@ -171,8 +165,8 @@ class CausalLMBatch:
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape padded_past_keys_shape = padded_past_values_shape
# seq_length is last for BLOOM
else: else:
# seq_length is last for BLOOM
padded_past_keys_shape = ( padded_past_keys_shape = (
total_batch_size, total_batch_size,
num_heads, num_heads,
@ -182,16 +176,8 @@ class CausalLMBatch:
# This will run only once per layer # This will run only once per layer
if j == len(past_key_values): if j == len(past_key_values):
padded_past_keys = torch.zeros( padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
padded_past_keys_shape, padded_past_values = past_values.new_zeros(padded_past_values_shape)
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
padded_past_values_shape,
dtype=past_values.dtype,
device=past_values.device,
)
past_key_values.append((padded_past_keys, padded_past_values)) past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches

View File

@ -6,7 +6,7 @@ from typing import List, Optional, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
from transformers.models.opt.parallel_layers import ( from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
@ -82,7 +82,7 @@ def escape_custom_split_sequence(text):
class GalacticaCausalLMBatch(CausalLMBatch): class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "GalacticaCausalLMBatch": ) -> "GalacticaCausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -278,9 +278,7 @@ class GalacticaSharded(Galactica):
def linear(input, weight, bias): def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,) size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features) input = input.view(-1, in_features)
out = torch.empty( out = input.new_empty(size_out)
size_out, device=input.device, dtype=input.dtype
)
out = bnb.matmul( out = bnb.matmul(
input, input,
weight, weight,

View File

@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from tokenizers import Tokenizer from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText from text_generation.models.types import Batch, GeneratedText
@ -10,7 +10,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: Tokenizer, device: torch.device): def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device

View File

@ -1,17 +1,17 @@
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass @dataclass
class Seq2SeqLMBatch: class Seq2SeqLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
@ -41,7 +41,7 @@ class Seq2SeqLMBatch:
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
def to_pb(self): def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
@ -51,7 +51,7 @@ class Seq2SeqLMBatch:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
@ -155,10 +155,8 @@ class Seq2SeqLMBatch:
# Create padded tensor # Create padded tensor
if input_ids is None: if input_ids is None:
input_ids = torch.zeros( input_ids = batch.input_ids.new_zeros(
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
) )
# Copy to correct indices # Copy to correct indices
input_ids[ input_ids[
@ -167,10 +165,8 @@ class Seq2SeqLMBatch:
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
) )
# Copy to correct indices # Copy to correct indices
attention_mask[ attention_mask[
@ -179,10 +175,8 @@ class Seq2SeqLMBatch:
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = torch.zeros( decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length),
dtype=batch.decoder_input_ids.dtype,
device=batch.decoder_input_ids.device,
) )
# Copy to correct indices # Copy to correct indices
decoder_input_ids[ decoder_input_ids[
@ -191,10 +185,9 @@ class Seq2SeqLMBatch:
# Create padded tensor # Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = torch.zeros( # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length),
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
) )
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
@ -210,14 +203,12 @@ class Seq2SeqLMBatch:
# Create padded tensor # Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
encoder_last_hidden_state = torch.zeros( encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
( (
total_batch_size, total_batch_size,
max_input_length, max_input_length,
batch.encoder_last_hidden_state.shape[-1], batch.encoder_last_hidden_state.shape[-1],
), ),
dtype=batch.encoder_last_hidden_state.dtype,
device=batch.encoder_last_hidden_state.device,
) )
# Copy to correct indices # Copy to correct indices
@ -245,9 +236,7 @@ class Seq2SeqLMBatch:
# Initialize tensors # Initialize tensors
# This will run only once per layer and per past tensor # This will run only once per layer and per past tensor
if k == len(past_key_values[j]): if k == len(past_key_values[j]):
past_key_values[j].append( past_key_values[j].append(t.new_zeros(padded_t_shape))
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][ past_key_values[j][k][
@ -271,9 +260,7 @@ class Seq2SeqLMBatch:
# Initialize tensors # Initialize tensors
# This will run only once per layer and per past tensor # This will run only once per layer and per past tensor
if idx == len(past_key_values[j]): if idx == len(past_key_values[j]):
past_key_values[j].append( past_key_values[j].append(t.new_zeros(padded_t_shape))
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
past_key_values[j][idx][ past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length :, : start_index:end_index, :, -batch.max_input_length :, :

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from transformers import AutoTokenizer from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
@ -17,7 +17,7 @@ class Batch(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
) -> "Batch": ) -> "Batch":
raise NotImplementedError raise NotImplementedError

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm from tqdm import tqdm
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from transformers import AutoTokenizer from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -114,7 +114,7 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences