Merge pull request #136 from qslug/docker-deps-fix
Mixed mode captioning and QoL tweaks
This commit is contained in:
commit
91ab92d298
|
@ -32,6 +32,28 @@
|
||||||
"Come visit us at [EveryDream Discord](https://discord.gg/uheqxU6sXN)"
|
"Come visit us at [EveryDream Discord](https://discord.gg/uheqxU6sXN)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ffff47f7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Optional Speed Test\n",
|
||||||
|
"If all goes well you may find yourself downloading (or pushing to the cloud) 2-8GB of model data per saved checkpoint. Make sure your pod is not a dud. ~1000Mbit/s up/dn is probably good, though the location of the pod also makes a difference.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "934ba107",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import speedtest\n",
|
||||||
|
"st = speedtest.Speedtest()\n",
|
||||||
|
"print(f\"Your download speed: {round(st.download() / 1000 / 1000, 1)} Mbit/s\")\n",
|
||||||
|
"print(f\"Your upload speed: {round(st.upload() / 1000 / 1000, 1)} Mbit/s\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "7c73894e-3b5e-4268-9f83-ed89bd4569f2",
|
"id": "7c73894e-3b5e-4268-9f83-ed89bd4569f2",
|
||||||
|
|
|
@ -174,7 +174,7 @@ class Dataset:
|
||||||
# Use file name for caption only as a last resort
|
# Use file name for caption only as a last resort
|
||||||
@classmethod
|
@classmethod
|
||||||
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
||||||
if cfg.main_prompts or cfg.tags:
|
if cfg.main_prompts:
|
||||||
return cfg
|
return cfg
|
||||||
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
||||||
return cfg.merge(cap_cfg)
|
return cfg.merge(cap_cfg)
|
||||||
|
@ -217,9 +217,13 @@ class Dataset:
|
||||||
items = []
|
items = []
|
||||||
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||||
config = self.image_configs[image]
|
config = self.image_configs[image]
|
||||||
|
|
||||||
if len(config.main_prompts) > 1:
|
if len(config.main_prompts) > 1:
|
||||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||||
|
|
||||||
|
if len(config.main_prompts) < 1:
|
||||||
|
logging.warning(f" *** No main_prompts for image {image}")
|
||||||
|
|
||||||
tags = []
|
tags = []
|
||||||
tag_weights = []
|
tag_weights = []
|
||||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||||
|
|
|
@ -74,6 +74,7 @@ RUN echo "source ${VIRTUAL_ENV}/bin/activate" >> /root/.bashrc
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
|
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
|
||||||
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudart.so
|
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudart.so
|
||||||
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvrtc.so
|
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvrtc.so
|
||||||
|
|
||||||
ADD requirements-runtime.txt /
|
ADD requirements-runtime.txt /
|
||||||
RUN pip install --no-cache-dir -r requirements-runtime.txt
|
RUN pip install --no-cache-dir -r requirements-runtime.txt
|
||||||
|
|
||||||
|
@ -81,7 +82,6 @@ WORKDIR /workspace
|
||||||
RUN git clone https://github.com/victorchall/EveryDream2trainer
|
RUN git clone https://github.com/victorchall/EveryDream2trainer
|
||||||
|
|
||||||
WORKDIR /workspace/EveryDream2trainer
|
WORKDIR /workspace/EveryDream2trainer
|
||||||
# RUN git checkout torch2
|
|
||||||
RUN python utils/get_yamls.py && \
|
RUN python utils/get_yamls.py && \
|
||||||
mkdir -p logs && mkdir -p input
|
mkdir -p logs && mkdir -p input
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,10 @@ ipyevents
|
||||||
ipywidgets
|
ipywidgets
|
||||||
jupyter-archive
|
jupyter-archive
|
||||||
jupyterlab
|
jupyterlab
|
||||||
|
lion-pytorch
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
pyfakefs
|
pyfakefs
|
||||||
pynvml==11.5.0
|
pynvml==11.5.0
|
||||||
|
speedtest-cli
|
||||||
tensorboard==2.12.0
|
tensorboard==2.12.0
|
||||||
wandb
|
wandb
|
||||||
|
|
|
@ -13,10 +13,11 @@ then
|
||||||
service ssh start
|
service ssh start
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
tensorboard --logdir /workspace/EveryDream2trainer/logs --host 0.0.0.0 &
|
||||||
|
|
||||||
# RunPod JupyterLab
|
# RunPod JupyterLab
|
||||||
if [[ $JUPYTER_PASSWORD ]]
|
if [[ $JUPYTER_PASSWORD ]]
|
||||||
then
|
then
|
||||||
tensorboard --logdir /workspace/EveryDream2trainer/logs --host 0.0.0.0 &
|
|
||||||
jupyter nbextension enable --py widgetsnbextension
|
jupyter nbextension enable --py widgetsnbextension
|
||||||
jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
|
jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
|
||||||
jupyter lab --allow-root --no-browser --port=8888 --ip=* --ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' --ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace/EveryDream2trainer
|
jupyter lab --allow-root --no-browser --port=8888 --ip=* --ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' --ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace/EveryDream2trainer
|
||||||
|
|
|
@ -75,7 +75,7 @@ class TestResolve(unittest.TestCase):
|
||||||
'path': DATA_PATH,
|
'path': DATA_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
items = resolver.resolve(data_root_spec, ARGS)
|
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -88,7 +88,7 @@ class TestResolve(unittest.TestCase):
|
||||||
self.assertEqual(len(undersized_images), 1)
|
self.assertEqual(len(undersized_images), 1)
|
||||||
|
|
||||||
def test_json_resolve_with_str(self):
|
def test_json_resolve_with_str(self):
|
||||||
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
|
items = sorted(resolver.resolve(JSON_ROOT_PATH, ARGS), key=lambda i: i.pathname)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
@ -124,14 +124,14 @@ class TestResolve(unittest.TestCase):
|
||||||
JSON_ROOT_PATH,
|
JSON_ROOT_PATH,
|
||||||
]
|
]
|
||||||
|
|
||||||
items = resolver.resolve(data_root_spec, ARGS)
|
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||||
image_paths = [item.pathname for item in items]
|
image_paths = [item.pathname for item in items]
|
||||||
image_captions = [item.caption for item in items]
|
image_captions = [item.caption for item in items]
|
||||||
captions = [caption.get_caption() for caption in image_captions]
|
captions = [caption.get_caption() for caption in image_captions]
|
||||||
|
|
||||||
self.assertEqual(len(items), 6)
|
self.assertEqual(len(items), 6)
|
||||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
|
self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2))
|
||||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
|
self.assertEqual(set(captions), {}'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'})
|
||||||
|
|
||||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||||
self.assertEqual(len(undersized_images), 2)
|
self.assertEqual(len(undersized_images), 2)
|
|
@ -100,6 +100,24 @@ class TestDataset(TestCase):
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
|
def test_captions_are_read_from_filename_if_no_main_prompt(self):
|
||||||
|
self.fs.create_file("filename main prompt, filename tag.jpg")
|
||||||
|
self.fs.create_file("filename main prompt, filename tag.yaml",
|
||||||
|
contents=dedent("""
|
||||||
|
caption:
|
||||||
|
tags:
|
||||||
|
- tag: standalone yaml tag
|
||||||
|
"""))
|
||||||
|
actual = Dataset.from_path(".").image_configs
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"./filename main prompt, filename tag.jpg": ImageConfig(
|
||||||
|
main_prompts="filename main prompt",
|
||||||
|
tags= [ Tag("filename tag"), Tag("standalone yaml tag") ]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
|
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
|
||||||
self.fs.create_file("image_1.jpg")
|
self.fs.create_file("image_1.jpg")
|
||||||
self.fs.create_file("image_1.yaml", contents=dedent("""
|
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||||
|
@ -358,4 +376,4 @@ class TestDataset(TestCase):
|
||||||
self.assertEqual(actual[2].caption.rating(), 1.0)
|
self.assertEqual(actual[2].caption.rating(), 1.0)
|
||||||
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
|
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
|
||||||
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
|
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
|
||||||
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
Loading…
Reference in New Issue