Taking into account number of shards.
This commit is contained in:
parent
54d3c8157c
commit
fa912440b1
|
@ -1497,14 +1497,15 @@ struct ComputeType {
|
||||||
|
|
||||||
impl ComputeType {
|
impl ComputeType {
|
||||||
fn f16_flop(&self) -> Option<u64> {
|
fn f16_flop(&self) -> Option<u64> {
|
||||||
match &self.card[..] {
|
let card_flop = match &self.card[..] {
|
||||||
// https://www.nvidia.com/en-us/data-center/l4/
|
// https://www.nvidia.com/en-us/data-center/l4/
|
||||||
"nvidia-l4" => Some(121 * 10u64.pow(12)),
|
"nvidia-l4" => Some(121 * 10u64.pow(12)),
|
||||||
card => {
|
card => {
|
||||||
tracing::warn!("Unkown compute for card {card}");
|
tracing::warn!("Unkown compute for card {card}");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
card_flop.map(|f| f * self.count as u64)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1813,8 +1814,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
None => {
|
None => {
|
||||||
// TODO figure out hardware optimal value
|
// TODO figure out hardware optimal value
|
||||||
let compute_type = compute_type(num_shard);
|
let compute_type = compute_type(num_shard);
|
||||||
tracing::info!("Compute type {compute_type:?}");
|
|
||||||
tracing::info!("Config {config:?}");
|
|
||||||
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||||
let default = compute_optimal.unwrap_or(4096);
|
let default = compute_optimal.unwrap_or(4096);
|
||||||
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
|
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
|
||||||
|
|
Loading…
Reference in New Issue