diff --git a/proto/generate.proto b/proto/generate.proto index 5081ce1..0bac435 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -21,8 +21,10 @@ message ServiceDiscoveryResponse { repeated string urls = 1; } -/// Empty request -message ClearCacheRequest {} +message ClearCacheRequest { + /// Optional batch id + optional uint64 id = 1; +} /// Empty response message ClearCacheResponse {} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 1f0d23f..1b2086a 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -56,8 +56,8 @@ impl Client { /// Clear the past generations cache #[instrument(skip(self))] - pub async fn clear_cache(&mut self) -> Result<()> { - let request = tonic::Request::new(ClearCacheRequest {}).inject_context(); + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); self.stub.clear_cache(request).await?; Ok(()) } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 2e662ca..7f0ec6f 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -40,11 +40,11 @@ impl ShardedClient { /// Clear the past generations cache #[instrument(skip(self))] - pub async fn clear_cache(&mut self) -> Result<()> { + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| client.clear_cache()) + .map(|client| client.clear_cache(batch_id)) .collect(); join_all(futures).await.into_iter().collect() } diff --git a/router/src/infer.rs b/router/src/infer.rs index ae151d8..5eafc3e 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -330,6 +330,7 @@ async fn prefill( entries: &mut IntMap, ) -> Option { let start_time = Instant::now(); + let batch_id = batch.id; match client.prefill(batch).await { Ok((generations, next_batch)) => { @@ -340,6 +341,7 @@ async fn prefill( } // If we have an error, we discard the whole batch Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); None diff --git a/router/src/main.rs b/router/src/main.rs index 2ccf66b..81c6aee 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -136,7 +136,7 @@ fn main() -> Result<(), std::io::Error> { .expect("Could not connect to server"); // Clear the cache; useful if the webserver rebooted sharded_client - .clear_cache() + .clear_cache(None) .await .expect("Unable to clear cache"); tracing::info!("Connected"); diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 72dc485..5556529 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -17,7 +17,9 @@ class Cache: self.cache[entry.batch_id] = entry def delete(self, batch_id: int): - del self.cache[batch_id] + batch = self.pop(batch_id) + if batch is not None: + del batch def clear(self): self.cache.clear() diff --git a/server/text_generation_server/pb/.gitignore b/server/text_generation_server/pb/.gitignore index 8527ad1..2621a19 100644 --- a/server/text_generation_server/pb/.gitignore +++ b/server/text_generation_server/pb/.gitignore @@ -1,2 +1,3 @@ *.py +*.pyi *.py-e \ No newline at end of file diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 0b75c3c..3e3789b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,7 +30,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - self.cache.clear() + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return generate_pb2.ClearCacheResponse() async def Prefill(self, request, context):