tensorflow_2.2_Resnet50实现花的识别

Resnet50介绍

Resnet50与之前在Resnet34中介绍的几乎一样,唯一有区别的就是:
残差块由两层卷积变成了三层卷积,网络更深,如下:

# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

直接进入代码演示

1. 代码演示

新建train.py

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, ZeroPadding2D, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

    
def Resnet50(inputs, classes):
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=7, strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='valid')(x)

    x = block(x, filters=64, strides=1, conv_short=True)
    x = block(x, filters=64, conv_short=False)
    x = block(x, filters=64, conv_short=False)

    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)

    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)

    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    
    return x
def data_process_func(datasets):
	# 数据预处理
	# ---------------------------------- #
    #   训练集进行的数据增强操作
    #   1. rotation_range -> 随机旋转角度
    #   2. width_shift_range -> 随机水平平移
    #   3. width_shift_range -> 随机数值平移
    #   4. rescale -> 数据归一化
    #   5. shear_range -> 随机错切变换
    #   6. zoom_range -> 随机放大
    #   7. horizontal_flip -> 水平翻转
    #   8. brightness_range -> 亮度变化
    #   9. fill_mode -> 填充方式
    # ---------------------------------- #
	train_data = ImageDataGenerator(
        rotation_range=20, 
        width_shift_range=0.1, 
        height_shift_range=0.1,
        rescale=1/255.0,
        shear_range=10,
        zoom_range=0.1,
        horizontal_flip=True,
        brightness_range=(0.7, 1.3),
        fill_mode='nearest'
    )
    # ---------------------------------- #
    #   测试集数据增加操作
    #   归一化即可
    # ---------------------------------- #
    test_data = ImageDataGenerator(
        rescale=1/255
    )
    # ---------------------------------- #
    #   训练器生成器
    #   测试集生成器
    # ---------------------------------- #
    train_generator = train_data.flow_from_directory(
        f'{datasets}/train',
        target_size=(224, 224),
        batch_size=8
    )
    test_generator = test_data.flow_from_directory(
        f'{datasets}/test',
        target_size=(224, 224),
        batch_size=8
    )
# 学习率调整
def adjust_lr(epoch, lr=1e-3):
    print("Seting to %s" % (lr))
    if epoch < 6:
        return lr
    else:
        return lr * 0.93

2. 主函数

  • 设置数据集路径datasets
    链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
    提取码:bhjx
  • 设置预训练权重路径weight
    链接:https://pan.baidu.com/s/1AhsAA8ww5GurK-pWNQ4aHg
    提取码:y1c4
  • 注意:新建一个logs文件夹
if __name__ == '__main__':
    datasets = './dataset/data_flower'
    weight = './model_data/test_acc0.860-val_loss0.557-resnet50-flower.h5'
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    inputs = Input(shape=(img_size,img_size,3))
    # 构造器
    train_generator, test_generator = data_process_func(datasets)
    model = Model(inputs=inputs, outputs=Resnet50(inputs=inputs, classes=classes))
    callbackss = [
            EarlyStopping(monitor='val_loss', patience=10, verbose=1),
            ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',
                            save_weights_only=True, save_best_only=False, period=1),
            LearningRateScheduler(adjust_lr)
        ]
    print('---------->loding weight--------->')
            model.load_weights(weight, by_name=True, skip_mismatch=True)
            model.compile(optimizer=Adam(lr=1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
            history = model.fit(
            x                      = train_generator,
            validation_data        = test_generator,
            workers                = 1,
            epochs                 = epochs,
            callbacks              = callbackss
        )

3. 预测图片

新建predict.py

import tensorflow as tf
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense, ZeroPadding2D, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add


# 结构快
def block(x, filters, strides=1, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters*4, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    # 三层卷积
    x = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

    
def Resnet50(inputs, classes):
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=7, strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='valid')(x)

    x = block(x, filters=64, strides=1, conv_short=True)
    x = block(x, filters=64, conv_short=False)
    x = block(x, filters=64, conv_short=False)

    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)
    x = block(x, filters=128, conv_short=False)

    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)
    x = block(x, filters=256, conv_short=False)

    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, conv_short=False)
    x = block(x, filters=512, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    
    return x

# 这次的权重路径指向训练后之后的路径
names = os.listdir('./dataset/data_flower/test')
weight = './model_data/test_acc0.860-val_loss0.557-resnet50-flower.h5'
net = Resnet50
classes = 17
img_size = 224

def preprocess_input(x):
    x /= 255
   
    return x

def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 

inputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=net(inputs=inputs, classes=classes))
model.load_weights(weight)

while True:
    
    img_path = input('input img_path:')
    try:
        img = Image.open(img_path)
        img = cvtColor(img)
        img = img.resize((224, 224))
        image_data = np.expand_dims(preprocess_input(np.array(img, np.float32)), 0)
    except:
        print('The path is error!')
        continue
    else:
        plt.imshow(img)
        plt.axis('off')
        p =model.predict(image_data)[0]
        pred_name = names[np.argmax(p)]
        plt.title('%s:%.3f'%(pred_name, np.max(p)))
        plt.show()

效果如下:
是flower0的概率为0.801
在这里插入图片描述

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值