122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
|
"""
|
||
|
Copyright [2022] Victor C Hall
|
||
|
|
||
|
Licensed under the GNU Affero General Public License;
|
||
|
You may not use this code except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
"""
|
||
|
import torch
|
||
|
import os
|
||
|
import hashlib
|
||
|
import io
|
||
|
from PIL import Image, ImageOps
|
||
|
import random
|
||
|
from aspects import get_aspect_buckets
|
||
|
from torchvision import transforms
|
||
|
|
||
|
class LatentCacheItem():
|
||
|
"""
|
||
|
caches image/caption latent pairs and index value to select appropriate random crop jitter
|
||
|
"""
|
||
|
def __init__(self, imagelatent, captionembedding, cropjitteridx, resolution = tuple):
|
||
|
"""
|
||
|
imagelatent: image tensor
|
||
|
captionembedding: caption embedding tensor
|
||
|
cropjitteridx: index of random crop jitter to use
|
||
|
"""
|
||
|
self.imagelatent = imagelatent
|
||
|
self.captionembedding = captionembedding
|
||
|
self.cropjitteridx = cropjitteridx
|
||
|
self.resolution = resolution
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"lat: {self.imagelatent.shape} emb:{self.captionembedding.shape} cj:{self.cropjitteridx}"
|
||
|
|
||
|
class LatentCacheManager():
|
||
|
"""
|
||
|
Manages a cache of latent vectors for a dataset.
|
||
|
"""
|
||
|
def __init__(self, latent_cache_path="/.cache/latents", device=torch.device("cuda"), jitter_lim=8, vae=None):
|
||
|
"""
|
||
|
Manages caching of image latents to disk,
|
||
|
latent_cache_path: path to latent cache folder
|
||
|
device: device to use for creating latents (torch.device)
|
||
|
vae: vae to use for creating latents
|
||
|
jitter_lim: number of random crop jitters to use per image (default: 8)
|
||
|
"""
|
||
|
assert vae is not None, "LatentCacheManager requires a vae to be passed in"
|
||
|
|
||
|
self.cache = dict(str, []) # key: sha256 hash of image path, value: list of LatentCacheItem
|
||
|
self.latentcachepath = latent_cache_path
|
||
|
self.jitter_lim = jitter_lim
|
||
|
self.device = device
|
||
|
self.vae = vae
|
||
|
|
||
|
# create pt file if it doesn't exist
|
||
|
if not os.path.exists(self.latentcachepath):
|
||
|
torch.save(self.cache, self.latentcachepath)
|
||
|
|
||
|
self.vae_on_device = False
|
||
|
|
||
|
def set_vae(self, vae):
|
||
|
self.vae = vae
|
||
|
|
||
|
def delete_vae(self):
|
||
|
self.vae = None
|
||
|
|
||
|
def vae_to_device(self, device):
|
||
|
self.vae.to(self.device)
|
||
|
self.vae_on_device = True
|
||
|
|
||
|
def vae_to_cpu(self):
|
||
|
self.vae.to("cpu")
|
||
|
self.vae_on_device = False
|
||
|
|
||
|
@staticmethod
|
||
|
def __hash(imagepath):
|
||
|
return hashlib.sha256(imagepath.encode("utf-8")).hexdigest()
|
||
|
|
||
|
def add(self, imagepath: io, captionembedding: torch.tensor, target_resolution=(512,512)):
|
||
|
"""
|
||
|
adds aan item to the cache
|
||
|
"""
|
||
|
if not self.vae_on_device: self.vae_to_gpu()
|
||
|
hash = self.__hash(imagepath)
|
||
|
|
||
|
image = Image.open(imagepath)
|
||
|
image_aspects = get_aspect_buckets(resolution=target_resolution)
|
||
|
|
||
|
for i in range(self.jitter_lim):
|
||
|
bleed = random.uniform(0.0, 0.02)
|
||
|
centering = (random.uniform(0.0, 0.02), random.uniform(0.0, 0.02))
|
||
|
jittered_image = ImageOps.fit(image, target_resolution, method=Image.BICUBIC, bleed=bleed, centering=centering)
|
||
|
# convert to tensor
|
||
|
latent = self.vae(jittered_image)
|
||
|
# add to cache
|
||
|
self.cache[hash].append(LatentCacheItem(imagelatent=latent,
|
||
|
captionembedding=captionembedding,
|
||
|
i,
|
||
|
resolution=self.vae.resolution))
|
||
|
|
||
|
|
||
|
|
||
|
# append to pt file
|
||
|
torch.save(self.cache, os.path.join(self.latentcachepath, f"{hash}.pt"))
|
||
|
|
||
|
def __getitem__(self, imagepath, cropjitteridx=0):
|
||
|
"""
|
||
|
returns a LatentCacheItem by imagepath key
|
||
|
"""
|
||
|
hash = self.__hash(imagepath)
|
||
|
|
||
|
item = self.cache[hash][cropjitteridx]
|
||
|
return item
|