From 73eb2ae255a8531049b2a4f6a1333b5cc55d8808 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 20:31:27 +0000 Subject: [PATCH] fix: refactor and move changes to v3 proto --- proto/generate.proto | 2 -- proto/v3/generate.proto | 2 ++ router/client/src/v2/client.rs | 1 - router/client/src/v3/client.rs | 1 + router/client/src/v3/sharded_client.rs | 1 + router/src/infer/v2/queue.rs | 1 - router/src/infer/v3/queue.rs | 2 ++ .../models/custom_modeling/flash_llama_modeling.py | 2 +- server/text_generation_server/utils/weights.py | 1 - 9 files changed, 7 insertions(+), 6 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 366a5418..6351e37f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -107,8 +107,6 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; - /// LORA adapter index - optional uint32 adapter_index = 8; } message Batch { diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd..c7a8013b 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -134,6 +134,8 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; + /// LORA adapter index + optional uint32 adapter_index = 11; } message Batch { diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index ff1a70eb..9a2e6ac7 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -154,7 +154,6 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - adapter_index: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb..5ced4056 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -177,6 +177,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, + adapter_index: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55..300decca 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + adapter_index: None, }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index e284b251..f0205697 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -290,7 +290,6 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, - adapter_index: entry.request.adapter_index, }); // Set batch_time entry.batch_time = Some(Instant::now()); diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142a..fbfdf715 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -351,6 +351,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + adapter_index: entry.request.adapter_index, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -491,6 +492,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_index: None, }, response_tx, span: info_span!("entry"), diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c84e2290..436c2f53 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -63,7 +63,7 @@ UP_PROJ = "up_proj" DOWN_PROJ = "down_proj" -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index efede312..4d5fcb25 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -2,7 +2,6 @@ import os from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError -from safetensors.torch import load_file import torch from loguru import logger from huggingface_hub import hf_hub_download