2023-12-25 14:01:02 -07:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import logging
|
2022-09-10 04:53:10 -06:00
|
|
|
|
2022-09-07 03:32:28 -06:00
|
|
|
import torch
|
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
from modules import (
|
|
|
|
devices,
|
|
|
|
errors,
|
|
|
|
face_restoration,
|
|
|
|
face_restoration_utils,
|
|
|
|
modelloader,
|
|
|
|
shared,
|
|
|
|
)
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2022-09-26 08:29:50 -06:00
|
|
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
2023-12-25 14:01:02 -07:00
|
|
|
model_download_name = 'codeformer-v0.1.0.pth'
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
# used by e.g. postprocessing_codeformer.py
|
|
|
|
codeformer: face_restoration.FaceRestoration | None = None
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2022-09-26 08:29:50 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
2023-12-25 05:43:51 -07:00
|
|
|
def name(self):
|
|
|
|
return "CodeFormer"
|
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
def load_net(self) -> torch.Module:
|
|
|
|
for model_path in modelloader.load_models(
|
|
|
|
model_path=self.model_path,
|
|
|
|
model_url=model_url,
|
|
|
|
command_path=self.model_path,
|
|
|
|
download_name=model_download_name,
|
2023-12-25 05:43:51 -07:00
|
|
|
ext_filter=['.pth'],
|
2023-12-25 14:01:02 -07:00
|
|
|
):
|
|
|
|
return modelloader.load_spandrel_model(
|
|
|
|
model_path,
|
|
|
|
device=devices.device_codeformer,
|
2023-12-30 07:37:03 -07:00
|
|
|
expected_architecture='CodeFormer',
|
2023-12-25 14:01:02 -07:00
|
|
|
).model
|
|
|
|
raise ValueError("No codeformer model found")
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
def get_device(self):
|
|
|
|
return devices.device_codeformer
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
def restore(self, np_image, w: float | None = None):
|
|
|
|
if w is None:
|
|
|
|
w = getattr(shared.opts, "code_former_weight", 0.5)
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
def restore_face(cropped_face_t):
|
|
|
|
assert self.net is not None
|
|
|
|
return self.net(cropped_face_t, w=w, adain=True)[0]
|
2022-09-07 03:32:28 -06:00
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
return self.restore_with_helper(np_image, restore_face)
|
2022-09-07 03:32:28 -06:00
|
|
|
|
|
|
|
|
2023-12-25 14:01:02 -07:00
|
|
|
def setup_model(dirname: str) -> None:
|
|
|
|
global codeformer
|
2023-12-25 05:43:51 -07:00
|
|
|
try:
|
2022-09-26 08:29:50 -06:00
|
|
|
codeformer = FaceRestorerCodeFormer(dirname)
|
2022-09-07 04:35:02 -06:00
|
|
|
shared.face_restorers.append(codeformer)
|
2022-09-07 03:32:28 -06:00
|
|
|
except Exception:
|
2023-05-31 10:56:37 -06:00
|
|
|
errors.report("Error setting up CodeFormer", exc_info=True)
|