feat(router): add device and dtype info (#215)
This commit is contained in:
parent
ac8c0f6fe4
commit
343437c7b5
|
@ -3,6 +3,8 @@ syntax = "proto3";
|
||||||
package generate.v1;
|
package generate.v1;
|
||||||
|
|
||||||
service TextGenerationService {
|
service TextGenerationService {
|
||||||
|
/// Model Info
|
||||||
|
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||||
/// Service discovery
|
/// Service discovery
|
||||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||||
/// Empties batch cache
|
/// Empties batch cache
|
||||||
|
@ -13,6 +15,15 @@ service TextGenerationService {
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Empty request
|
||||||
|
message InfoRequest {}
|
||||||
|
|
||||||
|
message InfoResponse {
|
||||||
|
bool requires_padding = 1;
|
||||||
|
string dtype = 2;
|
||||||
|
string device_type = 3;
|
||||||
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
message ServiceDiscoveryRequest {}
|
message ServiceDiscoveryRequest {}
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,14 @@ impl Client {
|
||||||
Ok(urls)
|
Ok(urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
|
|
@ -6,6 +6,7 @@ mod pb;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||||
Request, StoppingCriteriaParameters,
|
Request, StoppingCriteriaParameters,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, Generation};
|
use crate::{Batch, Client, Generation, ShardInfo};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
@ -37,6 +37,17 @@ impl ShardedClient {
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
|
|
@ -12,7 +12,7 @@ use validation::Validation;
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct ModelInfo {
|
pub struct HubModelInfo {
|
||||||
#[serde(rename(deserialize = "id"))]
|
#[serde(rename(deserialize = "id"))]
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub sha: Option<String>,
|
pub sha: Option<String>,
|
||||||
|
@ -25,6 +25,10 @@ pub struct Info {
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||||
pub model_sha: Option<String>,
|
pub model_sha: Option<String>,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
#[schema(nullable = true, example = "text-generation")]
|
#[schema(nullable = true, example = "text-generation")]
|
||||||
pub model_pipeline_tag: Option<String>,
|
pub model_pipeline_tag: Option<String>,
|
||||||
#[schema(example = "0.5.0")]
|
#[schema(example = "0.5.0")]
|
||||||
|
|
|
@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::{server, ModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
|
|
||||||
// Get Model info
|
// Get Model info
|
||||||
let model_info = match local_model {
|
let model_info = match local_model {
|
||||||
true => ModelInfo {
|
true => HubModelInfo {
|
||||||
model_id: tokenizer_name.clone(),
|
model_id: tokenizer_name.clone(),
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
|
@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client
|
||||||
|
.info()
|
||||||
|
.await
|
||||||
|
.expect("Unable to get shard info");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Binds on localhost
|
// Binds on localhost
|
||||||
|
@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
model_info,
|
model_info,
|
||||||
|
shard_info,
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
|
@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
/// get model info from the Huggingface Hub
|
||||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> ModelInfo {
|
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let mut builder = client.get(format!(
|
let mut builder = client.get(format!(
|
||||||
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
|
||||||
|
|
|
@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, Infer, Info, ModelInfo, PrefillToken,
|
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
|
||||||
StreamDetails, StreamResponse, Token, Validation,
|
StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
|
@ -18,7 +18,7 @@ use futures::Stream;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -78,13 +78,19 @@ async fn compat_generate(
|
||||||
responses((status = 200, description = "Served model info", body = Info))
|
responses((status = 200, description = "Served model info", body = Info))
|
||||||
)]
|
)]
|
||||||
#[instrument]
|
#[instrument]
|
||||||
async fn get_model_info(model_info: Extension<ModelInfo>) -> Json<Info> {
|
async fn get_model_info(
|
||||||
|
model_info: Extension<HubModelInfo>,
|
||||||
|
shard_info: Extension<ShardInfo>,
|
||||||
|
) -> Json<Info> {
|
||||||
let model_info = model_info.0;
|
let model_info = model_info.0;
|
||||||
|
let shard_info = shard_info.0;
|
||||||
let info = Info {
|
let info = Info {
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
sha: option_env!("VERGEN_GIT_SHA"),
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
model_id: model_info.model_id,
|
model_id: model_info.model_id,
|
||||||
model_sha: model_info.sha,
|
model_sha: model_info.sha,
|
||||||
|
model_dtype: shard_info.dtype,
|
||||||
|
model_device_type: shard_info.device_type,
|
||||||
model_pipeline_tag: model_info.pipeline_tag,
|
model_pipeline_tag: model_info.pipeline_tag,
|
||||||
};
|
};
|
||||||
Json(info)
|
Json(info)
|
||||||
|
@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
model_info: ModelInfo,
|
model_info: HubModelInfo,
|
||||||
|
shard_info: ShardInfo,
|
||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
|
@ -641,6 +648,7 @@ pub async fn run(
|
||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(model_info))
|
.layer(Extension(model_info))
|
||||||
|
.layer(Extension(shard_info))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
|
|
|
@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -400,7 +400,11 @@ class CausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=decode_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -343,7 +343,11 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=decode_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM):
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
self.model = model.eval().to(device)
|
self.model = model.eval().to(device)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -228,6 +228,8 @@ class GalacticaSharded(Galactica):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
@ -13,6 +14,8 @@ class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
requires_padding: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
|
@ -21,9 +24,19 @@ class Model(ABC):
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
self.requires_padding = requires_padding
|
||||||
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.decode_buffer = decode_buffer
|
self.decode_buffer = decode_buffer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self) -> InfoResponse:
|
||||||
|
return InfoResponse(
|
||||||
|
requires_padding=self.requires_padding,
|
||||||
|
dtype=str(self.dtype),
|
||||||
|
device_type=self.device.type,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def batch_type(self) -> Type[B]:
|
def batch_type(self) -> Type[B]:
|
||||||
|
|
|
@ -88,6 +88,8 @@ class OPTSharded(OPT):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,11 @@ class SantaCoder(CausalLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
|
|
@ -460,7 +460,11 @@ class Seq2SeqLM(Model):
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
decode_buffer=decode_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
|
async def Info(self, request, context):
|
||||||
|
return self.model.info
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue