Datawhale X 魔搭 AI夏令营第四期 AIGC方向 学习笔记(一)

本期主要任务是了解AI文生图的原理并进行相关实践

下面是对baseline部分代码的功能介绍:

安装Data-juicere和DiffSynth-Studio

!pip install simple-aesthetics-predictor
!pip install -v -e data-juicer
!pip uninstall pytorch-lightning -y
!pip install peft lightning pandas torchvision
!pip install -e DiffSynth-Studio

基本的通过pip安装,"!"控制语句在终端进行操作。simple-aesthetics-predictor 这个包,参考pypi上项目描述,是一个基于CLIP的美学预测器,用于预测图片的美学质量。"-v"、"-e"命令用于设定安装模式. data-juicer ,参考github上的原项目Readme文件,是一个“用于大语言模型的一站式数据处理系统”。peft 与参数高效微调相关,lightning 是用于简化训练过程的库,pandas和torchvision就不多说了。DiffSynth-Studio 则是一种用于实现图片和视频风格转换的引擎。


下载数据集

从modelscope上下载某个数据集,指定了目标数据集的路径,子集名称,拆分部分(训练集)和下载完成后的缓存目录。


保存数据集中的图片和元数据

os.makedirs("./data/lora_dataset/train", exist_ok=True)
os.makedirs("./data/data-juicer/input", exist_ok=True)
with open("./data/data-juicer/input/metadata.jsonl", "w") as f:
    for data_id, data in enumerate(tqdm(ds)):
        image = data["image"].convert("RGB")
        image.save(f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg")
        metadata = {"text": "二次元", "image": [f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg"]}
        f.write(json.dumps(metadata))
        f.write("\n")

这部分主要进行对下载得到的数据集的遍历,将其中的图片转化成RGB格式后保存到指定路径(../data/lora_dataset/train)。另外创建由文本和对应图片构成的字典作为元数据写入json文件保存


数据处理

在变量 data_juicer_config 中定义了数据处理的各项配置信息,并将其写入yaml文件中。之后调用dj-process命令开启数据处理,并通过该配置文件传入相关参数。

保存处理好的数据

主要是从 result.jsonl 文件中进行文本和图像的保存,并将文件名和文本信息存至csv文件中


训练模型 

from diffsynth import download_models
download_models(["Kolors", "SDXL-vae-fp16-fix"])

!python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py -h

下载模型;终端查看训练脚本输入参数

cmd = """
python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \
  --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
  --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
  --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
  --lora_rank 16 \
  --lora_alpha 4.0 \
  --dataset_path data/lora_dataset_processed \
  --output_path ./models \
  --max_epochs 1 \
  --center_crop \
  --use_gradient_checkpointing \
  --precision "16-mixed"
""".strip()

os.system(cmd)

 这一段定义了训练过程需要在终端执行的命令,主要包含以下内容:指定了预训练需要的Unet模型路径、文本编码器模型路径和fp16VAE模型路径;指定lora的等级和alpha值相关参数;指定数据集路径、输出路径;指定最大训练轮数,使用中心裁剪、梯度检查点,和精度参数。


加载模型 

def load_lora(model, lora_rank, lora_alpha, lora_path):
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        init_lora_weights="gaussian",
        target_modules=["to_q", "to_k", "to_v", "to_out"],
    )
    model = inject_adapter_in_model(lora_config, model)
    state_dict = torch.load(lora_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                             file_path_list=[
                                 "models/kolors/Kolors/text_encoder",
                                 "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
                                 "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
                             ])
pipe = SDXLImagePipeline.from_model_manager(model_manager)

# Load LoRA
pipe.unet = load_lora(
    pipe.unet,
    lora_rank=16, # This parameter should be consistent with that in your training script.
    lora_alpha=2.0, # lora_alpha can control the weight of LoRA.
    lora_path="models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt"
)

 load_lora 函数加载loRA模型并进行相关参数配置(

model:要注入 LoRA 适配器的原始模型。

lora_rank:LoRA 适配器的秩,用于控制适配器的复杂度。

lora_alpha:LoRA 适配器的缩放因子,用于控制其权重。

lora_path:包含预训练 LoRA 权重的文件路径。) 

使用 inject_adapt_in_model 将loRA注入原始模型,加载loRA预训练的权重字典并应用至模型中。

后续部分并不熟悉各实例的作用,暂且一放。


生成图像

torch.manual_seed(0)
image = pipe(
    prompt="二次元,一个紫色短发小女孩,在家中沙发上坐着,双手托着腮,很无聊,全身,粉色连衣裙",
    negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
    cfg_scale=4,
    num_inference_steps=50, height=1024, width=1024,
)
image.save("1.jpg")

设置随机种子值使随机操作具有可重复性。使用pipe对象进行生成,给出正负面提示词,配置尺度,推理步数和图像尺寸参数。


在尝试了自己的一系列提示词后得到如下八张图,内容类似baseline原本给的,主线换成了足球:

主要存在的问题:部分画风不统一;部分图细节不佳;对“足球”一词的表现错误(应该是中文输入翻译问题);部分提示词的信息未有效表现出来

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值