fix: refactor and move changes to v3 proto

This commit is contained in:
drbh 2024-06-06 20:31:27 +00:00
parent c927376725
commit 73eb2ae255
9 changed files with 7 additions and 6 deletions

View File

@ -107,8 +107,6 @@ message Request {
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// LORA adapter index
optional uint32 adapter_index = 8;
} }
message Batch { message Batch {

View File

@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index
optional uint32 adapter_index = 11;
} }
message Batch { message Batch {

View File

@ -154,7 +154,6 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
adapter_index: None,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;

View File

@ -177,6 +177,7 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
adapter_index: None,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
adapter_index: None,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,

View File

@ -290,7 +290,6 @@ impl State {
entry.request.stopping_parameters.clone(), entry.request.stopping_parameters.clone(),
)), )),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
adapter_index: entry.request.adapter_index,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());

View File

@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
adapter_index: entry.request.adapter_index,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_index: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -63,7 +63,7 @@ UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj" DOWN_PROJ = "down_proj"
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)

View File

@ -2,7 +2,6 @@ import os
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
from safetensors.torch import load_file
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download