Adding A100 compute. (#2806)

This commit is contained in:
Nicolas Patry 2024-12-06 22:49:15 +05:30 committed by GitHub
parent 5df8059037
commit d96dcb1797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 9 deletions

View File

@ -172,7 +172,9 @@ struct RawConfig {
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
#[serde(rename = "num_experts_per_tok")]
experts: Option<usize>,
num_experts_per_token: Option<usize>,
#[serde(rename = "n_shared_experts")]
num_shared_experts: Option<usize>,
}
#[derive(Deserialize)]
@ -196,7 +198,8 @@ struct Config {
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
experts: Option<usize>,
num_experts_per_token: usize,
num_shared_experts: usize,
}
impl Config {
@ -210,11 +213,9 @@ impl Config {
let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64;
let hidden_size = self.hidden_size? as u64;
let intermediate_size = if let Some(experts) = self.experts {
(self.intermediate_size? * experts) as u64
} else {
self.intermediate_size? as u64
};
let intermediate_size = (self.intermediate_size?
* (self.num_experts_per_token + self.num_shared_experts))
as u64;
let num_layers = self.num_layers? as u64;
let q_flops = 2 * num_heads * head_dim * hidden_size;
@ -257,7 +258,8 @@ impl From<RawConfig> for Config {
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
let experts = other.experts;
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
let num_shared_experts = other.num_shared_experts.unwrap_or(0);
Config {
max_position_embeddings,
quantize,
@ -270,7 +272,8 @@ impl From<RawConfig> for Config {
num_kv_heads,
intermediate_size,
num_layers,
experts,
num_experts_per_token,
num_shared_experts,
}
}
}
@ -1547,6 +1550,7 @@ impl ComputeType {
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
"nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)),
"nvidia-a100" => Some(312 * 10u64.pow(12)),
card => {
tracing::warn!("Unkown compute for card {card}");