hf_text-generation-inference/server/text_generation_server/models/flash_mixtral.py

30 lines
823 B
Python
Raw Permalink Normal View History

2023-12-11 06:43:40 -07:00
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
2023-12-11 06:49:52 -07:00
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
2023-12-11 06:43:40 -07:00
class FlashMixtral(BaseFlashMistral):
def __init__(
2023-12-11 06:49:52 -07:00
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
2023-12-11 06:43:40 -07:00
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
2023-12-11 06:49:52 -07:00
trust_remote_code=trust_remote_code,
2023-12-11 06:43:40 -07:00
)