import torch import torch.distributed from typing import Optional, Type from transformers import ( PreTrainedTokenizerBase, ) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 class BloomCausalLMBatch(CausalLMBatch): @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch class BLOOMSharded(CausalLM): @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=True, ) logits = outputs.logits return logits, speculative_logits, outputs.past_key_values