update to latest colossalai (#1951)
This commit is contained in:
parent
aba2a65d6a
commit
089f0f4c98
|
@ -15,8 +15,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel.utils import convert_to_torch_module
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
|
@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
|
||||
model = GeminiDDP(
|
||||
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
|
||||
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
# config for colossalai
|
||||
|
||||
config = {
|
||||
"BATCH": args.train_batch_size,
|
||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
"clip_grad_norm": args.max_grad_norm,
|
||||
}
|
||||
|
||||
colossalai.launch_from_torch(config=config)
|
||||
pg = ProcessGroup()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
if args.seed is not None:
|
||||
gpc.set_seed(args.seed)
|
||||
|
@ -472,7 +462,7 @@ def main(args):
|
|||
)
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext():
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
|
||||
)
|
||||
|
@ -484,12 +474,19 @@ def main(args):
|
|||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
|
||||
args.learning_rate = (
|
||||
args.learning_rate
|
||||
* args.gradient_accumulation_steps
|
||||
* args.train_batch_size
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
|
||||
unet = gemini_zero_dpp(unet, pg, args.placement)
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(
|
||||
unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
|
||||
)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
@ -657,10 +654,11 @@ def main(args):
|
|||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=convert_to_torch_module(unet),
|
||||
unet=torch_unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
|
@ -670,7 +668,7 @@ def main(args):
|
|||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
unet = convert_to_torch_module(unet)
|
||||
unet = get_static_torch_model(unet)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
|
Loading…
Reference in New Issue