Merge pull request #14467 from akx/drop-basicsr
Drop basicsr dependency
This commit is contained in:
commit
16848f950b
|
@ -57,7 +57,7 @@ jobs:
|
||||||
2>&1 | tee output.txt &
|
2>&1 | tee output.txt &
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
wait-for-it --service 127.0.0.1:7860 -t 600
|
wait-for-it --service 127.0.0.1:7860 -t 20
|
||||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||||
- name: Kill test server
|
- name: Kill test server
|
||||||
if: always()
|
if: always()
|
||||||
|
|
|
@ -17,6 +17,28 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
||||||
|
assert img.shape[2] == 3, "image must be RGB"
|
||||||
|
if img.dtype == "float64":
|
||||||
|
img = img.astype("float32")
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||||
|
|
||||||
|
|
||||||
|
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
||||||
|
"""
|
||||||
|
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||||
|
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||||
|
assert tensor.dim() == 3, "tensor must be RGB"
|
||||||
|
img_np = tensor.numpy().transpose(1, 2, 0)
|
||||||
|
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
||||||
|
return np.squeeze(img_np, axis=2)
|
||||||
|
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
|
||||||
def create_face_helper(device) -> FaceRestoreHelper:
|
def create_face_helper(device) -> FaceRestoreHelper:
|
||||||
from facexlib.detection import retinaface
|
from facexlib.detection import retinaface
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
@ -36,14 +58,13 @@ def create_face_helper(device) -> FaceRestoreHelper:
|
||||||
def restore_with_face_helper(
|
def restore_with_face_helper(
|
||||||
np_image: np.ndarray,
|
np_image: np.ndarray,
|
||||||
face_helper: FaceRestoreHelper,
|
face_helper: FaceRestoreHelper,
|
||||||
restore_face: Callable[[np.ndarray], np.ndarray],
|
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||||
|
|
||||||
`restore_face` should take a cropped face image and return a restored face image.
|
`restore_face` should take a cropped face image and return a restored face image.
|
||||||
"""
|
"""
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
np_image = np_image[:, :, ::-1]
|
np_image = np_image[:, :, ::-1]
|
||||||
original_resolution = np_image.shape[0:2]
|
original_resolution = np_image.shape[0:2]
|
||||||
|
@ -56,23 +77,19 @@ def restore_with_face_helper(
|
||||||
face_helper.align_warp_face()
|
face_helper.align_warp_face()
|
||||||
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||||
for cropped_face in face_helper.cropped_faces:
|
for cropped_face in face_helper.cropped_faces:
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
restored_face = tensor2img(
|
cropped_face_t = restore_face(cropped_face_t)
|
||||||
restore_face(cropped_face_t),
|
|
||||||
rgb2bgr=True,
|
|
||||||
min_max=(-1, 1),
|
|
||||||
)
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report('Failed face-restoration inference', exc_info=True)
|
errors.report('Failed face-restoration inference', exc_info=True)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
|
||||||
|
|
||||||
restored_face = restored_face.astype('uint8')
|
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
||||||
|
restored_face = (restored_face * 255.0).astype('uint8')
|
||||||
face_helper.add_restored_face(restored_face)
|
face_helper.add_restored_face(restored_face)
|
||||||
|
|
||||||
logger.debug("Merging restored faces into image")
|
logger.debug("Merging restored faces into image")
|
||||||
|
@ -126,7 +143,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||||
def restore_with_helper(
|
def restore_with_helper(
|
||||||
self,
|
self,
|
||||||
np_image: np.ndarray,
|
np_image: np.ndarray,
|
||||||
restore_face: Callable[[np.ndarray], np.ndarray],
|
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
try:
|
try:
|
||||||
if self.net is None:
|
if self.net is None:
|
||||||
|
|
|
@ -11,7 +11,6 @@ import safetensors.torch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
})
|
})
|
||||||
|
|
||||||
def tensorboard_setup(log_directory):
|
def tensorboard_setup(log_directory):
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||||
return SummaryWriter(
|
return SummaryWriter(
|
||||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||||
|
@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
|
tensorboard_writer = None
|
||||||
if shared.opts.training_enable_tensorboard:
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
try:
|
||||||
tensorboard_writer = tensorboard_setup(log_directory)
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
except ImportError:
|
||||||
|
errors.report("Error initializing tensorboard", exc_info=True)
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
|
||||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
|
@ -2,7 +2,6 @@ GitPython
|
||||||
Pillow
|
Pillow
|
||||||
accelerate
|
accelerate
|
||||||
|
|
||||||
basicsr
|
|
||||||
blendmodes
|
blendmodes
|
||||||
clean-fid
|
clean-fid
|
||||||
einops
|
einops
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
GitPython==3.1.32
|
GitPython==3.1.32
|
||||||
Pillow==9.5.0
|
Pillow==9.5.0
|
||||||
accelerate==0.21.0
|
accelerate==0.21.0
|
||||||
basicsr==1.4.2
|
|
||||||
blendmodes==2022
|
blendmodes==2022
|
||||||
clean-fid==0.1.35
|
clean-fid==0.1.35
|
||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
|
|
Loading…
Reference in New Issue