Taking into account number of shards.
This commit is contained in:
parent
54d3c8157c
commit
fa912440b1
|
@ -1497,14 +1497,15 @@ struct ComputeType {
|
|||
|
||||
impl ComputeType {
|
||||
fn f16_flop(&self) -> Option<u64> {
|
||||
match &self.card[..] {
|
||||
let card_flop = match &self.card[..] {
|
||||
// https://www.nvidia.com/en-us/data-center/l4/
|
||||
"nvidia-l4" => Some(121 * 10u64.pow(12)),
|
||||
card => {
|
||||
tracing::warn!("Unkown compute for card {card}");
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
card_flop.map(|f| f * self.count as u64)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1813,8 +1814,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
None => {
|
||||
// TODO figure out hardware optimal value
|
||||
let compute_type = compute_type(num_shard);
|
||||
tracing::info!("Compute type {compute_type:?}");
|
||||
tracing::info!("Config {config:?}");
|
||||
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||
let default = compute_optimal.unwrap_or(4096);
|
||||
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
|
||||
|
|
Loading…
Reference in New Issue