change import statements for #14478
This commit is contained in:
parent
be5f1acc8f
commit
a70dfb64a8
|
@ -4,7 +4,7 @@ from functools import lru_cache
|
|||
|
||||
import torch
|
||||
from modules import errors, shared
|
||||
from modules.torch_utils import get_param
|
||||
from modules import torch_utils
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
|
@ -132,7 +132,7 @@ patch_module_list = [
|
|||
|
||||
|
||||
def manual_cast_forward(self, *args, **kwargs):
|
||||
org_dtype = get_param(self).dtype
|
||||
org_dtype = torch_utils.get_param(self).dtype
|
||||
self.to(dtype)
|
||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
|
|
@ -10,8 +10,7 @@ import torch.hub
|
|||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
from modules.torch_utils import get_param
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
@ -132,7 +131,7 @@ class InterrogateModels:
|
|||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = get_param(self.clip_model).dtype
|
||||
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
|
|
|
@ -6,7 +6,7 @@ import sgm.models.diffusion
|
|||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules.torch_utils import get_param
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
|
@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
|
|||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = get_param(model.model.diffusion_model).dtype
|
||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
|
|
|
@ -6,8 +6,7 @@ import torch
|
|||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from modules import images, shared
|
||||
from modules.torch_utils import get_param
|
||||
from modules import images, shared, torch_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
|
|||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
|
||||
param = get_param(model)
|
||||
param = torch_utils.get_param(model)
|
||||
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
|||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
|
@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = get_param(self).device
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
|
|
@ -4,8 +4,7 @@ import torch
|
|||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
|
@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = get_param(self).device
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
|
|
@ -3,7 +3,7 @@ import types
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapped", [True, False])
|
||||
|
@ -14,6 +14,6 @@ def test_get_param(wrapped):
|
|||
if wrapped:
|
||||
# more or less how spandrel wraps a thing
|
||||
mod = types.SimpleNamespace(model=mod)
|
||||
p = get_param(mod)
|
||||
p = torch_utils.get_param(mod)
|
||||
assert p.dtype == torch.float16
|
||||
assert p.device == cpu
|
||||
|
|
Loading…
Reference in New Issue