From fa912440b1dc6d40044e31c4177dd41473a3820d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 4 Nov 2024 11:15:47 +0100 Subject: [PATCH] Taking into account number of shards. --- launcher/src/main.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a68d99e0..8a9aa78b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1497,14 +1497,15 @@ struct ComputeType { impl ComputeType { fn f16_flop(&self) -> Option { - match &self.card[..] { + let card_flop = match &self.card[..] { // https://www.nvidia.com/en-us/data-center/l4/ "nvidia-l4" => Some(121 * 10u64.pow(12)), card => { tracing::warn!("Unkown compute for card {card}"); None } - } + }; + card_flop.map(|f| f * self.count as u64) } } @@ -1813,8 +1814,6 @@ fn main() -> Result<(), LauncherError> { None => { // TODO figure out hardware optimal value let compute_type = compute_type(num_shard); - tracing::info!("Compute type {compute_type:?}"); - tracing::info!("Config {config:?}"); let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); let default = compute_optimal.unwrap_or(4096); let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);