Apply correct inference precision implementation

This commit is contained in:
Kohaku-Blueleaf 2024-01-09 23:13:34 +08:00
parent c2c05fcca8
commit e00365962b
1 changed files with 33 additions and 9 deletions

View File

@ -132,6 +132,21 @@ patch_module_list = [
] ]
def cast_output(result):
if isinstance(result, tuple):
result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result
def autocast_with_cast_output(self, *args, **kwargs):
result = self.org_forward(*args, **kwargs)
if dtype_inference != dtype:
result = cast_output(result)
return result
def manual_cast_forward(target_dtype): def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs): def forward_wrapper(self, *args, **kwargs):
if any( if any(
@ -149,15 +164,7 @@ def manual_cast_forward(target_dtype):
self.to(org_dtype) self.to(org_dtype)
if target_dtype != dtype_inference: if target_dtype != dtype_inference:
if isinstance(result, tuple): result = cast_output(result)
result = tuple(
i.to(dtype_inference)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result return result
return forward_wrapper return forward_wrapper
@ -178,6 +185,20 @@ def manual_cast(target_dtype):
module_type.forward = module_type.org_forward module_type.forward = module_type.org_forward
@contextlib.contextmanager
def precision_full_with_autocast(autocast_ctx):
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = autocast_with_cast_output
module_type.org_forward = org_forward
try:
with autocast_ctx:
yield None
finally:
for module_type in patch_module_list:
module_type.forward = module_type.org_forward
def autocast(disable=False): def autocast(disable=False):
if disable: if disable:
return contextlib.nullcontext() return contextlib.nullcontext()
@ -191,6 +212,9 @@ def autocast(disable=False):
if has_xpu() or has_mps() or cuda_no_autocast(): if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype_inference) return manual_cast(dtype_inference)
if dtype_inference == torch.float32 and dtype != torch.float32:
return precision_full_with_autocast(torch.autocast("cuda"))
return torch.autocast("cuda") return torch.autocast("cuda")