【文生图】Stable Diffusion XL 1.0模型Full Fine-tuning指南(U-Net全参微调)

前言

Stable Diffusion是计算机视觉领域的一个生成式大模型,能够进行文生图(txt2img)和图生图(img2img)等图像生成任务。Stable Diffusion的开源公布,以及随之而来的一系列借助Stable Diffusion为基础的工作使得人工智能绘画领域呈现出前所未有的高品质创作与创意。

今年7月Stability AI 正式推出了 Stable Diffusion XL(SDXL)1.0,这是当前图像生成领域最好的开源模型。文生图模型又完成了进化过程中的一次重要迭代,SDXL 1.0几乎能够生成任何艺术风格的高质量图像,并且是实现逼真效果的最佳开源模型。该模型在色彩的鲜艳度和准确度方面做了很好的调整,对比度、光线和阴影都比上一代更好,并全部采用原生1024x1024分辨率。除此之外,SDXL 1.0 对于难以生成的概念有了很大改善,例如手、文本以及空间的排列。

目前关于文生图(text2img)模型的训练教程多集中在LoRA、DreamBooth、Text Inversion等模型,且训练方式大多也依赖于可视化UI界面工具,如SD WebUI、AI 绘画一键启动软件等等。而Full Fine-tuning的详细教程可以说几乎没有,所以这里记录一下我在微调SDXL Base模型过程中所参考的资料,以及一些训练参数的说明。

重要教程链接

以海报生成微调为例

总体流程

在这里插入图片描述

数据获取

使用科研机构、公司以及Kaggle平台的公开数据集,具体如下:

POSTER-TEXT

数据集POSTER-TEXT是关于电商海报图片的文本图像生成任务,它包含114,009条记录,由阿里巴巴集团提供。包括原始海报图及擦除海报图中文字后的图片。
Paper:TextPainter: Multimodal Text Image Generation with Visual-harmony and Text-comprehension for Poster Design ACM MM 2023.
Source:https://tianchi.aliyun.com/dataset/160034

AutoPoster

数据集AutoPoster-Dataset是关于电商海报图片的自动化生成任务,它包含 76000 条记录,由阿里巴巴集团提供。

在一些图像中存在重复标注的问题。该论文提到训练集中有69,249张图像,测试集中有7,711张图像。但实际上,在去除重复数据后,训练集中有68,866张唯一的广告海报图像,测试集中有7,671张唯一的图像。

Paper: AutoPoster: A Highly Automatic and Content-aware Design System for Advertising Poster Generation ACM MM 2023
Source:https://tianchi.aliyun.com/dataset/159829

CGL-Dataset

Paper: Composition-aware Graphic Layout GAN for Visual-textual Presentation Designs IJCAI 2022
Github: https://github.com/minzhouGithub/CGL-GAN
Source:https://tianchi.aliyun.com/dataset/142692

PKU PosterLayout

作为第一个包含复杂布局的公共数据集,它在建模布局内关系方面提供了更多的困难,并代表了需要复杂布局的扩展任务。包含9,974个训练图片和905个测试图片。

  • 领域多样性
    从多个来源收集了数据,包括电子商务海报数据集和多个图片库网站。图像在域、质量和分辨率方面都是多样化的,这导致了数据分布的变化,并使数据集更加通用。
  • 内容多样性
    定义了九个类别,涵盖大多数产品,包括食品/饮料、化妆品/配件、电子产品/办公用品、玩具/仪器、服装、运动/交通、杂货、电器/装饰和新鲜农产品。

Paper: A New Dataset and Benchmark for Content-aware Visual-Textual Presentation Layout CVPR 2023
Github: https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023
Source:http://59.108.48.34/tiki/PosterLayout/

PosterT80K

电商海报图片,但是数据未公开,单位为中科大和阿里巴巴
Paper: TextPainter: Multimodal Text Image Generation with
Visual-harmony and Text-comprehension for Poster Design
ACM MM 2023
Source:None

Movie & TV Series & Anime Posters

Kaggle上的公开数据,需要从提供的csv或json文件中的图片url地址自己写个下载脚本。\

Source:

  • https://www.kaggle.com/datasets/bourdier/all-tv-series-details-dataset
    file prefix: https://www.themoviedb.org/t/p/w600_and_h900_bestv2/xx.jpg
  • https://www.kaggle.com/datasets/crawlfeeds/movies-and-tv-shows-dataset
  • https://www.kaggle.com/datasets/phiitm/movie-posters
  • https://www.kaggle.com/datasets/ostamand/tmdb-box-office-prediction-posters
  • https://www.kaggle.com/datasets/dbdmobile/myanimelist-dataset
  • https://www.kaggle.com/zakarihachemi/datasets
  • https://www.kaggle.com/datasets/rezaunderfit/48k-imdb-movies-data

以第一个数据源为例:

import csv
import os
import requests
import warnings
warnings.filterwarnings('ignore')

csv_file = r"C:\Users\xxx\Downloads\tvs.csv"
url_prefix = 'https://www.themoviedb.org/t/p/w600_and_h900_bestv2'
save_root_path = r"D:\dataset\download_data\tv_series"


def parse_csv(path):
    cnt = 0
    s = requests.Session()
    s.verify = False    # 全局关闭ssl验证
    with open(path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            raw_img_url = row['poster_path']    # url item
            img_url = url_prefix + raw_img_url
            if raw_img_url == '':
                continue
            try:
                img_file = s.get(img_url, verify=False)
            except Exception as e:
                print(repr(e))
                print("错误的状态响应码:{}".format(img_file.status_code))

            if img_file.status_code == 200:
                img_name = raw_img_url.split('/')[-1]
                # img_name = row['url'].split('/')[-2] + '.jpg'
                save_path = os.path.join(save_root_path, img_name)
                with open(save_path, 'wb') as img:
                    img.write(img_file.content)

                cnt += 1
                print(cnt, 'saved!')

    print("Done!")


if __name__ == '__main__':
    if not os.path.exists(save_root_path):
        os.makedirs(save_root_path)
    parse_csv(csv_file)

数据清洗与标注

数据筛选标准:低于512和超过1024的去除,文件大小/分辨率<0.0005的去除,dpi小于96的去除。这里的0.0005是根据SD所生成图片的文件大小(kb)和分辨率所确定的一个主观参数标准,用于保证图片质量。统计八张SD所生成图片在该指标下的数值如下:

SD生成的文件大小/图像分辨率:0.00129, 0.0012, 0.0011, 0.00136, 0.0014, 0.0015, 0.0013, 0.00149

图片标注:使用BLIP和Waifu模型自动标注,上文给出的那个知乎链接中有详细的说明,这里不做赘述。

模型训练

  1. 模型准备:
    SDXL 1.0 的 vea 由于数字水印的问题生成的图像会有彩色条纹的伪影存在,可以使用 0.9 的 vae 文件解决这个问题。这里建议使用整合好的Base模型(链接),HuggingFace官方也有相关的说明如下:

SDXL’s VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely --pretrained_vae_model_name_or_path that lets you specify the location of a better VAE (such as this one).

  1. 部分训练参数说明
    • mixed_precision: 是否使用混合精度,根据unet模型决定,[“no”, “fp16”, “bf16”]
    • save_percision: 保存模型精度,[float, fp16, bf16], "float"表示torch.float32
    • vae: 训练和推理的所使用的指定vae模型
    • keep_tokens: 训练中会将tags随机打乱,如果设置为n,则前n个tags的顺序不会被打乱
    • shuffle_caption:bool,打乱标签,可增强模型泛化性
    • sdxl_train_util.py的_load_target_model()判断是否从单个safetensor读取模型,有修改StableDiffusionXLPipeline读取模型的代码
    • sample_every_n_steps: 间隔多少训练迭代次数推理一次
    • noise_offset: 避免生成图像的平均亮度值为0.5,对于logo、立体裁切图像和自然光亮与黑暗场景图片生成有很大提升
    • optimizer_args: 优化器的额外设定参数,比如weight_decay和betas等等,在train_util.py中定义。
    • clip_threshold: AdaFactor优化器的参数,在optimizer_args里新增该参数,默认值1.0。参考:https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
    • gradient checkpointing:用额外的计算时间换取GPU内存,使得有限的GPU显存下训练更大的模型
    • lr_scheduler: 设置学习率变化规则,如linear,cosine,constant_with_warmup
    • lr_scheduler_args: 设置该规则下的具体参数,可参考pytorch文档

模型评估

目前,AIGC领域的测评过程整体上还是比较主观,但这里还是通过美学评分(Aesthetics)和CLIP score指标来分别衡量生成的图片质量与文图匹配度。评测代码基于GhostMix的作者开发的GhostReview,笔者仅取其中的一部分并做了一些优化,请结合着原作者的代码理解,具体代码如下:

import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn as nn
import clip
import os
import torch.nn.functional as F
import pandas as pd
from PIL import Image
import scipy

class MLP(pl.LightningModule):
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            # nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            # nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            # nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            # nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)

    def training_step(self, batch, batch_idx):
        x = batch[self.xcol]
        y = batch[self.ycol].reshape(-1, 1)
        x_hat = self.layers(x)
        loss = F.mse_loss(x_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch[self.xcol]
        y = batch[self.ycol].reshape(-1, 1)
        x_hat = self.layers(x)
        loss = F.mse_loss(x_hat, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def PredictionLAION(image, laion_model, clip_model, clip_process, device='cpu'):
    image = clip_process(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
    im_emb_arr = normalized(image_features.cpu().detach().numpy())
    prediction = laion_model(torch.from_numpy(im_emb_arr).to(device).type(torch.FloatTensor))
    return float(prediction)


# ClipScore for 1 image
# 1张图片的ClipScore
def get_clip_score(image, text, clip_model, preprocess, device='cpu'):
    # Preprocess the image and tokenize the text
    image_input = preprocess(image).unsqueeze(0)
    text_input = clip.tokenize([text], truncate=True)

    # Move the inputs to GPU if available
    image_input = image_input.to(device)
    text_input = text_input.to(device)

    # Generate embeddings for the image and text
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_input)

    # Normalize the features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # Calculate the cosine similarity to get the CLIP score
    clip_score = torch.matmul(image_features, text_features.T).item()

    return clip_score


if __name__ == '__main__':
    # 读取图片路径
    ImgRoot = './Image/ImageRating'
    DataFramePath = './dataresult/MyImageRating'    # all prompts results of each model
    ModelSummaryFile = './ImageRatingSummary/MyModelSummary_Total.csv'

    PromptsFolder = os.listdir(ImgRoot)
    if not os.path.exists(DataFramePath):
        os.makedirs(DataFramePath)

    # 读取图片对应的Prompts
    PromptDataFrame = pd.read_csv('./PromptsForReviews/mytest.csv')
    PromptsList = list(PromptDataFrame['Prompts'])

    # 载入评估模型
    device = "cuda" if torch.cuda.is_available() else "cpu"
    MLP_Model = MLP(768)  # CLIP embedding dim is 768 for CLIP ViT L 14
    # load LAION aesthetics model
    state_dict = torch.load("./models/sac+logos+ava1-l14-linearMSE.pth", map_location=torch.device(device))
    MLP_Model.load_state_dict(state_dict)
    MLP_Model.to(device)
    MLP_Model.eval()
    # Load the pre-trained CLIP model and the image
    CLIP_Model, CLIP_Preprocess = clip.load('ViT-L/14', device=device, download_root='./models/clip')  # RN50x64
    CLIP_Model.to(device)
    CLIP_Model.eval()

    # 跳过已经做过的Prompts
    try:
        DataSummaryDone = pd.read_csv(ModelSummaryFile)
        PromptsNotDone = [i for i in PromptsFolder if i not in list(DataSummaryDone['Model'])]
    except:
        DataSummaryDone = pd.DataFrame()
        PromptsNotDone = [i for i in PromptsFolder]
    if not PromptsNotDone:
        import sys
        sys.exit("There are no models to analyze.")

    for i, name in enumerate(PromptsNotDone):
        FolderPath = os.path.join(ImgRoot, str(name))
        ImageInFolder = os.listdir(FolderPath)
        DataCollect = pd.DataFrame()
        for j, img in enumerate(ImageInFolder):
            prompt_index = int(img.split('-')[1])
            txt = PromptsList[prompt_index]
            ImagePath = os.path.join(FolderPath, img)
            Img = Image.open(ImagePath)
            # Clipscore
            ImgClipScore = get_clip_score(Img, txt, CLIP_Model, CLIP_Preprocess, device)

            # aesthetics scorer
            # ImageScore = predict(Img)

            # LAION aesthetics scorer
            ImageLAIONScore = PredictionLAION(Img, MLP_Model, CLIP_Model, CLIP_Preprocess, device)

            # temp = list(ImageScore)
            temp = list()
            temp.append(float(ImgClipScore))
            temp.append(ImageLAIONScore)
            temp = pd.DataFrame(temp)
            DataCollect = pd.concat([DataCollect, temp], axis=1)
            print("Model:{}/{}, image:{}/{}".format(i+1, len(PromptsNotDone), j+1, len(ImageInFolder)))
        DataCollect = DataCollect.T
        DataCollect['ImageIndex'] = [i + 1 for i in range(len(ImageInFolder))]

        DataCollect.columns = ['ClipScore', 'LAIONScore', 'ImageIndex']

        # 保存原数据
        DataCollect.to_csv(os.path.join(DataFramePath, str(name) + '.csv'), index=False)
        print("One Results File Saved!")
    print('Image rating complete!')


    # do some calculation
    ModelSummary = pd.DataFrame()
    for i in PromptsNotDone:
        DataCollect = pd.read_csv(os.path.join('dataresult/MyImageRating', str(i) + '.csv'))
        temp = pd.DataFrame(DataCollect['LAIONScore'].describe()).T
        # 计算数据的偏度
        temp['skew'] = scipy.stats.skew(DataCollect['LAIONScore'], axis=0, bias=True, nan_policy="propagate")
        # 计算数据的峰度
        temp['kurtosis'] = scipy.stats.kurtosis(DataCollect['LAIONScore'], axis=0, fisher=True, bias=True,
                                                nan_policy="propagate")
        temp.columns = [i + '_LAIONScore' for i in list(temp.columns)]
        # temp['RatingScore_mean']=np.mean(DataCollect['Rating'])
        # temp['RatingScore_std']=np.std(DataCollect['Rating'])
        temp['Clipscore_mean'] = np.mean(DataCollect['ClipScore'])
        temp['Clipscore_std'] = np.std(DataCollect['ClipScore'])
        # temp['Artifact_mean']=np.mean(DataCollect['Artifact'])
        # temp['Artifact_std']=np.std(DataCollect['Artifact'])
        temp['Model'] = str(i)
        ModelSummary = pd.concat([ModelSummary, temp], axis=0)

    # save results
    new_order = ['Model', 'count_LAIONScore', 'mean_LAIONScore', 'std_LAIONScore',
                 'min_LAIONScore', '25%_LAIONScore', '50%_LAIONScore', '75%_LAIONScore',
                 'max_LAIONScore', 'skew_LAIONScore', 'kurtosis_LAIONScore',
                 'Clipscore_mean', 'Clipscore_std']
    # 使用 reindex() 方法重新排序列
    ModelSummary = ModelSummary.reindex(columns=new_order)

    DataSummaryDone = pd.concat([DataSummaryDone, ModelSummary], axis=0)
    DataSummaryDone.to_csv('./ImageRatingSummary/MyModelSummary_Total.csv')

    pd.set_option('display.max_rows', None)  # None表示没有限制
    pd.set_option('display.max_columns', None)  # None表示没有限制
    pd.set_option('display.width', 1000)  # 设置宽度为1000字符
    print(DataSummaryDone)

下图给出了本文所训练的SDXL-Poster与主流文生图模型的比较结果,注意其中包括Anything模型开始往下的结果都是笔者自己调用相关模型生成的180张图片计算得来,所以标准差都偏大;而上方则是GhostReview作者调用这些模型生成960张图片的计算而来的结果。由于样本数量不一致,请读者谨慎参考
在这里插入图片描述

生成图片样例

将本文训练的SDXL-Poster与SDXL-Base、CyberRealistic比较。

宠物包商品海报

A feline peering out from a striped transparent travel bag with a bicycle in the background. Outdoor setting, sunset ambiance. Product advertisement of pet bag, No humans, focus on cat and bag, vibrant colors, recreational theme
SDXL-Poster

(a) SDXL-Poster

在这里插入图片描述

(b) SDXL-Base

在这里插入图片描述

(c) CyberRealistic

护肤精华商品海报

Four amber glass bottles with droppers placed side by side, arranged on a white background, skincare product promotion, no individuals present, still life setup
在这里插入图片描述

(a) SDXL-Poster

在这里插入图片描述

(b) SDXL-Base

在这里插入图片描述

(c) CyberRealistic

一些Tips

Mata:EMU(Expressive Media Universe)

简单文字,5秒出图。论文:https://arxiv.org/abs/2309.15807
知乎详细解读:https://zhuanlan.zhihu.com/p/659476603

介绍了EMU的训练方式:quality-tuning,一种有监督微调。其有三个关键:

  • 微调数据集可以小得出奇,大约只有几千张图片;
  • 数据集的质量非常高,这使得数据整理难以完全自动化,需要人工标注;
  • 即使微调数据集很小,质量调整不仅能显著提高生成图片的美观度,而且不会牺牲通用性,因为通用性是根据输入提示的忠实度来衡量的。
  • 基础的预训练大模型生成过程没有被引导为生成与微调数据集统计分布一致的图像,而quality-tuning能有效地限制输出图像与微调子集保持分布一致
  • 分辨率低于1024x1024的图像使用pixel diffusion upsampler来提升分辨率

ideogram

生成图像中包含文字的生成模型:ideogram,2023年8月23日发布,免费,官网https://ideogram.ai/

DALL-E3

“文本渲染仍然不可靠,他们认为该模型很难将单词 token 映射为图像中的字母”

  • 增强模型的prompt following能力,训练一个image captioner来生成更为准确详细的图像caption
  • 合成的长caption能够提升模型性能,且与groud-truth caption混合比例为95%效果最佳。长caption通过GPT-4对人类的描述进行upsampling得到

关于模型优化

  • 将训练出的base模型与类似的LoRA模型mix会好些
  • base模型的作用就是多风格兼容,风格细化是LoRA做的事情
  • SDXL在生成文字以及手时出现问题:https://zhuanlan.zhihu.com/p/649308666
  • 微调迭代次数不能超过5k,否则导致明显的过拟合,降低视觉概念的通用性(来源:EMU模型tips)

Examples of Commonly Used Negative Prompts:

  1. Basic Negative Prompts: worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting.
  2. For Animated Characters: bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, worst face, three crus, extra crus, fused crus, worst feet, three feet, fused feet, fused thigh, three thigh, fused thigh, extra thigh, worst thigh, missing fingers, extra fingers, ugly fingers, long fingers, horn, realistic photo, extra eyes, huge eyes, 2girl, amputation, disconnected limbs.
  3. For Realistic Characters: bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms, poorly drawn face, bad face, fused face, cloned face, worst face, three crus, extra crus, fused crus, worst feet, three feet, fused feet, fused thigh, three thigh, fused thigh, extra thigh, worst thigh, missing fingers, extra fingers, ugly fingers, long fingers, horn, extra eyes, huge eyes, 2girl, amputation, disconnected limbs, cartoon, cg, 3d, unreal, animate.
  4. For Non-Adult Content: nsfw, nude, censored.
  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值