概述
生成任务不像理解类任务有明确客观的评价指标算法,生成任务中常用的FID,IS,CLIP-Score等指标均无法做到全面客观的评价,目前只能依赖人工标注,费时费力。因此,寻找合适的图像质量评价方案,并结合相应的业务特点,对生成的图像质量评价结果尽可能接近人的主观评价,对于指导我们模型训练,提升模型生成能力,降低模型算法落地成本,促进模型算法与业务更好的结合,都有重要的意义。
算法原理
MPS:Learning Multi-dimensional Human Preference for Text-to-Image GenerationRichHF:Rich Human Feedback for Text-to-Image Generation
ImageReward:Learning and Evaluating Human Preferences for Text-to-Image Generation
MPS
算法原理
整体网络结构:
cross attention模块计算公式:
注意:Mc是经过二值化的mask,大于规定的像素值阈值设为0,其余设为无穷大。这里是否也可以采用attention思路,把Mc训练成attention map,然后将结果和Mc相乘
condtion 语句:
训练loss:
算法结果
condtion mask可视化分析:
结果对比:
RichHF算法原理
算法原理
整体网络结构:
Tips:
预测keyword misalignment sequence时,是通过修改输入的prompt,在针对图文不符的相应词汇后加上_0
Implausibility/misalignment是通过热力图实现,打标只需要标注中心点,通过高斯热力图形成heat map。该分支具体生成Implausibility还是misalignment是在prompt后添加额外条件决定,比如输入prompt+Implausibility heatmap该分支就会输出Implausibility heatmap(scores的输出也是类似处理逻辑)
算法结果
预测质量分结果:
misalignment heatmaps结果:
implausibility heatmaps结果:
利用RichHF算法优化模型:
利用heatmap二次生成图片:
ImageReward算法
算法原理
官方中文博客
!](https://i-blog.csdnimg.cn/direct/92df60058f0a425eb723f115f7a8d605.png)
ImageReward解决方案由以下几个步骤组成:
- 专业的大规模数据集ImageRewardDB:约13.7万个⽐较pairs,完全开源。
- 通⽤的反映⼈类对于⽂本到图像偏好的模型ImageReward:文生图奖励模型之先锋,优于现有的⽂本-图像评分⽅法,例如CLIP、Aesthetic和BLIP;也是新的文生图自动评价指标。
- 借助ImageReward的直接优化⽅法ReFL:用人类偏好改进扩散生成模型。
评分模型训练:
ReFL算法:
通过观察去噪步骤中的ImageReward分数,我们得出了一个有趣的发现(参见上图左)。对于一个降噪过程,例如降噪步数为40步时,在降噪过程中途直接预测中间降噪结果对应的原图:
- 当t ≤ 15:ImageReward得分和最终结果的一致性很低;
- 当15 ≤ t ≤30:高质量生成结果的ImageReward得分开始脱颖而出,但总体上我们仍然无法根据目前的ImageReward分数清楚地判断所有生成结果的最终质量;
- 当t ≥ 30:不同生成结果对应的ImageReward分数的已经可以区分。
根据观察,我们得出结论,经过30步去噪(总步数为40步),而不需要到最后一步降噪,ImageReward分数可以作为改进LDM的可靠反馈。因此,我们提出了一种直接微调LDM的算法。算法流程可见上图右。将RM的分数视为人类的偏好损失,将梯度反向传播到去噪过程中随机挑选的后一步t(在我们的例子中t取值范围为30~40)。随机选择t而不是使用最后一步的原因是,如果只保留最后一个去噪步骤的梯度,训练被证明是非常不稳定的,结果是不好的。在实践中,为了避免快速过拟合和稳定微调,我们对ReFL Loss进行重新加权,并用Pre-training Loss进行正则化。
算法概览
算法 | MPS | RichHF |
---|---|---|
网络结构 | clip + cross attention | ViT + T5X + self-attention |
网络输出 | Aesthetics,Detail quality,Semantic alignment,Overall assessment四个维度的分数 | Plausibility,Alignment,Aesthetics,Overall四个维度的分数,Misaligned keyword sequence,Implausibility/misalignment heatmap |
数据量 | 60w(90w 图像pair) | 18k |
图片来源 | 使用KOLORS等diffusion类模型+GAN+Autoregressive共九种文生图模型生成 | 从开源的Pick-a-Pic数据集挑选 |
prompt构建 | 从开源数据集挑选6w个prompt,这对数量过少的类别使用GPT4生成额外类似的prompt,并人工标注 | 从开源的Pick-a-Pic数据集挑选 |
人工标注方式 | 每份数据3人标注,并抽20%数据进行质检 | 每份数据3人标注 |
开源情况 | 代码和数据开源,部分模型开源 | 仅数据开源 |
打标工具对比 |
应用点
- 清洗训练数据集
- 评测/挑选模型
- 作为loss加入网络训练
- 二次生成图像,修复badcase区域,优化最终生成图片效果
- 强化学习RLHF中作为reward model
- DPO算法中区分偏好/非偏好样本