fix: refactor and move changes to v3 proto
This commit is contained in:
parent
c927376725
commit
73eb2ae255
|
@ -107,8 +107,6 @@ message Request {
|
|||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// LORA adapter index
|
||||
optional uint32 adapter_index = 8;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -134,6 +134,8 @@ message Request {
|
|||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional uint32 adapter_index = 11;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -154,7 +154,6 @@ impl Client {
|
|||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_index: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
|
|
@ -177,6 +177,7 @@ impl Client {
|
|||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_index: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
|
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
adapter_index: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
|
|
@ -290,7 +290,6 @@ impl State {
|
|||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
adapter_index: entry.request.adapter_index,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
|
|
@ -351,6 +351,7 @@ impl State {
|
|||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
adapter_index: entry.request.adapter_index,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -491,6 +492,7 @@ mod tests {
|
|||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_index: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
|
|
@ -63,7 +63,7 @@ UP_PROJ = "up_proj"
|
|||
DOWN_PROJ = "down_proj"
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
from safetensors.torch import load_file
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
|
Loading…
Reference in New Issue