fix(server): Minor refactorization using new_zeros (#24)
- Fix some type hints, in particular base tokenizer class - Make use of `tensor.new_zero/empty` methods - Simplify env var string parsing in launcher
This commit is contained in:
parent
fcc2c5fcbf
commit
e6d3eb5d5d
|
@ -289,36 +289,25 @@ fn shard_manager(
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut env = vec![
|
let mut env = vec![
|
||||||
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
("RANK".into(), rank.to_string().into()),
|
||||||
(
|
("WORLD_SIZE".into(), world_size.to_string().into()),
|
||||||
"WORLD_SIZE".parse().unwrap(),
|
("MASTER_ADDR".into(), master_addr.into()),
|
||||||
world_size.to_string().parse().unwrap(),
|
("MASTER_PORT".into(), master_port.to_string().into()),
|
||||||
),
|
("SAFETENSORS_FAST_GPU".into(), "1".into()),
|
||||||
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
|
|
||||||
(
|
|
||||||
"MASTER_PORT".parse().unwrap(),
|
|
||||||
master_port.to_string().parse().unwrap(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"SAFETENSORS_FAST_GPU".parse().unwrap(),
|
|
||||||
"1".to_string().parse().unwrap(),
|
|
||||||
),
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
env.push((
|
env.push((
|
||||||
"HUGGINGFACE_HUB_CACHE".parse().unwrap(),
|
"HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(),
|
||||||
huggingface_hub_cache.parse().unwrap(),
|
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
env.push((
|
env.push((
|
||||||
"CUDA_VISIBLE_DEVICES".parse().unwrap(),
|
"CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(),
|
||||||
cuda_visible_devices.parse().unwrap(),
|
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -74,10 +74,9 @@ impl Batcher {
|
||||||
|
|
||||||
// Await on the response from the background task
|
// Await on the response from the background task
|
||||||
// We can safely unwrap as the background task will never drop the sender
|
// We can safely unwrap as the background task will never drop the sender
|
||||||
match response_rx.await.unwrap() {
|
response_rx.await.unwrap().map_err(
|
||||||
Ok(output) => Ok(output),
|
|err| InferError::GenerationError(err.to_string())
|
||||||
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,5 +23,5 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
try:
|
try:
|
||||||
return CausalLM(model_name, quantize=quantize)
|
return CausalLM(model_name, quantize=quantize)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return Seq2SeqLM(model_name, quantize=quantize)
|
return Seq2SeqLM(model_name, quantize=quantize)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
||||||
from transformers.models.bloom.parallel_layers import (
|
from transformers.models.bloom.parallel_layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
@ -34,7 +34,7 @@ torch.manual_seed(0)
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||||
pb=pb, tokenizer=tokenizer, device=device
|
pb=pb, tokenizer=tokenizer, device=device
|
||||||
|
@ -203,9 +203,7 @@ class BLOOMSharded(BLOOM):
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
size_out = input.size()[:-1] + (out_features,)
|
||||||
input = input.view(-1, in_features)
|
input = input.view(-1, in_features)
|
||||||
out = torch.empty(
|
out = input.new_empty(size_out)
|
||||||
size_out, device=input.device, dtype=input.dtype
|
|
||||||
)
|
|
||||||
out = bnb.matmul(
|
out = bnb.matmul(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText
|
from text_generation.models.types import GeneratedText, Batch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatch:
|
class CausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class CausalLMBatch:
|
||||||
# Past metadata
|
# Past metadata
|
||||||
keys_head_dim_last: bool = True
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
def to_pb(self):
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
|
@ -47,7 +47,7 @@ class CausalLMBatch:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
|
@ -130,20 +130,14 @@ class CausalLMBatch:
|
||||||
# input_ids is always of shape [batch_size, 1]
|
# input_ids is always of shape [batch_size, 1]
|
||||||
# We do not need to pad it
|
# We do not need to pad it
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = torch.empty(
|
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||||
(total_batch_size, 1),
|
|
||||||
dtype=batch.input_ids.dtype,
|
|
||||||
device=batch.input_ids.device,
|
|
||||||
)
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_sequence_length),
|
(total_batch_size, max_sequence_length),
|
||||||
dtype=batch.attention_mask.dtype,
|
|
||||||
device=batch.attention_mask.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
|
@ -171,8 +165,8 @@ class CausalLMBatch:
|
||||||
|
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys_shape = padded_past_values_shape
|
padded_past_keys_shape = padded_past_values_shape
|
||||||
# seq_length is last for BLOOM
|
|
||||||
else:
|
else:
|
||||||
|
# seq_length is last for BLOOM
|
||||||
padded_past_keys_shape = (
|
padded_past_keys_shape = (
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
|
@ -182,16 +176,8 @@ class CausalLMBatch:
|
||||||
|
|
||||||
# This will run only once per layer
|
# This will run only once per layer
|
||||||
if j == len(past_key_values):
|
if j == len(past_key_values):
|
||||||
padded_past_keys = torch.zeros(
|
padded_past_keys = past_keys.new_zeros(padded_past_keys_shape)
|
||||||
padded_past_keys_shape,
|
padded_past_values = past_values.new_zeros(padded_past_values_shape)
|
||||||
dtype=past_keys.dtype,
|
|
||||||
device=past_keys.device,
|
|
||||||
)
|
|
||||||
padded_past_values = torch.zeros(
|
|
||||||
padded_past_values_shape,
|
|
||||||
dtype=past_values.dtype,
|
|
||||||
device=past_values.device,
|
|
||||||
)
|
|
||||||
past_key_values.append((padded_past_keys, padded_past_values))
|
past_key_values.append((padded_past_keys, padded_past_values))
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import List, Optional, Type
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
|
||||||
from transformers.models.opt.parallel_layers import (
|
from transformers.models.opt.parallel_layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
@ -82,7 +82,7 @@ def escape_custom_split_sequence(text):
|
||||||
class GalacticaCausalLMBatch(CausalLMBatch):
|
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||||
) -> "GalacticaCausalLMBatch":
|
) -> "GalacticaCausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
|
@ -278,9 +278,7 @@ class GalacticaSharded(Galactica):
|
||||||
def linear(input, weight, bias):
|
def linear(input, weight, bias):
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
size_out = input.size()[:-1] + (out_features,)
|
||||||
input = input.view(-1, in_features)
|
input = input.view(-1, in_features)
|
||||||
out = torch.empty(
|
out = input.new_empty(size_out)
|
||||||
size_out, device=input.device, dtype=input.dtype
|
|
||||||
)
|
|
||||||
out = bnb.matmul(
|
out = bnb.matmul(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from tokenizers import Tokenizer
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.models.types import Batch, GeneratedText
|
from text_generation.models.types import Batch, GeneratedText
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: Tokenizer, device: torch.device):
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText
|
from text_generation.models.types import GeneratedText, Batch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seq2SeqLMBatch:
|
class Seq2SeqLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class Seq2SeqLMBatch:
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_decoder_input_length: int
|
max_decoder_input_length: int
|
||||||
|
|
||||||
def to_pb(self):
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
|
@ -51,7 +51,7 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
|
@ -155,10 +155,8 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = torch.zeros(
|
input_ids = batch.input_ids.new_zeros(
|
||||||
(total_batch_size, max_input_length),
|
(total_batch_size, max_input_length),
|
||||||
dtype=batch.input_ids.dtype,
|
|
||||||
device=batch.input_ids.device,
|
|
||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
input_ids[
|
input_ids[
|
||||||
|
@ -167,10 +165,8 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_input_length),
|
(total_batch_size, max_input_length),
|
||||||
dtype=batch.attention_mask.dtype,
|
|
||||||
device=batch.attention_mask.device,
|
|
||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
attention_mask[
|
attention_mask[
|
||||||
|
@ -179,10 +175,8 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = torch.zeros(
|
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, max_decoder_input_length),
|
||||||
dtype=batch.decoder_input_ids.dtype,
|
|
||||||
device=batch.decoder_input_ids.device,
|
|
||||||
)
|
)
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
decoder_input_ids[
|
decoder_input_ids[
|
||||||
|
@ -191,10 +185,9 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = torch.zeros(
|
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
||||||
|
decoder_attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, max_decoder_input_length),
|
||||||
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
|
|
||||||
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
|
|
||||||
)
|
)
|
||||||
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
|
@ -210,14 +203,12 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if encoder_last_hidden_state is None:
|
if encoder_last_hidden_state is None:
|
||||||
encoder_last_hidden_state = torch.zeros(
|
encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
|
||||||
(
|
(
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
batch.encoder_last_hidden_state.shape[-1],
|
batch.encoder_last_hidden_state.shape[-1],
|
||||||
),
|
),
|
||||||
dtype=batch.encoder_last_hidden_state.dtype,
|
|
||||||
device=batch.encoder_last_hidden_state.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
|
@ -245,9 +236,7 @@ class Seq2SeqLMBatch:
|
||||||
# Initialize tensors
|
# Initialize tensors
|
||||||
# This will run only once per layer and per past tensor
|
# This will run only once per layer and per past tensor
|
||||||
if k == len(past_key_values[j]):
|
if k == len(past_key_values[j]):
|
||||||
past_key_values[j].append(
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
past_key_values[j][k][
|
past_key_values[j][k][
|
||||||
|
@ -271,9 +260,7 @@ class Seq2SeqLMBatch:
|
||||||
# Initialize tensors
|
# Initialize tensors
|
||||||
# This will run only once per layer and per past tensor
|
# This will run only once per layer and per past tensor
|
||||||
if idx == len(past_key_values[j]):
|
if idx == len(past_key_values[j]):
|
||||||
past_key_values[j].append(
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
||||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_values[j][idx][
|
past_key_values[j][idx][
|
||||||
start_index:end_index, :, -batch.max_input_length :, :
|
start_index:end_index, :, -batch.max_input_length :, :
|
||||||
|
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class Batch(ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from transformers import AutoTokenizer
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
|
@ -114,7 +114,7 @@ class StoppingCriteria:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
|
Loading…
Reference in New Issue