迁移学习 Transfer Learning

在上次的动画简介中, 我们大概了解了一些迁移学习的原理和为什么要使用迁移学习. 如果用一句话来概括迁移学习, 那务必就是: “为了偷懒, 在训练好了的模型上接着训练其他内容, 充分使用原模型的理解力”. 有时候也是为了避免再次花费特别长的时间重复训练大型模型.

CNN 通常都是大型模型, 下面我们拿 CNN 来举个例子. 我训练好了一个区分男人和女人的 CNN. 接着来了个任务, 说我下个任务是区分照片中人的年龄. 这看似完全不相干的两个模型, 但是我们却可以运用到迁移学习, 让之前那个 CNN 当我们的初始模型, 因为区分男女的 CNN 已经对人类有了理解. 基于这个理解开始训练, 总比完全重新开始训练强. 但是如果你下一个任务是区分飞机和大象. 这个 CNN 可能就没那么有用了, 因为这个 CNN 可能并没有对飞机大象有任何的理解.

这一次, 我们就来迁移一个图片分类的 CNN (VGG). 这个 VGG 在1000个类别中训练过. 我们提取这个 VGG 前面的 Conv layers, 重新组建后面的 fully connected layers, 让它做一个和分类完全不相干的事. 我们在网上下载那1000个分类数据中的猫和老虎的图片, 然后伪造一些猫和老虎长度的数据. 最后做到让迁移后的网络分辨出猫和老虎的长度 (regressor).

下载数据 

为了达到这次的目的, 我们不需要下载所有的1000个分类的所有图片, 只要找到自己感兴趣的类就好 (老虎和猫). 我选老虎和猫的目的就是因为他们是近亲, 还是有点像的, 可以增加点难度. 如果是飞机和大象的话, 学习难度就被降低了.

上图是这个网址, 你能在 Download 的那个 tag 中, 找到所有图片的 urls, 我将所有老虎和猫的 urls 文件给大家放在下面:

我们可以编一个 Python功能 逐个下载里面的图片. 这个功能我定义成 download(). 下载好后就会被放在 data 这个文件夹中了.

 

迁移学习 Transfer Learning

因为有些图片url已经过期了, 所以我自己也手动过滤了一遍, 将不是图片的和404的图片给清理掉了. 因为只有两个类, 图片不是很多, 比较好清理. 有网友说一些很多链接和图片已经”失联”, 我把我收集到的图片数据打包放在我的百度云, 如果用代码下图片感到有困难的同学们, 请直接在我百度云下载吧.

因为现在我们不是预测分类结果了, 所以我伪造了一些体长的数据. 老虎通常要比猫长, 所以它们的 distribution 就差不多是下面这种结构(单位cm).

 

原文:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-16-transfer-learning/

迁移学习 Transfer Learning

迁移 VGG 

处理好图片后, 我们可以开始弄 VGG 的 pre-trained model. 我使用的是machrisaa 改写的VGG16 的代码. 和他提供的 VGG16 train 好了的 model parameters, 你可以在这里下载 这些 parameters (有网友说这个文件下载不了,我把它放在了百度云共享了). 做好准备, 这个 parameter 文件有500+MB. 可见一般 CNN 的 model 有多大.

迁移学习 Transfer Learning

为了做迁移学习, 我对他的 tensorflow VGG16 代码进行了改写. 保留了所有 Conv 和 pooling 层, 将后面的所有 fc 层拆了, 改成可以被 train 的两层, 输出一个数字, 这个数字代表了这只猫或老虎的长度.

class Vgg16:
    def __init__(): # ...前面的层 pool5 = self.max_pool(conv5_3, 'pool5') # pool5 是最后的 conv 出来的结果 self.flatten = tf.reshape(pool5, [-1, 7*7*512]) self.fc6 = tf.layers.dense(self.flatten, 256, tf.nn.relu, name='fc6') self.out = tf.layers.dense(self.fc6, 1, name='out') 

在 self.flatten 之前的 layers, 都是不能被 train 的. 而 tf.layers.dense() 建立的 layers 是可以被 train 的. 到时候我们 train 好了, 再定义一个 Saver 来保存由 tf.layers.dense() 建立的 parameters.

class Vgg16:
    ...
    def save(self, path='./for_transfer_learning/model/transfer_learn'): saver = tf.train.Saver() saver.save(self.sess, path, write_meta_graph=False)

 

转载于:https://www.cnblogs.com/Ph-one/p/11345555.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值