API: use finally: for state.end()

This commit is contained in:
Aarni Koskela 2023-06-30 13:11:49 +03:00
parent f44feb6a10
commit e430344347
1 changed files with 14 additions and 15 deletions

View File

@ -602,37 +602,35 @@ class Api:
shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}")
finally:
shared.state.end()
def create_hypernetwork(self, args: dict):
try:
shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}")
finally:
shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete')
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e:
shared.state.end()
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e:
finally:
shared.state.end()
return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
@ -649,11 +647,11 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg:
shared.state.end()
except Exception as msg:
return models.TrainResponse(info=f"train embedding error: {msg}")
finally:
shared.state.end()
def train_hypernetwork(self, args: dict):
try:
@ -675,9 +673,10 @@ class Api:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError:
except Exception as exc:
return models.TrainResponse(info=f"train embedding error: {exc}")
finally:
shared.state.end()
return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try: