change import statements for #14478

This commit is contained in:
AUTOMATIC1111 2023-12-31 22:38:30 +03:00
parent be5f1acc8f
commit a70dfb64a8
7 changed files with 14 additions and 17 deletions

View File

@ -4,7 +4,7 @@ from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared
from modules.torch_utils import get_param from modules import torch_utils
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
@ -132,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs): def manual_cast_forward(self, *args, **kwargs):
org_dtype = get_param(self).dtype org_dtype = torch_utils.get_param(self).dtype
self.to(dtype) self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}

View File

@ -10,8 +10,7 @@ import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
from modules.torch_utils import get_param
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
@ -132,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = get_param(self.clip_model).dtype self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:

View File

@ -6,7 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model): def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = get_param(model.model.diffusion_model).dtype dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' model.cond_stage_key = 'txt'

View File

@ -6,8 +6,7 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
from modules import images, shared from modules import images, shared, torch_utils
from modules.torch_utils import get_param
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
param = get_param(model) param = torch_utils.get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad(): with torch.no_grad():

View File

@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param from modules import torch_utils
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = get_param(self).device device = torch_utils.get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,

View File

@ -4,8 +4,7 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules import torch_utils
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = get_param(self).device device = torch_utils.get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,

View File

@ -3,7 +3,7 @@ import types
import pytest import pytest
import torch import torch
from modules.torch_utils import get_param from modules import torch_utils
@pytest.mark.parametrize("wrapped", [True, False]) @pytest.mark.parametrize("wrapped", [True, False])
@ -14,6 +14,6 @@ def test_get_param(wrapped):
if wrapped: if wrapped:
# more or less how spandrel wraps a thing # more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod) mod = types.SimpleNamespace(model=mod)
p = get_param(mod) p = torch_utils.get_param(mod)
assert p.dtype == torch.float16 assert p.dtype == torch.float16
assert p.device == cpu assert p.device == cpu