Datawhale X 魔搭 AI夏令营第四期 AIGC方向 task02笔记

赛事任务:

1.参赛者需在可图Kolors 模型的基础上训练LoRA 模型,生成无限风格,如水墨画风格、水彩风格、赛博朋克风格、日漫风格......

2.基于LoRA模型生成 8 张图片组成连贯故事,故事内容可自定义

本次笔记内容为用lora微调可图模型的代码(GitHub - modelscope/DiffSynth-Studio: Enjoy the magic of Diffusion models!Enjoy the magic of Diffusion models! Contribute to modelscope/DiffSynth-Studio development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/modelscope/DiffSynth-Studio.git)分析,在linux中输入如下命令(运行py文件和参数配置)即可运行

python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \ # 选择使用可图的Lora训练脚本DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py
  --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \ # 选择unet模型
  --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \ # 选择text_encoder
  --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \ # 选择vae模型
  --lora_rank 16 \ # lora_rank 16 表示在权衡模型表达能力和训练效率时,选择了使用 16 作为秩,适合在不显著降低模型性能的前提下,通过 LoRA 减少计算和内存的需求
  --lora_alpha 4.0 \ # 设置 LoRA 的 alpha 值,影响调整的强度
  --dataset_path data/lora_dataset_processed \ # 指定数据集路径,用于训练模型
  --output_path ./models \ # 指定输出路径,用于保存模型
  --max_epochs 1 \ # 设置最大训练轮数为 1
  --center_crop \ # 启用中心裁剪,这通常用于图像预处理
  --use_gradient_checkpointing \ # 启用梯度检查点技术,以节省内存
  --precision "16-mixed" # 指定训练时的精度为混合 16 位精度(half precision),这可以加速训练并减少显存使用

 加载配置参数后初始化模型,进行训练

    args = parse_args()
    model = LightningModel(
        torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
        pretrained_weights=[
            args.pretrained_unet_path,
            args.pretrained_text_encoder_path,
            args.pretrained_fp16_vae_path,
        ],
        learning_rate=args.learning_rate,
        use_gradient_checkpointing=args.use_gradient_checkpointing,
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_target_modules=args.lora_target_modules
    )
    launch_training_task(model, args)

先加载数据集,配置训练轮数图像大小尺寸,再做迭代器(在一轮训练中分多批次拿出数据进行训练),然后定义训练策略后进行训练

    dataset = TextImageDataset(
        args.dataset_path,
        steps_per_epoch=args.steps_per_epoch * args.batch_size,
        height=args.height,
        width=args.width,
        center_crop=args.center_crop,
        random_flip=args.random_flip
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.dataloader_num_workers
    )

    # train
    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        accelerator="gpu",
        devices="auto",
        precision=args.precision,
        strategy=args.training_strategy,
        default_root_dir=args.output_path,
        accumulate_grad_batches=args.accumulate_grad_batches,
        callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
    )
    trainer.fit(model=model, train_dataloaders=train_loader)

使用trainer.fit(model=model, train_dataloaders=train_loader)进行训练,trainer.fit容易被误导为Pytorch-Lightning(import pytorch_lightning as pl导入)中的类方法,但在本baseline中用!pip uninstall pytorch-lightning -y卸载,实际使用的是pl.Trainer(import lightning as pl导入)。

模型定义:

class LightningModel(LightningModelForT2ILoRA):
    def __init__(
        self,
        torch_dtype=torch.float16, pretrained_weights=[],
        learning_rate=1e-4, use_gradient_checkpointing=True,
        lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
    ):
        super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
        # Load models
        model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
        model_manager.load_models(pretrained_weights)
        self.pipe = SDXLImagePipeline.from_model_manager(model_manager)
        self.pipe.scheduler.set_timesteps(1100)

        # Convert the vae encoder to torch.float16
        self.pipe.vae_encoder.to(torch_dtype)

        self.freeze_parameters()
        self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值