iGPT 原理及Pytorch代码实现

iGPT 原理及Pytorch代码实现

1.引言

无监督学习对机器学习是一个长期的挑战,在过去的一段时间内,它在自然语言方面取得了不可思议的进展。受无监学习在自然语言方面进展的启发,OpenAI尝试将这种无监督学习的思想用于图像领域,于在2020年提出一个方法——image GPT,即在大量的无标签数据集上预训练,在少量标记数据上进行微调,以更好地迁移到具体的下游任务。实验结果表明了iGPT拥有强大的图像理解能力,不仅在诸多分类数据集上取得了领先的分类效果,更惊艳的是它在图像补全上的表现。

2.iGPT原理介绍

iGPT包括两个阶段,第一个阶段预训练阶段(Pre-training),在这个阶段iGPT对比了自回归(Auto Regressive,AR)预测下一个像素和BERT掩码语言模型(Masked Language Mode,MLM)预测被mask掉的像素这两个方法。

第二个阶段为微调(fine-tune),通过在模型上添加一个分类头,将模型运用于图像分类,来衡量预训练模型提取特征的质量。

另一种衡量模型提取的特征的质量的方法是将预训练模型作为一个特征提取器,利用线性探测(Linear Probe)进行验证,基于的原理是如果模型能够比较好的提取特征,那么如果在这个特征上直接进行分类,那么分类任务应该也会取得非常好的效果。

在这里插入图片描述

3. 代码复现

代码来源
GitHub: https://github.com/teddykoker/image-gpt
作者:Teddy Koker

3.1计算质心
目的:为进行K-means聚类,计算num_clusters个质心

#num_clusters 分类数
  python3  src/compute_centroids.py --dataset mnist --num_clusters=8

利用计算出的num_clusters个质心,在进行预训练之前对每个像素点进行聚类,以达到降低分辨率的目的。

#x:一组无标签数据,即图像信息
def quantize(x, centroids):
    b, c, h, w = x.shape
    # [B, C, H, W] => [B, H, W, C]
    x = x.permute(0, 2, 3, 1).contiguous()
    x = x.view(-1, c)  # flatten to pixels
    d = squared_euclidean_distance(x, centroids) #计算平方欧式距离
    x = torch.argmin(d, 1)        #像素点到某个质心的距离最小,即将该像素点分配给最近的聚类
    x = x.view(b, h, w)              
    return x

3.2 预训练(自回归)

python3 src/run.py --dataset mnist train configs/xxs_gen.yml
#x:经过聚类之后的一组数据 
#logits:GPT-2网络的输出,logits为三维结构[28*28,batch_size,num_vocab]   (28*28为mnist数据集的图像尺寸;batch_size:样本数,可以理解成一次处理的照片数量;num_vocab:transformer中像素颜色可能值的数量),代表着每张图片的每个像素低点在num_vocab种像素颜色上的可能性
logits = self.gpt(x)
#计算模型生成值与实际中的损失,生成预训练模型。
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))
return {"val_loss": loss}

3.3 微调

python3 src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt

#如果进行的是微调,则取出预训练的模型,在有标签的目标数据集上继续训练。
if args.pretrained is not None:
    model = ImageGPT.load_from_checkpoint(args.pretrained)
    #potentially modify model for finetuning
    model.learning_rate = config["learning_rate"]
    model.classify = config["classify"]
#y:数据的标签;x:经过聚类之后的一组数据 ;
#clf_logits:一个尺寸为[batch_size,num_classes]的矩阵,代表着batch中每一个sample属于哪一类;
#logits:同预训练;clf_loss:分类损失;gen_loss:生成损失。
if self.classify:
    clf_logits, logits = self.gpt(x, classify=True)
    clf_loss = self.criterion(clf_logits, y)
    gen_loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))
    # joint loss for classification
    loss = clf_loss + gen_loss

3.4 实例——手写数字图像的补全

python3  src/sample.py models/mnist_gen.ckpt

生成下图,第1列为待补全图像,第7列为完整图像,第2-6列为模型根据第1列补全的图像。
在这里插入图片描述

#model:事先训练好的模型。
#context:一半的图像信息。
#length:一半图像所包含的像素点的个数。
def sample(model, context, length, num_samples=1, temperature=1.0):
 
    output = context.unsqueeze(-1).repeat_interleave(
        num_samples, dim=-1
    ) #将图像信息复制num_samples份,用于对同一半张图像产生num_samples种预测(如上图中2-6列共5份)

    pad = torch.zeros(1, num_samples, dtype=torch.long).cuda()  # to pad prev output  
    with torch.no_grad():
        for _ in tqdm(range(length), leave=False):
            logits = model(torch.cat((output, pad), dim=0))    
            logits = logits[-1, :, :] / temperature          
            probs = F.softmax(logits, dim=-1)                  
            pred = torch.multinomial(probs, num_samples=1).transpose(1, 0)  
            output = torch.cat((output, pred), dim=0)                       
    return output

这里我们主要思考的问题在于对于同一张不完整的图像,模型如何产生不同的预测图像?
这就需要我们了解函数
torch.multinomial(input, num_samples,replacement=False)
作用是对input的每一行做num_samples次取值,输出的张量是每一次取值时input张量对应行的下标,replacement指的是取样时是否是有放回的取样,True是有放回,False无放回。

weights = torch.Tensor([0, 10, 3, 0])
torch.multinomial(weights, 4)
尝试重复运行,发现只会有2种结果:[1 2 0 0]以及[2 1 0 0],以[1 2 0 0]这种情况居多。
这其实很好理解,第1个元素权重比第2个元素权重要大,所以先取第1个元素的概率就会大。

这就说明当我们计算出某个像素点在num_vocab种像素颜色上的概率时,经过上述函数最终确定的像素预测值具有随机性,这也解释了为什么对于同一半张图像,最终会生成不同的预测图像。

另一个困惑的点在于,在将图像送入模型之前,我们对图像信息进行聚类处理,本例中num_clusters=8,也就是将其分为8类。模型中num_vocab=16,最终经过模型处理之后,得到的是[num_pixels,num_samples,num_vocab]这样一个三维结构,表示的是各个像素在num_vocab种可能值上的权重。那这个num_clusters与num_vocab是否需要统一呢?

通过代码我们发现,虽然模型最后得到的是在num_vocab种可能值上的权重,但是由于自回归模型是由之前的信息来预测的,所以我们由某个像素点的权重(下图)可知,其在[0,num_clusters]上num_clusters种可能值(也就是下图中的0-7列)的权重大,大概率被确定为预测值。
但需要考虑num_samples,当num_samples>num_clusters,且在取值时不放回,那么,尽管权重小也是会被取到的。
在这里插入图片描述

4.参考文献

https://openai.com/blog/image-gpt/
https://zhuanlan.zhihu.com/p/352350329

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值