Tensorflow2.0 基于ResNet18的CIFAR10图片分类

这篇博客是作者基于Tensorflow2.0使用ResNet18模型对CIFAR10数据集进行图像分类的学习笔记。文中提到数据集由Tensorflow模块内置提供,ResNet18的网络结构可调整以适应更深的网络层次。
摘要由CSDN通过智能技术生成

Tensorflow2.0 基于ResNet18的CIFAR10图片分类

前言

Tensorflow2.0的学习笔记,教材基于开源的https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book。非原创,不要杠。

备注

  1. 数据集CIFAR10,tf模块自带下载该模块功能,详情见代码
  2. 模型ResNet18,也可自己修改参数,可实现更深层次的网络
  3. 理论自行搜索

一些随笔

# 关于自定义层
class MyDense(layers.Layer):
    def __init__(self,in_dim,out_dim):
        super(MyDense, self).__init__()
        # self.kernel = self.add_variable('w',[in_dim,out_dim],trainable=True)
        # self.add_variable会返回张量𝑾的Python 引用,而变量名name 由TensorFlow 内部维护
        # 不过我觉得用add_variable显得接口混乱,还是喜欢用如下的tf.Variable
        self.kernel = tf.Variable(tf.random.normal([in_dim, out_dim]),trainable=True)

    def call(self,inputs,training=None):
        out = inputs@self.kernel
        out = tf.nn.relu(out)
        return out

net = MyDense(4,3)
print(net.trainable_variables,'\n',net.trainable_variables)
# 简单看看该layer的结构

ResNet18的实现

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers, losses, optimizers, metrics


gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
# 设置显存增长式占用
def preprocess(x, y):
    # 预处理数据
    x = 2*tf.cast(x, dtype=tf.float32) / 255.-1
    # 图像数据归一化
    y = tf.squeeze(y, axis=1)
    # 删除y的第二个维度,为啥,看一下数据集就知道
    y = tf.cast(y, dtype=tf.int32)
    # 转为int32,tf的数据格式自动转换似乎不太理想,可能是为了避免出错吧
    return x,y
(x,y), (x_test, y_test) = keras.datasets.cifar10.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x,y)).shuffle(10000).batch(256).map(preprocess)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(256).map(preprocess)

class BasicBlock(layers.Layer):
    # 残差模块
    # Page 185,自定义网络层
    def __init__(self,filter_num,stride=1):
        super(BasicBlock,self).__init__()
        self.conv1 = layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')

        self.conv2 = layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
        self.bn2 = layers.BatchNormalization()

        if stride != 1:
        	# 如果卷积步长不为1,需要对x降采样,不然x不能和残差块的输出相加
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num,(1,1),strides=stride))
        else:
            self.downsample = lambda x:x
            # 教材中的代码,不包含x通道缺失补零的操作,应该是利用了tf的维度自动转换

    def call(self,inputs,training=None):
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        identity = self.downsample(inputs)
        output = layers.add([out, identity])
        # 加法
        output = tf.nn.relu(output)
        return output


class ResNet(keras.Model):
    def __init__(self,layer_dims,num_classes=10):
        super(ResNet, self).__init__()
        self.stem = Sequential([
            layers.Conv2D(64,(3,3),strides=(1,1)),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding='same')
        ])
        self.layer1 = self.build_resblock(64,  layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_classes)

    def build_resblock(self, filter_num, blocks, stride=1):
        # 辅助函数
        res_blocks = Sequential()
        res_blocks.add(BasicBlock(filter_num,stride))
        # 只有第一个BasicBlock 的步长可能不为1,实现下采样
        for _ in range(1,blocks):
            res_blocks.add(BasicBlock(filter_num))
        return res_blocks

    def call(self,inputs,training=None):
        x = self.stem(inputs)
        x = self.layer4(
                self.layer3(
                    self.layer2(
                        self.layer1(x))))
        x = self.fc(self.avgpool(x))
        return x

res = ResNet([2,2,2,2])
criterion = losses.CategoricalCrossentropy(from_logits=True)
optimizer = optimizers.Adam(learning_rate=1e-3)
loss_meter = metrics.Mean()
acc_meter = metrics.Accuracy()

for epoch in range(10):
    for step, (x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            out = res(x)
            y_onehot = tf.one_hot(y,depth=10)
            loss = criterion(y_onehot, out)
        grads = tape.gradient(loss,res.trainable_variables)
        optimizer.apply_gradients(zip(grads,res.trainable_variables))
        loss_meter.update_state(loss.numpy())
        pred = tf.argmax(out, axis=1)
        acc_meter.update_state(y,pred)
        if (step+1)%10 == 0:
            print(f'Epoch:{epoch+1},Step:{step+1},Loss:{loss_meter.result().numpy()},Accuracy:{acc_meter.result().numpy()}')
            loss_meter.reset_states()
            acc_meter.reset_states()
res.save_weights('ResNet18_on_CIFAR10.ckpt')
# 顺便把模型的保存和加载做个笔记
# 因为不是用的Sequential封装所有结构,直接调用save方法会报错,只能保存权重
# 保存权重后想要恢复模型,必须重新定义模型结构
# 因为模型结构定义就在在本代码中,所以不需要重新写一遍
del res


res = ResNet([2,2,2,2])
# 更改layer_dims参数,可实现ResNet36等模型
res.load_weights('ResNet18_on_CIFAR10.ckpt')

# 不用acc meter
accuracy = []
for _, (x,y) in enumerate(test_db):
    out = res(x)
    pred = tf.cast(tf.argmax(out, axis=1),dtype=tf.int32)
    acc = tf.reduce_mean((tf.cast(tf.equal(y, pred),dtype=tf.float32)))
    accuracy.append(acc)
accuracy = tf.reduce_mean(accuracy)
print(accuracy.numpy())

# 使用acc meter
for _, (x,y) in enumerate(test_db):
	# 在测试集上计算准确率
    out = model(x)
    pred = tf.argmax(out,axis=1)
    acc_meter.update_state(y,pred)
print(acc_meter.result())

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值