feat(server): clear cache on error (#143)

This commit is contained in:
OlivierDehaene 2023-03-28 11:29:35 +02:00 committed by GitHub
parent 8e8dd984d8
commit f000068944
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 21 additions and 9 deletions

View File

@ -21,8 +21,10 @@ message ServiceDiscoveryResponse {
repeated string urls = 1;
}
/// Empty request
message ClearCacheRequest {}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}

View File

@ -56,8 +56,8 @@ impl Client {
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}

View File

@ -40,11 +40,11 @@ impl ShardedClient {
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache())
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}

View File

@ -330,6 +330,7 @@ async fn prefill(
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
let batch_id = batch.id;
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
@ -340,6 +341,7 @@ async fn prefill(
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None

View File

@ -136,7 +136,7 @@ fn main() -> Result<(), std::io::Error> {
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache()
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");

View File

@ -17,7 +17,9 @@ class Cache:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
del self.cache[batch_id]
batch = self.pop(batch_id)
if batch is not None:
del batch
def clear(self):
self.cache.clear()

View File

@ -1,2 +1,3 @@
*.py
*.pyi
*.py-e

View File

@ -30,7 +30,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
if request.HasField("id"):
self.cache.delete(request.id)
else:
self.cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context):