Merge pull request #10803 from klimaleksus/refactoring-for-embedding-merge
Refactor EmbeddingDatabase.register_embedding() to allow unregistering
This commit is contained in:
commit
881de0df38
|
@ -119,16 +119,29 @@ class EmbeddingDatabase:
|
||||||
self.embedding_dirs.clear()
|
self.embedding_dirs.clear()
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
self.word_embeddings[embedding.name] = embedding
|
return self.register_embedding_by_name(embedding, model, embedding.name)
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
|
||||||
|
|
||||||
|
def register_embedding_by_name(self, embedding, model, name):
|
||||||
|
ids = model.cond_stage_model.tokenize([name])[0]
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
|
if name in self.word_embeddings:
|
||||||
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
# remove old one from the lookup list
|
||||||
|
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||||
|
else:
|
||||||
|
lookup = self.ids_lookup[first_id]
|
||||||
|
if embedding is not None:
|
||||||
|
lookup += [(ids, embedding)]
|
||||||
|
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
||||||
|
if embedding is None:
|
||||||
|
# unregister embedding with specified name
|
||||||
|
if name in self.word_embeddings:
|
||||||
|
del self.word_embeddings[name]
|
||||||
|
if len(self.ids_lookup[first_id])==0:
|
||||||
|
del self.ids_lookup[first_id]
|
||||||
|
return None
|
||||||
|
self.word_embeddings[name] = embedding
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_expected_shape(self):
|
def get_expected_shape(self):
|
||||||
|
|
Loading…
Reference in New Issue