fix: refactor and move changes to v3 proto

This commit is contained in:
drbh 2024-06-06 20:31:27 +00:00
parent c927376725
commit 73eb2ae255
9 changed files with 7 additions and 6 deletions

View File

@ -107,8 +107,6 @@ message Request {
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// LORA adapter index
optional uint32 adapter_index = 8;
}
message Batch {

View File

@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional uint32 adapter_index = 11;
}
message Batch {

View File

@ -154,7 +154,6 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_index: None,
});
n_tokens += max_input_length;

View File

@ -177,6 +177,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_index: None,
});
n_tokens += max_input_length;

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_index: None,
};
let batch = Batch {
id: u64::MAX,

View File

@ -290,7 +290,6 @@ impl State {
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
adapter_index: entry.request.adapter_index,
});
// Set batch_time
entry.batch_time = Some(Instant::now());

View File

@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
adapter_index: entry.request.adapter_index,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_index: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -63,7 +63,7 @@ UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)

View File

@ -2,7 +2,6 @@ import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
from safetensors.torch import load_file
import torch
from loguru import logger
from huggingface_hub import hf_hub_download