tensorflow2.2_实现Resnet34_花的识别

残差块

    Resnet是由许多残差块组成的,而残差块可以解决网络越深,效果越差的问题
    残差块的结构如下图所示。
在这里插入图片描述
其中:

  1. weight layer表示卷积层,用于特征提取。
  2. F ( x ) F(x) F(x)表示经过两层卷积得到的结果。
  3. x x x表示恒等映射
  4. F ( x ) + x F(x)+x F(x)+x表示经过两层卷积后与之前的卷积层进行结合。

所以 F ( x ) F(x) F(x) x x x代表的是相同的信号。

  • 作用:将浅层网络的信号递给深层网络,使网络得到更好的结果。

批量归一化(Batch Normalization)

    我们暂时简称它为BN。
    BN可以对网络中的每一层的输入,输出特征进行标准化处理,将他们变成均值为0,方差为1的分布。
标准化的公式如下:
在这里插入图片描述
其中:

  • x n x_n xn表示第n个维度的数据
  • μ μ μ为该维度的平均值
  • σ σ σ表示该维度的方差
  • ϵ ϵ ϵ表示一个很小很小的值,防止分母为零

BN的主要作用:

  1. 加快模型的收敛速度。
  2. 增强正则化的作用。

Resnet34网络结构

如下图:
在这里插入图片描述
其中:

  • 7x7 conv 表示7x7大小的卷积核的窗口
  • 3x3 conv 表示3x3大小的卷积核的窗口
  • 64、128、256、512表示特征图的数量
  • /2 表示卷积核的步长,没写就默认为1
  • 虚线表示无法直接连接,因为生成的特征图数量是不一样的,也就是说shape是不一样的,一般是使用步长为2、大小为1的卷积核来对输入信号进行特征提取,使输入信号和输出信号的shape一致,再进行结合。

代码演示

1. 导入相关库

可新建一个train.py文件

from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam

2. 定义网络结构

# 结构快
def block(x, filters, strides=2, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    
    # 2层卷积
    x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)



    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

def Resnet34(inputs, classes):
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
    x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    
    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=128, strides=1, conv_short=False)
    
    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    
    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, strides=1, conv_short=False)
    x = block(x, filters=512, strides=1, conv_short=False)
    
    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    return x

3. 定义超参数

数据集:
链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
提取码:bhjx
复制这段内容后打开百度网盘手机App,操作更方便哦
权重文件:
链接:https://pan.baidu.com/s/1JotFy2G5wdThj409K87ExA
提取码:4vi5
复制这段内容后打开百度网盘手机App,操作更方便哦

数据集格式:
test和train文件夹里面需要按类别存放,如下

- dataset
	- data1_dog_cat
		- test
			- cat
				- cat.10000.jpg
				- cat.10001.jpg
				- ...
			- dog
				- dog.10000.jpg
				- dog.10001.jpg
		- train
			- cat
				- cat.0.jpg
				- cat.1.jpg
				- ...
			- dog
				- dog.0.jpg
				- dog.1.jpg
				- ...
classes = 17 # 需要分类的类别
batch_size = 16 # 批次大小
epochs = 100 # 轮次
img_size = 224 # 图片大小
lr = 1e-3 # 学习率大小
datasets = './dataset/data_flower' # 数据集的路径
weight = './model_data/test_acc0.794-resnet18val_loss0.857-flower.h5' # 权重文件的路径
# ------------------------------- #
#	我们使用加载权重的方式进行训练,效果会更好

4. 定义数据处理的构造器

train_data = ImageDataGenerator(
        rotation_range=20, 
        width_shift_range=0.1, 
        height_shift_range=0.1,
        rescale=1/255.0,
        shear_range=10,
        zoom_range=0.1,
        horizontal_flip=True,
        brightness_range=(0.7, 1.3),
        fill_mode='nearest'
    )

test_data = ImageDataGenerator(
        rescale=1/255
    )

train_generator = train_data.flow_from_directory(
        f'{datasets}/train',
        target_size=(img_size, img_size),
        batch_size=batch_size
    )
    
test_generator = test_data.flow_from_directory(
    f'{datasets}/test',
    target_size=(img_size, img_size),
    batch_size=batch_size
)

5. 定义学习率回调函数

def adjust_lr(epoch, lr=lr):
    print("Seting to %s" % (lr))
    if epoch < 10:
        return lr
    else:
        return lr * 0.93

6. 主函数

需要新建一个logs文件夹,保存权重文件。

if __name__ == '__main__':
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    
    inputs = Input(shape=(img_size,img_size,3))
    model = Model(inputs=inputs, outputs=Resnet34(inputs=inputs, classes=classes))
    callbackss = [
            EarlyStopping(monitor='val_loss', patience=10, verbose=1),
            ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',
                            save_weights_only=True, save_best_only=False, period=1),
            LearningRateScheduler(adjust_lr)
        ]
    

    if weight:
        print('---------->loding weight--------->')
        model.load_weights(weight, by_name=True, skip_mismatch=True)
        model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['accuracy'])
        history = model.fit(
        x                      = train_generator,
        validation_data        = test_generator,
        workers                = 1,
        epochs                 = epochs,
        callbacks              = callbackss
    )
    else:
        print('---------->epoch0 starting--------->')
        model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['accuracy'])
        history = model.fit(
            x                    = train_generator,
            validation_data      = test_generator,
            epochs               = epochs,
            workers              = 1,
            callbacks            = callbackss
        )
     

7. 预测图片

可新建一个predict.py文件
导入库


from PIL import Image
from tensorflow.keras.layers import Input
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add

定义归一化函数

def preprocess_input(x):
    x /= 255
   
    return x

定义转RGB函数

def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 

定义参数
注意:weight需要指定训练好的权重文件

datasets = './dataset/data_flower/test'
names = os.listdir(datasets)
weight = './model_data/test_acc0.860-val_loss0.599-resnet34-flower.h5'
net = Resnet34
classes = 17
img_size = 224

定义网络模型

# 结构快
def block(x, filters, strides=2, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    
    # 2层卷积
    x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)



    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

def Resnet34(inputs, classes):
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
    x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    
    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=128, strides=1, conv_short=False)
    
    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=256, strides=1, conv_short=False)
    
    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, strides=1, conv_short=False)
    x = block(x, filters=512, strides=1, conv_short=False)
    
    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    return x
inputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=Resnet34(inputs=inputs, classes=classes))
# -------------------------------------------------#
#   载入模型
# -------------------------------------------------#
model.load_weights(weight)
while True:
    
    img_path = input('input img_path:')
    try:
        img = Image.open(img_path)
        img = cvtColor(img)
        img = img.resize((224, 224))
        image_data = np.expand_dims(preprocess_input(np.array(img, np.float32)), 0)
    except:
        print('The path is error!')
        continue
    else:
        plt.imshow(img)
        plt.axis('off')
        p =model.predict(image_data)[0]
        pred_name = names[np.argmax(p)]
        plt.title('%s:%.3f'%(pred_name, np.max(p)))
        plt.show()

效果如下:
在这里插入图片描述

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值