Taking into account number of shards.

This commit is contained in:
Nicolas Patry 2024-11-04 11:15:47 +01:00
parent 54d3c8157c
commit fa912440b1
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
1 changed files with 3 additions and 4 deletions

View File

@ -1497,14 +1497,15 @@ struct ComputeType {
impl ComputeType {
fn f16_flop(&self) -> Option<u64> {
match &self.card[..] {
let card_flop = match &self.card[..] {
// https://www.nvidia.com/en-us/data-center/l4/
"nvidia-l4" => Some(121 * 10u64.pow(12)),
card => {
tracing::warn!("Unkown compute for card {card}");
None
}
}
};
card_flop.map(|f| f * self.count as u64)
}
}
@ -1813,8 +1814,6 @@ fn main() -> Result<(), LauncherError> {
None => {
// TODO figure out hardware optimal value
let compute_type = compute_type(num_shard);
tracing::info!("Compute type {compute_type:?}");
tracing::info!("Config {config:?}");
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
let default = compute_optimal.unwrap_or(4096);
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);