feat(server): clear cache on error (#143)
This commit is contained in:
parent
8e8dd984d8
commit
f000068944
|
@ -21,8 +21,10 @@ message ServiceDiscoveryResponse {
|
|||
repeated string urls = 1;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
message ClearCacheRequest {}
|
||||
message ClearCacheRequest {
|
||||
/// Optional batch id
|
||||
optional uint64 id = 1;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
|
|
@ -56,8 +56,8 @@ impl Client {
|
|||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -40,11 +40,11 @@ impl ShardedClient {
|
|||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache())
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
|
|
@ -330,6 +330,7 @@ async fn prefill(
|
|||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
|
@ -340,6 +341,7 @@ async fn prefill(
|
|||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
None
|
||||
|
|
|
@ -136,7 +136,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.expect("Could not connect to server");
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache()
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
|
|
@ -17,7 +17,9 @@ class Cache:
|
|||
self.cache[entry.batch_id] = entry
|
||||
|
||||
def delete(self, batch_id: int):
|
||||
del self.cache[batch_id]
|
||||
batch = self.pop(batch_id)
|
||||
if batch is not None:
|
||||
del batch
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
*.py
|
||||
*.pyi
|
||||
*.py-e
|
|
@ -30,7 +30,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
||||
async def ClearCache(self, request, context):
|
||||
self.cache.clear()
|
||||
if request.HasField("id"):
|
||||
self.cache.delete(request.id)
|
||||
else:
|
||||
self.cache.clear()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
|
|
Loading…
Reference in New Issue