经典分类模型回顾17-Resnet实现水果分类(Tensorflow2.0)

ResNet是何凯明等人提出的深度学习网络架构,通过残差块和跨层连接解决深度网络训练难题,允许更深层次的网络训练,提高性能。文章介绍了ResNet的结构,包括瓶颈层和全局平均池化,以及如何使用TensorFlow和Keras进行实现和训练。
摘要由CSDN通过智能技术生成

ResNet(Residual Network)是由何凯明(Kaiming He)等人提出的深度学习网络架构,是ImageNet竞赛中历史最好的结果之一。ResNet的主要特点是在深度较深的网络中,通过特殊的残差块(residual block)和跨层连接(skip connection)的方式,使得网络训练更加容易,使得网络深度可以进一步增加,从而获得更好的性能表现。

在ResNet中,每个残差块包含两个子层,每个子层都以一个卷积层和一个批量归一化层为主,其中第二个子层还包括了一个激活函数。在残差块中,通过跨层连接将输入数据直接传递到输出数据的过程中,网络直接“学习”残差的方式,从而避免了由于深度增加而产生的梯度消失或梯度爆炸等问题,使得网络的收敛速度更快、训练效果更好。

ResNet还引入了一种称为“bottleneck”的结构,通过在残差块中增加一个额外的瓶颈层,使得网络在保持较少的计算复杂度的情况下,能够更好地适应更深的网络结构。此外,ResNet还采用了全局平均池化(Global Average Pooling)的方式,减少了全连接层的数量,使得模型更加轻量化。

总之,ResNet是一种引入了残差结构的深度学习网络,通过跨层连接和特殊的残差块,使得网络更加易于训练,进而使得网络深度可以进一步增加,提升了网络的性能表现。

需要安装必要的库:tensorflow、matplotlib。

```python
!pip install tensorflow
!pip install matplotlib
```

导入必要的库:

```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
```

接下来,加载数据集,将训练集、验证集和测试集分别存放到不同的文件夹中,并使用ImageDataGenerator对图像进行数据增强。

```python
# 加载数据集
train_data_dir = "path/to/train/folder"
valid_data_dir = "path/to/valid/folder"
test_data_dir = "path/to/test/folder"

# 数据增强
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

valid_datagen = ImageDataGenerator(rescale=1./255)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(100, 100),
    batch_size=32,
    class_mode='categorical')

valid_generator = valid_datagen.flow_from_directory(
    valid_data_dir,
    target_size=(100, 100),
    batch_size=32,
    class_mode='categorical')

test_generator = test_datagen.flow_from_directory(
    test_data_dir,
    target_size=(100, 100),
    batch_size=32,
    class_mode='categorical')
```

接下来,定义ResNet模型。这里使用的是ResNet50,可以通过调整`depth`参数来改变网络深度。

```python
# 定义ResNet
def identity_block(inputs, filters):
    filters1, filters2, filters3 = filters

    x = layers.Conv2D(filters1, (1, 1), padding='valid')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, (1, 1), padding='valid')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Add()([x, inputs])
    x = layers.Activation('relu')(x)

    return x

def conv_block(inputs, filters, strides):
    filters1, filters2, filters3 = filters

    x = layers.Conv2D(filters1, (1, 1), strides=strides, padding='valid')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, (1, 1), padding='valid')(x)
    x = layers.BatchNormalization()(x)

    shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, padding='valid')(inputs)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)

    return x

def ResNet50(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.ZeroPadding2D((3, 3))(inputs)
    x = layers.Conv2D(64, (7, 7), strides=(2, 2))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)

    # stage 2
    x = conv_block(x, filters=[64, 64, 256], strides=(1, 1))
    x = identity_block(x, filters=[64, 64, 256])
    x = identity_block(x, filters=[64, 64, 256])

    # stage 3
    x = conv_block(x, filters=[128, 128, 512], strides=(2, 2))
    x = identity_block(x, filters=[128, 128, 512])
    x = identity_block(x, filters=[128, 128, 512])
    x = identity_block(x, filters=[128, 128, 512])

    # stage 4
    x = conv_block(x, filters=[256, 256, 1024], strides=(2, 2))
    x = identity_block(x, filters=[256, 256, 1024])
    x = identity_block(x, filters=[256, 256, 1024])
    x = identity_block(x, filters=[256, 256, 1024])
    x = identity_block(x, filters=[256, 256, 1024])
    x = identity_block(x, filters=[256, 256, 1024])

    # stage 5
    x = conv_block(x, filters=[512, 512, 2048], strides=(2, 2))
    x = identity_block(x, filters=[512, 512, 2048])
    x = identity_block(x, filters=[512, 512, 2048])

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs, outputs)

    return model

# 实例化模型
model = ResNet50(input_shape=(100, 100, 3), num_classes=60)
```

接下来,编译模型并训练。

```python
# 编译模型
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
              loss=keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
history = model.fit(train_generator,
                    epochs=50,
                    validation_data=valid_generator)
```

最后,评估模型并绘制准确率和损失曲线。

```python
# 评估模型
test_loss, test_acc = model.evaluate(test_generator)
print('Test accuracy:', test_acc)

# 绘制准确率和损失曲线
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
```

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

share_data

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值