feat(server): clear cache on error (#143)
This commit is contained in:
parent
8e8dd984d8
commit
f000068944
|
@ -21,8 +21,10 @@ message ServiceDiscoveryResponse {
|
||||||
repeated string urls = 1;
|
repeated string urls = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
message ClearCacheRequest {
|
||||||
message ClearCacheRequest {}
|
/// Optional batch id
|
||||||
|
optional uint64 id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
|
@ -56,8 +56,8 @@ impl Client {
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
self.stub.clear_cache(request).await?;
|
self.stub.clear_cache(request).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,11 +40,11 @@ impl ShardedClient {
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.clear_cache())
|
.map(|client| client.clear_cache(batch_id))
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.into_iter().collect()
|
join_all(futures).await.into_iter().collect()
|
||||||
}
|
}
|
||||||
|
|
|
@ -330,6 +330,7 @@ async fn prefill(
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
|
@ -340,6 +341,7 @@ async fn prefill(
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
None
|
None
|
||||||
|
|
|
@ -136,7 +136,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache()
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
|
@ -17,7 +17,9 @@ class Cache:
|
||||||
self.cache[entry.batch_id] = entry
|
self.cache[entry.batch_id] = entry
|
||||||
|
|
||||||
def delete(self, batch_id: int):
|
def delete(self, batch_id: int):
|
||||||
del self.cache[batch_id]
|
batch = self.pop(batch_id)
|
||||||
|
if batch is not None:
|
||||||
|
del batch
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
*.py
|
*.py
|
||||||
|
*.pyi
|
||||||
*.py-e
|
*.py-e
|
|
@ -30,7 +30,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
async def ClearCache(self, request, context):
|
async def ClearCache(self, request, context):
|
||||||
self.cache.clear()
|
if request.HasField("id"):
|
||||||
|
self.cache.delete(request.id)
|
||||||
|
else:
|
||||||
|
self.cache.clear()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
|
Loading…
Reference in New Issue