【Tensorflow 2.0 正式版教程】ImageNet(二)模型加载与训练

前面的教程都只在小模型、小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet。教程会分为三部分:数据增强、模型加载与训练、模型测试,最终在ResNet50上可以达到77.72%的top-1准确率,复现出了ResNet原文的结果。

完整的代码可以在我的github上找到。https://github.com/Apm5/ImageNet_Tensorflow2.0

提供ResNet-18和ResNet-50的预训练模型,以供大家做迁移使用。
链接:https://pan.baidu.com/s/1nwvkt3Ei5Hp5Pis35cBSmA
提取码:y4wo

还提供百度云链接的ImageNet原始数据,但是这份资源只能创建临时链接以供下载,有需要的还请私信联系。下面开始正文。

模型加载与训练

初始化

github项目中提供了tensorflow 2.0版本实现的ResNet,包括各种层数18、34、50、101和152,以及ResNet后续改进的v2版本以供直接调用。

from model.ResNet import ResNet
model = ResNet(50)

或者也可以使用官方实现的经典模型,具体参考keras applications

from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=None)

训练过程中,可以通过model.save_weights()保存权重,也可以在中断训练时通过model.load_weights()加载权重继续训练。

数据迭代器

采用tf.data.Dataset()并行加载图像并进行数据增强

def train_iterator(list_path=c.train_list_path):
    images, labels = load_list(list_path, c.train_data_path)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.shuffle(len(images))
    dataset = dataset.repeat()
    dataset = dataset.map(lambda x, y: tf.py_function(load_image, inp=[x, y, True, False], Tout=[tf.float32, tf.float32]),
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(c.batch_size)
    it = dataset.__iter__()
    return it

调用该函数得到迭代器后,可以实现GPU进行图计算时,CPU并行加载并处理图像。

images, labels = data_iterator.next()
ce, prediction = train_step(model, images, labels, optimizer)
模型训练

其中train_step完成前向计算、梯度反向传播和参数更新。

@tf.function
def train_step(model, images, labels, optimizer):
    with tf.GradientTape() as tape:
        prediction = model(images, training=True)
        ce = cross_entropy_batch(labels, prediction, label_smoothing=c.label_smoothing)
        l2 = l2_loss(model)
        loss = ce + l2
        gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return ce, prediction

可以通过@tf.function控制代码进行静态图或动态图模式计算,注释掉@tf.function修饰可以使tensorflow进入动态图模式,可以直接在网络中print中间层结果进行调试。开启修饰后进行静态图计算,可以极大的提升网络的计算速度。

损失函数

模型的损失这里包括两部分:交叉熵和参数正则化。
交叉熵需要计算一个batch内的多组数据的平均值,label_smoothing可以微弱的增强模型泛化性。

def cross_entropy_batch(y_true, y_pred, label_smoothing=0.0):
    cross_entropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred, label_smoothing=label_smoothing)
    cross_entropy = tf.reduce_mean(cross_entropy)
    return cross_entropy

参数正则化为

def l2_loss(model, weights=c.weight_decay):
    variable_list = []
    for v in model.trainable_variables:
        if 'kernel' in v.name:
            variable_list.append(tf.nn.l2_loss(v))
    return tf.add_n(variable_list) * weights

这里只统计卷积核,对于bn层等其他参数不作约束。

优化器与变学习率

优化器选用sgd,
optimizer = optimizers.SGD(learning_rate=learning_rate_schedules, momentum=0.9, nesterov=True),学习率变化分为warm up阶段和余弦下降阶段,warm up指网络训练初期学习率从0线性增长到最大学习率,余弦下降是让学习率大致遵循:前期维持大学习率,中期学习率线性下降,后期维持小学习率。

learning_rate_schedules = optimizers.schedules.PolynomialDecay(initial_learning_rate=c.minimum_learning_rate,
                                                               decay_steps=c.warm_iterations,
                                                               end_learning_rate=c.initial_learning_rate)

learning_rate_schedules = tf.keras.experimental.CosineDecay(initial_learning_rate=c.initial_learning_rate,
                                                            decay_steps=c.epoch_num * c.iterations_per_epoch,
                                                            alpha=c.minimum_learning_rate)

tf.keras.experimental中还有其他一些官方实现的变学习率策略,可自行了解。

代码使用

在github的项目中,直接执行train文件即可。

python train.py

训练的各种配置均在config.py中设置,如学习率、训练轮次、数据位置、增强策略等。
我的硬件配置是CPU i7 6850K @ 3.6GHz,显卡TITAN Xp 12G,对于默认的配置训练ResNet50可以达到大约每秒2个batch的速度,训练50轮ImageNet的大约需要3天的时间。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值