2023-05-30 10:25:19 -06:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from opentelemetry import trace
|
2023-06-08 06:51:52 -06:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from typing import Optional
|
2023-05-30 10:25:19 -06:00
|
|
|
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
|
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
|
|
RWConfig,
|
|
|
|
FlashRWForCausalLM,
|
|
|
|
)
|
|
|
|
from text_generation_server.utils import (
|
|
|
|
initialize_torch_distributed,
|
|
|
|
weight_files,
|
2023-06-08 06:51:52 -06:00
|
|
|
Weights,
|
2023-05-30 10:25:19 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
class FlashRWSharded(FlashCausalLM):
|
2023-05-30 10:25:19 -06:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2023-06-30 12:30:09 -06:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
2023-05-30 10:25:19 -06:00
|
|
|
trust_remote_code: bool = False,
|
|
|
|
):
|
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
2023-06-30 12:30:09 -06:00
|
|
|
dtype = torch.float16 if dtype is None else dtype
|
2023-05-30 10:25:19 -06:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashRW is only available on GPU")
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
|
|
|
config = RWConfig.from_pretrained(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
2023-07-12 01:51:34 -06:00
|
|
|
weights = Weights(
|
|
|
|
filenames,
|
|
|
|
device,
|
|
|
|
dtype,
|
|
|
|
process_group=self.process_group,
|
|
|
|
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
|
|
|
)
|
2023-05-30 10:25:19 -06:00
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
config.quantize = quantize
|
|
|
|
|
|
|
|
model = FlashRWForCausalLM(config, weights)
|
2023-05-30 10:25:19 -06:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
2023-06-30 11:09:59 -06:00
|
|
|
super(FlashRWSharded, self).__init__(
|
2023-05-30 10:25:19 -06:00
|
|
|
model=model.to(device),
|
|
|
|
tokenizer=tokenizer,
|
2023-06-30 11:09:59 -06:00
|
|
|
num_layers=len(model.transformer.h),
|
|
|
|
num_kv_heads=model.transformer.cache_size,
|
|
|
|
head_size=model.transformer.head_size,
|
2023-05-30 10:25:19 -06:00
|
|
|
dtype=dtype,
|
|
|
|
device=device,
|
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
|
|
|
)
|