Adding ctranslate2 quantization and inference: moving the contribution (#1)
* rebaseing the commit on preemo fork. * reformatting and changes. * update dockerfile * update changes for dockerfile * adapt path * rebaseing the commit on preemo fork. * reformatting and changes. * update dockerfile * update changes for dockerfile * adapt path --------- Co-authored-by: michaelfeil <me@michaelfeil.eu>
This commit is contained in:
parent
012c917b6f
commit
ff703cb867
|
@ -188,7 +188,7 @@ COPY server/Makefile server/Makefile
|
|||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install ".[bnb, accelerate, quantize]" --no-cache-dir
|
||||
pip install ".[bnb, accelerate, quantize, ct2]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
|
|
|
@ -25,6 +25,7 @@ enum Quantization {
|
|||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
Ct2,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
|
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
Quantization::Ct2 => {
|
||||
write!(f, "ct2")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -104,7 +108,7 @@ struct Args {
|
|||
num_shard: Option<usize>,
|
||||
|
||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`. 4bit quantization is available through
|
||||
/// quantization on the fly, `bnb` or `gptq`, or `ctranslate2`. 4bit quantization is available through
|
||||
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
|
|
@ -21,7 +21,7 @@ install-torch:
|
|||
install: gen-server install-torch
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements.txt
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
pip install -e ".[bnb, accelerate, quantize, ct2]"
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
|
|
@ -16,6 +16,7 @@ grpcio-reflection = "^1.51.1"
|
|||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.19.0", optional = true }
|
||||
ctranslate2 = { version = "^3.20.0", optional = true }
|
||||
bitsandbytes = { version = "^0.40.0", optional = true }
|
||||
safetensors = "0.3.1"
|
||||
loguru = "^0.6.0"
|
||||
|
@ -35,6 +36,7 @@ datasets = { version = "^2.14.0", optional = true }
|
|||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
quantize = ["texttable", "datasets", "accelerate"]
|
||||
ct2 = ["ctranslate2"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.51.1"
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import pytest
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.ct2_causal_lm import CT2CausalLM
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_santacoder():
|
||||
return CT2CausalLM("bigcode/gpt_bigcode-santacoder", dtype=torch.float16)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="def",
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_batch(default_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
stopping_parameters=default_pb_stop_parameters,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_fim_pb_batch(default_fim_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
|
||||
|
||||
|
||||
def test_ct2santa_generate_token_completion(default_santacoder, default_pb_batch):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text in (" test_get_all_users_with_", ' test_get_all_users(client):')
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
== batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_fim_ct2santacoder_generate_token_completion(
|
||||
default_santacoder, default_fim_pb_batch
|
||||
):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_fim_pb_batch,
|
||||
default_santacoder.tokenizer,
|
||||
default_santacoder.dtype,
|
||||
default_santacoder.device,
|
||||
)
|
||||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== """ineProperty(exports, "__esModule", { value"""
|
||||
)
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
== batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
|
@ -16,6 +16,7 @@ class Quantization(str, Enum):
|
|||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
ct2 = "ct2"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -73,7 +74,7 @@ def serve(
|
|||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
if dtype is not None and quantize is not None:
|
||||
if dtype is not None and quantize is not None and quantize != Quantization.ct2:
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
|
@ -90,6 +91,7 @@ def download_weights(
|
|||
auto_convert: bool = True,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
@ -169,6 +171,7 @@ def download_weights(
|
|||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
architecture = config.architectures[0]
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded
|
|||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.ct2_causal_lm import CT2CausalLM
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
|
@ -75,6 +76,7 @@ def get_model(
|
|||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
dtype_ct2 = dtype
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
elif dtype == "float16":
|
||||
|
@ -84,6 +86,15 @@ def get_model(
|
|||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if quantize is not None and "ct2" in quantize:
|
||||
return CT2CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype_ct2,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
|
|
|
@ -0,0 +1,359 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 Michael Feil.
|
||||
#
|
||||
# This code is loosely based on Huggingface text-generation-inference v0.9.3's causal_lm.py implementation.
|
||||
# While it remains licensed under Apache License, Version 2.0,
|
||||
# text-generation-inference itself on 7/28/2023 has changed its license.
|
||||
# This code remains unaffected by this change.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
|
||||
from text_generation_server.utils import Sampling
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
|
||||
try:
|
||||
import ctranslate2
|
||||
except ImportError:
|
||||
ctranslate2 = None
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class CT2CausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if ctranslate2 is None:
|
||||
raise ValueError(
|
||||
"for quantization with ct2, the installation requires the pip package ctranslate2. "
|
||||
"install via `text-generation-server[ct2]` or `pip install ctranslate2` is required.",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
ct2_generator_kwargs = {
|
||||
"inter_threads": int(os.environ.get("TGI_CT2_INTER_THREADS", 1))
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
self.ct2_device = "cuda"
|
||||
ct2_generator_kwargs["intra_threads"] = int(
|
||||
os.environ.get("TGI_CT2_INTRA_THREADS", 1)
|
||||
)
|
||||
else:
|
||||
self.ct2_device = "cpu"
|
||||
ct2_generator_kwargs["intra_threads"] = int(
|
||||
os.environ.get(
|
||||
"TGI_CT2_INTRA_THREADS", multiprocessing.cpu_count() // 2
|
||||
)
|
||||
)
|
||||
|
||||
if dtype == torch.float16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "float16"
|
||||
elif dtype == torch.bfloat16 and self.ct2_device == "cuda":
|
||||
ct2_compute_type = "bfloat16"
|
||||
elif self.ct2_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
|
||||
# float16 is not available on CPU
|
||||
# and int16 has no stable implementation
|
||||
ct2_compute_type = "float32"
|
||||
else:
|
||||
# default, int8 quantization.
|
||||
|
||||
if "cuda" in self.ct2_device:
|
||||
# int8 for int8 layers, float16 for non-quantized layers
|
||||
ct2_compute_type = "int8_float16"
|
||||
else:
|
||||
# int8 for int8 layers, float32 for non-quantized layers
|
||||
ct2_compute_type = "int8"
|
||||
|
||||
# Start CT2 - conversion
|
||||
out_dir = (
|
||||
Path(HUGGINGFACE_HUB_CACHE)
|
||||
/ "ct2models" / f"{model_id.replace('/','--')}--{ct2_compute_type}"
|
||||
)
|
||||
|
||||
if not os.path.exists(out_dir / "model.bin"):
|
||||
try:
|
||||
converter = ctranslate2.converters.TransformersConverter(
|
||||
model_id,
|
||||
activation_scales=None,
|
||||
load_as_float16=ct2_compute_type != "bfloat16",
|
||||
revision=revision,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
converter.convert(
|
||||
output_dir=out_dir,
|
||||
vmap=None,
|
||||
quantization=ct2_compute_type,
|
||||
force=True,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"conversion with ctranslate2 for {model_id} failed : Error {ex}"
|
||||
)
|
||||
if not os.path.exists(out_dir / "model.bin"):
|
||||
raise ValueError(
|
||||
f"no ctranslate2 model for {model_id} found after conversion in {out_dir}"
|
||||
)
|
||||
|
||||
# Start CT2
|
||||
self.ct2_model = ctranslate2.Generator(
|
||||
str(out_dir),
|
||||
device=self.ct2_device,
|
||||
compute_type=ct2_compute_type,
|
||||
**ct2_generator_kwargs,
|
||||
)
|
||||
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
model = DummyModel()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=torch.int8 if "int8" in ct2_compute_type else torch.float16,
|
||||
device=torch.device(self.ct2_device),
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward_ct2(
|
||||
self,
|
||||
all_input_ids,
|
||||
input_lengths,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# CT2 forward requires a list of list of input tokens ids and lengths
|
||||
ids_input = (
|
||||
torch.nested.to_padded_tensor(
|
||||
torch.nested.nested_tensor(all_input_ids), 1234567
|
||||
)
|
||||
.flatten(1)
|
||||
.to(torch.int32)
|
||||
)
|
||||
# lengths of the padded ids_input, i.e. how often not pad=1234567 is used.
|
||||
lengths = np.array(input_lengths, dtype=np.int32)
|
||||
|
||||
if self.ct2_device == "cuda":
|
||||
lengths = torch.from_numpy(lengths).to(self.ct2_device)
|
||||
elif self.ct2_device == "cpu":
|
||||
ids_input = ids_input.numpy()
|
||||
|
||||
ids_input = ctranslate2.StorageView.from_array(ids_input)
|
||||
lengths = ctranslate2.StorageView.from_array(lengths)
|
||||
# now, forward through the network
|
||||
logits = self.ct2_model.forward_batch(ids_input, lengths)
|
||||
|
||||
# continue with logits as torch tensor
|
||||
if self.ct2_device == "cpu":
|
||||
# logits is a float32 torch cpu tensor
|
||||
logits = torch.from_numpy(np.asarray(logits))
|
||||
else:
|
||||
# logits is a float16 torch cuda tensor
|
||||
logits = torch.as_tensor(logits, device=self.ct2_device)
|
||||
return logits, None
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
logits, past = self.forward_ct2(batch.all_input_ids, batch.input_lengths)
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids[:, 0], prefix_offset, read_offset
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
|
||||
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||
logits, -1
|
||||
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
return generations, None
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
# Update position_ids
|
||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
||||
|
||||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
|
||||
return generations, batch
|
|
@ -42,7 +42,7 @@ class StaticWarper:
|
|||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if torch.cuda.is_available():
|
||||
if scores.device.type == "cuda":
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
|
Loading…
Reference in New Issue