CNN拟合-基于Keras

CNN拟合-基于Keras

此项目类似与由人脸预测年龄,输入一张图片,要求CNN预测一个近似连续的数值,利用CNN做自动特征提取。

数据增强

keras自带了ImageDataGenerator做数据增强,包含了基本的旋转,缩放,裁剪等功能。

from keras.preprocessing.image import ImageDataGenerator

# construct the training image generator for data augmentation
aug = ImageDataGenerator(featurewise_center=False, # 将输入数据的均值设置为 0,逐特征进行。
                   samplewise_center=False, # 将每个样本的均值设置为 0。

                   featurewise_std_normalization=False,# 将输入除以数据标准差,逐特征进行。
                   samplewise_std_normalization=False,# 将每个输入除以其标准差。

                   zca_whitening=False,
                   zca_epsilon=1e-06,

                   rotation_range=0.2,
                   width_shift_range=0.05,
                   height_shift_range=0.05,
                   shear_range=0.05, # 剪切强度(以弧度逆时针方向剪切角度)。
                   zoom_range=0.05, # 随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。

                   channel_shift_range=0.0,
                   fill_mode='nearest',

                   horizontal_flip=False,
                   vertical_flip=False,
                   rescale=None)


train_generator = aug.flow_from_directory('E://kidney_data/imgs_for_reg',
                        target_size=(434, 636),
                        batch_size=32,
                        save_to_dir='E://kidney_data/imgs_for_reg/aug',
                        class_mode=None,)


epochs = 100 # 总图像生成个数为 epochs*batch_size
for i in range(epochs):
    train_generator.next()

aug.flow_from_directory() 函数的第一个输入参数为原始图片路径的上一级:
在这里插入图片描述
原始图片存储路径在E://kidney_data/imgs_for_reg
在这里插入图片描述图片名称存储格式为 “姓名+拟合目标值”。

增强后的数据为
在这里插入图片描述
其中_0_*.png 表示为原始第一张图片经过增强后的结果。

拟合模型

拟合网络基本为传统的网络最后一层的输出为一个节点:
在这里插入图片描述
下面创建一个MRSEstimateModel类用于搭建网络,estimate_based_on_regression()函数用于fine-tining拟合网络。

from keras.applications import VGG16, inception_v3
from keras.layers import GlobalAveragePooling2D, Dropout, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.losses import categorical_crossentropy, mean_absolute_error, mean_squared_error

class MRSEstimateModel:
    def __init__(self, model_name = 'VGG16', input_shape=(224, 224, 3)):
        self.input_shape = input_shape
        self.model_name = model_name

    def estimate_based_on_VGG16_expected_classification(self):
        based_model = VGG16(include_top=False, weights=None, input_shape=self.input_shape)

        x = based_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dropout(0.25)(x)

        predictions = Dense(101, activation='softmax', name='predictions')(x)

        model = Model(inputs=based_model.input, outputs=predictions)
        model.summary()
        model.compile(optimizer=Adam(1e-4), loss=categorical_crossentropy)

        return model

    def estimate_based_on_regression(self):
        if not self.model_name in {'VGG16', 'inceptionv3'}:
            raise ValueError('The `model_name` should be either `VGG16`, `inceptionv3`')

        if self.model_name == 'inceptionv3':
            base_model = inception_v3(include_top=False, weights=None, input_shape=self.input_shape)

        else: # 默认为VGG16
            base_model = VGG16(include_top=False, weights=None, input_shape=self.input_shape)

        x = base_model.ouput
        x = GlobalAveragePooling2D()(x)
        x = Dropout(0.25)(x)

        x = Dense(256, activation='relu')(x)
        x = Dropout(0.5)(x)

        predictions = Dense(1, activation='relu', name='predictions')(x)

        model = Model(inputs=base_model.input, outputs=predictions)
        model.summary()
        model.compile(optimizer=Adam(1e-4), loss=mean_absolute_error)

        return model

其中,predictions = Dense(1, activation='relu', name='predictions')(x) 将网络输出为一个节点用于预测MRS。

Batch输入数据

增强后的数据不可能一次读入内存,也不能一张张读取,需要 一个batch 一个batch 的读取。需要考虑的是:

  1. 一个epoch必须读入文件夹下所有的图片,如果最后一个batch的图片不足batch_size,可直接进入下一个epoch。
  2. 每个epoch图片的顺序必须重新打乱。

解决上述问题可通过os.listdir(path) 列出path下所有文件名称,也可通过files = glob.glob(os.path.join(path, '*.png'))列出path下匹配的文件路径名。然后将所有文件名append()到一个list容器,即:虽然文件无法一次加载内存,但文件名可以。
有了全部文件的文件名,每次只需指定每轮batch开始和结束的index即可,当结束index大于len(images)时,意味着当前batch文件数量不足batch_size,此时直接打乱存放所有文件名的列表,并开始新一轮epoch。

from keras.preprocessing.image import load_img, img_to_array
from openpyxl import load_workbook
import numpy as np
import warnings
import os

from scipy import stats

class BatchKidneyDataGenerator(object):
    epoch_count = 1 # 当前属于第几个epoch
    batches_per_epoch = 0
    number_of_images = 0 # 总图片数量

    def __init__(self, path_imgs, path_MRS, batch_size, num_of_iter):
        self.path_imgs = path_imgs
        self.path_MRS = path_MRS
        self.batch_size = batch_size
        self.num_of_iter = num_of_iter

    def __get_multi_hot(self, label, std=5):
        x = np.arange(101)
        y = stats.norm(label, std).pdf(x)

        return y

    # 获取图片标签数据:私有方法
    def __get_img_labels(self, img_name):
        # 首先从MRS的EXCEL文件读成数组
        workbook = load_workbook(self.path_MRS)
        booksheet = workbook.active

        # 获取sheet页的行数据
        i = 1 # Excel从第一行开始
        MRS = []
        rows = booksheet.rows
        for row in rows:
            val = booksheet.cell(row=i, column=1).value
            MRS.append(val)
            i += 1

        # 根据MRS数组生成label
        name = img_name[img_name.find('/_') + 2:]

        index = name[0:name.find('_')]
        index = int(index)

        label = MRS[index] + np.random.randn()*0.0001
        label *= 100
        label = int( round(label, 0) )


        if (label < 0) or (label > 100):
            raise Exception("Invalid age for " + img_name)

        return label

    #
    # 获取训练数据 images, labels
    def batch_train_data_generator(self, type='reg'):

        image_names = [] # 存储所有图片地址的list

        for file in os.listdir(self.path_imgs):
            image_names.append(os.path.join(self.path_imgs, file))
        np.random.shuffle(image_names) # 打乱该list

        self.number_of_images = len(image_names) # 获取总图片数量
        print("在%s路径下总共找到%d张图片" % (self.path_imgs, self.number_of_images))

        if self.number_of_images > 0:
            self.batches_per_epoch = self.number_of_images // self.batch_size
            print("需%d个batches可完成一个epoch" % self.batches_per_epoch)

        num_of_epochs = self.num_of_iter // self.batches_per_epoch
        print("%d轮batch训练可迭代%d轮epoch(s)" % (self.num_of_iter, num_of_epochs))
        if(num_of_epochs <= 0):
            warnings.warn("所设置的batch轮数不足一个epoch,num_of_epochs=0")

        index = 0 # index of all images
        batch_count = 0  # 第几个batch

        while True: # 不断yield数据
            index += 1

            start = self.batch_size * batch_count # 每轮batch的起始index
            end = start + self.batch_size - 1 # 每轮batch的结束index

            if end >= len(image_names): # 最后一个batch 图片数量不足。直接执行新一轮 epoch
                print('剩余图片数量不足batch_size=%d,开始新一轮batch' % self.batch_size)
                index = 0
                batch_count = 0
                self.epoch_count += 1

                np.random.shuffle(image_names)  # 重新打乱

            else:
                print("当前属于第%d轮epoch,总共有%d轮完整的epoch(s)。" % (self.epoch_count, num_of_epochs))

                images = []
                if type == 'reg':
                    labels = []

                if type == 'cla':
                    labels = np.zeros((self.batch_size, 101))

                batch_count += 1
                image_path_list = image_names[start:end + 1] # 当前batch的image_path组成一个list


                ii = 0
                for idx in image_path_list:
                    # 根据image_path读取图片
                    img = load_img(idx)
                    img = img_to_array(img, data_format='channels_last') / 255.0

                    # 获取对应的MRS label
                    label = self.__get_img_labels(idx)

                    if type == 'cla':
                        labels[ii, :] = self.__get_multi_hot(label, 3)
                        ii += 1

                    if type == 'reg':
                        labels.append(label)

                    if img is None:
                        print("\n[WARRING]: 读取图片 '{}' 失败", format(idx))
                        continue
                    images.append(img)

                yield images, labels


if __name__ == '__main__':
    path = 'E://kidney_data/imgs_for_reg/aug/'
    path_MRS = 'E://kidney_data/imgs_for_reg/MRS.xlsx'
    batch_size = 8

    num_of_iter = 400
    gen = BatchKidneyDataGenerator(path, path_MRS, batch_size, num_of_iter)
    batch_gen = gen.batch_train_data_generator(type='cla')

    for i in range(num_of_iter):
        imgs, labels = next(batch_gen)

        # y = labels[2][:]
        # plt.bar(range(len(y)), y)
        # plt.show()


上述代码中,path_MRS路径下存放了未做数据增强原始图片的拟合target,是一个excel文件。
在这里插入图片描述
一张图片的增强图片,可以用同一个target。这里用原始target加上一个正态分布随机数,即:
label = MRS[index] + np.random.randn()*0.0001

主函数

from BatchKidneyDataGenerator import *
from gen_data_test import *
from MRSEstimateModel import *

import glob
import numpy as np
import matplotlib.pyplot as plt

def __test_on_batch(path_test, model):
    files = glob.glob(os.path.join(path_test, '*.png'))
    batch_size_test = 10

    preds = np.zeros((len(files), 1))
    targets = np.zeros((len(files), 1))
    iter_test = len(files) // batch_size_test
    for j in range(iter_test):
        print('------测试进度------:%f' % ((j+1) / iter_test))

        index_start = j * batch_size_test
        index_end = index_start + batch_size_test - 1

        batch_files_test = files[index_start:index_end + 1]

        imgs = []
        labels = []
        for file in batch_files_test:
            img = load_img(file)
            img = img_to_array(img, data_format='channels_last') / 255.0
            imgs.append(img)

            # 获取测试集标签数据
            label = file[-10:-4]
            label = float(label)
            label *= 100
            label = round(label, 0)

            labels.append(label)

        imgs = np.array(imgs) # (10, 224, 224, 3)
        labels = np.array(labels) # (10,)

        targets[index_start:index_end + 1, 0] = labels

        pred = model.predict_on_batch(imgs) # (10, 1)
        pred = np.reshape(pred, (batch_size_test, ))
        preds[index_start:index_end + 1, 0] = pred # (-1, 1)

    return preds, targets

if __name__ == '__main__':

    # 构建回归模型
    mrsModel = MRSEstimateModel(input_shape=(434, 636, 3))
    model = mrsModel.estimate_based_on_VGG16_regression()

    # 尝试加载本地模型参数
    model_name = "./model_reg.h5"
    try:
        model.load_weights(model_name)
        print("本地模型加载成功!")
    except:
        print("未加载到本地模型,开始训练新模型")

    batch_size = 8
    path_imgs = 'E://kidney_data/imgs_for_reg/aug/'
    path_MRS = 'E://kidney_data/imgs_for_reg/MRS.xlsx'

    num_of_iter = 1

    bkdGenerator = BatchKidneyDataGenerator(path_imgs, path_MRS, batch_size, num_of_iter)
    batch_gen = bkdGenerator.batch_train_data_generator(type='reg')

    best_perform =  .0
    # 查看模型测试集效果 每N次迭代
    N = 1
    for iter_cnt in range(num_of_iter):
        print('=====训练进度=====>>:%f' % ((iter_cnt+1) / num_of_iter))

        imgs, labels = next(batch_gen)

        arr_imgs = np.array(imgs)
        arr_labels = np.array(labels)
        arr_labels.reshape((-1, 1))

        cost = model.train_on_batch(arr_imgs, arr_labels)
        print("正在执行第%d轮迭代训练,当前批次的cost = %f" % (iter_cnt+1, cost))

        if (iter_cnt % N) == 0:
            print('------------------开始进入测试阶段------------------')

            path_test = 'E://kidney_data/imgs_for_reg/test/'

            preds, targets = __test_on_batch(path_test, model) # (-1, 1)

            arr = np.corrcoef(preds.T, targets.T)
            corr = arr[0][1]
            corr = round(corr, 2)
            print('测试集相关系数为:%f' % corr)

            if corr > best_perform: # 只保存测试集上相关系数最大的模型
                best_perform = corr
                plt.scatter(preds, targets)
                plt.show()

                # model.save_weights(model_name, overwrite=True)
                print("Saved model to disk")
            else:
                print('Will not save this model')

  • 主函数用model.train_on_batch() 训练网络,一个batch的图片以
    batch_size*rows*columns*height _channels格式的numpy数组输入,labels是一个batch_size*1的numpy数组。

  • 测试数据通过model.predict_on_batch() 测试集数据的组织和训练集类似,只是少了get_label()的步骤。

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值