Adding A100 compute. (#2806)
This commit is contained in:
parent
5df8059037
commit
d96dcb1797
|
@ -172,7 +172,9 @@ struct RawConfig {
|
|||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
#[serde(rename = "num_experts_per_tok")]
|
||||
experts: Option<usize>,
|
||||
num_experts_per_token: Option<usize>,
|
||||
#[serde(rename = "n_shared_experts")]
|
||||
num_shared_experts: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@ -196,7 +198,8 @@ struct Config {
|
|||
model_type: Option<String>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
experts: Option<usize>,
|
||||
num_experts_per_token: usize,
|
||||
num_shared_experts: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
@ -210,11 +213,9 @@ impl Config {
|
|||
let num_kv_heads = self.num_kv_heads? as u64;
|
||||
let head_dim = self.head_dim? as u64;
|
||||
let hidden_size = self.hidden_size? as u64;
|
||||
let intermediate_size = if let Some(experts) = self.experts {
|
||||
(self.intermediate_size? * experts) as u64
|
||||
} else {
|
||||
self.intermediate_size? as u64
|
||||
};
|
||||
let intermediate_size = (self.intermediate_size?
|
||||
* (self.num_experts_per_token + self.num_shared_experts))
|
||||
as u64;
|
||||
let num_layers = self.num_layers? as u64;
|
||||
|
||||
let q_flops = 2 * num_heads * head_dim * hidden_size;
|
||||
|
@ -257,7 +258,8 @@ impl From<RawConfig> for Config {
|
|||
let model_type = other.model_type;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
let experts = other.experts;
|
||||
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
|
||||
let num_shared_experts = other.num_shared_experts.unwrap_or(0);
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
|
@ -270,7 +272,8 @@ impl From<RawConfig> for Config {
|
|||
num_kv_heads,
|
||||
intermediate_size,
|
||||
num_layers,
|
||||
experts,
|
||||
num_experts_per_token,
|
||||
num_shared_experts,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1547,6 +1550,7 @@ impl ComputeType {
|
|||
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
|
||||
"nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)),
|
||||
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||
"nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)),
|
||||
"nvidia-a100" => Some(312 * 10u64.pow(12)),
|
||||
card => {
|
||||
tracing::warn!("Unkown compute for card {card}");
|
||||
|
|
Loading…
Reference in New Issue