使用自己生成的OCR数据集进行迁移学习


为了防止遗忘,将实验过程记录于此。

数据集生成

在进行深度学习的过程中,不论是视频教程还是书籍的示例代码中,常常都是使用已经封装好的经典数据集进行示教演示的,但是为了将神经网络模型应用于自己的研究领域,需要使用自己研究领域的的数据集去训练神经网络。下面介绍生成一个英文和数字的OCR数据集的全过程。
首先,新建文件夹,命名为OCR-sets(命名可随意)。

挑选所需要的字体

本次任务需要生成的是不同字体的英文和数字的字符样本集,首先需要挑选出所需的字体文件。对于Windows10系统的用户来说,系统的字体文件一般保存在“C:\Windows\Fonts”文件夹下。
Windows10中的字体文件

在OCR-sets文件夹下新建文件夹“Chinese_fonts”(参考原博主.的代码时为了调试方便使用了同样的命名)。根据自己项目的需求,选择相应的字体,存入“Chinese_fonts”文件夹中。
项目所需字体

生成(ID:字符)映射表文件

此步骤的目的是生成(ID:字符)映射表文件,ID即字符的类别代号,用于后续的字符生成过程。本次任务需要识别的内容为:英文字母大、小写和数字,共计62个类别。
生成的映射表文件
下面代码的作用是生成OCR-set.txt文件,并将OCR-set.txt中的内容读入字典,使用pickle.dump将字典内容序列化输出,生成OCR-sets.txt文件,该文件即为(ID:字符)映射表文件。
OCR-sets.txt文件

# 本程序是用于生成OCR-set文件,其作用是将字符与label组成键值对,如:ID:char
# 用于后续生成字符训练样本

import os
import pickle
from os.path import exists

file = open("OCR-set.txt", 'w', encoding='utf-8')

i = 0

w_string = ''
for j in range(ord("A"), ord("Z")+1):
    w_string = str(i).zfill(3) + ':' + chr(j) + '\n'
    file.write(w_string)
    i = i + 1

for j in range(ord("a"), ord("z")+1):
    w_string = str(i).zfill(3) + ':' + chr(j) + '\n'
    file.write(w_string)
    i = i + 1

for j in range(0, 10):
    w_string = str(i).zfill(3) + ':' + chr(j + ord('0')) + '\n'
    file.write(w_string)
    i = i + 1

file.close()

# 将上述生成的文件读入并存入字典
txt_dict = {}
fopen = open("OCR-set.txt", 'r')
for line in fopen.readlines():
    line = str(line).replace("\n", "")  # 注意,必须是双引号,找了大半个小时,发现是这个问题。。
    txt_dict[line.split(':', 1)[0]] = line.split(':', 1)[1]
    # split()函数用法,逗号前面是以什么来分割,后面是分割成n+1个部分,且以数组形式从0开始
    # 初学python,感觉这样表达会理解一点。。
fopen.close()
print(len(txt_dict))

# 使用pickle.dump将字典内容序列化输出
file2 = open("OCR-sets.txt", 'wb')
pickle.dump(txt_dict, file2)
file2.close()

生成OCR字符样本

以下代码来自于这篇博客,根据自己的项目内容进行部分修改,即可生成所需样本集。

#! /usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
import pickle
import argparse
from argparse import RawTextHelpFormatter
import fnmatch
import os
import cv2
import json
import random
import numpy as np
import shutil
import traceback
import copy


class dataAugmentation(object):
    def __init__(self, noise=True, dilate=True, erode=True):
        self.noise = noise
        self.dilate = dilate
        self.erode = erode

    @classmethod
    def add_noise(cls, img):
        for i in range(20):  # 添加点噪声
            temp_x = np.random.randint(0, img.shape[0])
            temp_y = np.random.randint(0, img.shape[1])
            img[temp_x][temp_y] = 255
        return img

    @classmethod
    def add_erode(cls, img):
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        img = cv2.erode(img, kernel)
        return img

    @classmethod
    def add_dilate(cls, img):
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        img = cv2.dilate(img, kernel)
        return img

    def do(self, img_list=[]):
        aug_list = copy.deepcopy(img_list)
        for i in range(len(img_list)):
            im = img_list[i]
            if self.noise and random.random() < 0.5:
                im = self.add_noise(im)
            if self.dilate and random.random() < 0.5:
                im = self.add_dilate(im)
            elif self.erode:
                im = self.add_erode(im)
            aug_list.append(im)
        return aug_list


# 对字体图像做等比例缩放
class PreprocessResizeKeepRatio(object):

    def __init__(self, width, height):
        self.width = width
        self.height = height

    def do(self, cv2_img):
        max_width = self.width
        max_height = self.height

        cur_height, cur_width = cv2_img.shape[:2]

        ratio_w = float(max_width) / float(cur_width)
        ratio_h = float(max_height) / float(cur_height)
        ratio = min(ratio_w, ratio_h)

        new_size = (min(int(cur_width * ratio), max_width),
                    min(int(cur_height * ratio), max_height))

        new_size = (max(new_size[0], 1),
                    max(new_size[1], 1),)

        resized_img = cv2.resize(cv2_img, new_size)
        return resized_img


# 查找字体的最小包含矩形
class FindImageBBox(object):
    def __init__(self, ):
        pass

    def do(self, img):
        height = img.shape[0]
        width = img.shape[1]
        v_sum = np.sum(img, axis=0)
        h_sum = np.sum(img, axis=1)
        left = 0
        right = width - 1
        top = 0
        low = height - 1
        # 从左往右扫描,遇到非零像素点就以此为字体的左边界
        for i in range(width):
            if v_sum[i] > 0:
                left = i
                break
        # 从右往左扫描,遇到非零像素点就以此为字体的右边界
        for i in range(width - 1, -1, -1):
            if v_sum[i] > 0:
                right = i
                break
        # 从上往下扫描,遇到非零像素点就以此为字体的上边界
        for i in range(height):
            if h_sum[i] > 0:
                top = i
                break
        # 从下往上扫描,遇到非零像素点就以此为字体的下边界
        for i in range(height - 1, -1, -1):
            if h_sum[i] > 0:
                low = i
                break
        return (left, top, right, low)


# 把字体图像放到背景图像中
class PreprocessResizeKeepRatioFillBG(object):

    def __init__(self, width, height,
                 fill_bg=False,
                 auto_avoid_fill_bg=True,
                 margin=None):
        self.width = width
        self.height = height
        self.fill_bg = fill_bg
        self.auto_avoid_fill_bg = auto_avoid_fill_bg
        self.margin = margin

    @classmethod
    def is_need_fill_bg(cls, cv2_img, th=0.5, max_val=255):
        image_shape = cv2_img.shape
        height, width = image_shape
        if height * 3 < width:
            return True
        if width * 3 < height:
            return True
        return False

    @classmethod
    def put_img_into_center(cls, img_large, img_small, ):
        width_large = img_large.shape[1]
        height_large = img_large.shape[0]

        width_small = img_small.shape[1]
        height_small = img_small.shape[0]

        if width_large < width_small:
            raise ValueError("width_large <= width_small")
        if height_large < height_small:
            raise ValueError("height_large <= height_small")

        start_width = (width_large - width_small) // 2
        start_height = (height_large - height_small) // 2

        img_large[start_height:start_height + height_small,
        start_width:start_width + width_small] = img_small
        return img_large

    def do(self, cv2_img):
        # 确定有效字体区域,原图减去边缘长度就是字体的区域
        if self.margin is not None:
            width_minus_margin = max(2, self.width - self.margin)
            height_minus_margin = max(2, self.height - self.margin)
        else:
            width_minus_margin = self.width
            height_minus_margin = self.height

        cur_height, cur_width = cv2_img.shape[:2]
        if len(cv2_img.shape) > 2:
            pix_dim = cv2_img.shape[2]
        else:
            pix_dim = None

        preprocess_resize_keep_ratio = PreprocessResizeKeepRatio(
            width_minus_margin,
            height_minus_margin)
        resized_cv2_img = preprocess_resize_keep_ratio.do(cv2_img)

        if self.auto_avoid_fill_bg:
            need_fill_bg = self.is_need_fill_bg(cv2_img)
            if not need_fill_bg:
                self.fill_bg = False
            else:
                self.fill_bg = True

        ## should skip horizontal stroke
        if not self.fill_bg:
            ret_img = cv2.resize(resized_cv2_img, (width_minus_margin,
                                                   height_minus_margin))
        else:
            if pix_dim is not None:
                norm_img = np.zeros((height_minus_margin,
                                     width_minus_margin,
                                     pix_dim),
                                    np.uint8)
            else:
                norm_img = np.zeros((height_minus_margin,
                                     width_minus_margin),
                                    np.uint8)
            # 将缩放后的字体图像置于背景图像中央
            ret_img = self.put_img_into_center(norm_img, resized_cv2_img)

        if self.margin is not None:
            if pix_dim is not None:
                norm_img = np.zeros((self.height,
                                     self.width,
                                     pix_dim),
                                    np.uint8)
            else:
                norm_img = np.zeros((self.height,
                                     self.width),
                                    np.uint8)
            ret_img = self.put_img_into_center(norm_img, ret_img)
        return ret_img


# 检查字体文件是否可用
class FontCheck(object):

    def __init__(self, lang_chars, width=32, height=32):
        self.lang_chars = lang_chars
        self.width = width
        self.height = height

    def do(self, font_path):
        width = self.width
        height = self.height
        try:
            for i, char in enumerate(self.lang_chars):
                img = Image.new("RGB", (width, height), "black")  # 黑色背景
                draw = ImageDraw.Draw(img)
                font = ImageFont.truetype(font_path, int(width * 0.9), )
                # 白色字体
                draw.text((0, 0), char, (255, 255, 255),
                          font=font)
                data = list(img.getdata())
                sum_val = 0
                for i_data in data:
                    sum_val += sum(i_data)
                if sum_val < 2:
                    return False
        except:
            print("fail to load:%s" % font_path)
            traceback.print_exc(file=sys.stdout)
            return False
        return True


# 生成字体图像
class Font2Image(object):

    def __init__(self,
                 width, height,
                 need_crop, margin):
        self.width = width
        self.height = height
        self.need_crop = need_crop
        self.margin = margin

    def do(self, font_path, char, rotate=0):
        find_image_bbox = FindImageBBox()
        # 黑色背景
        img = Image.new("RGB", (self.width, self.height), "black")
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(font_path, int(self.width * 0.7), )
        # 白色字体
        draw.text((0, 0), char, (255, 255, 255),
                  font=font)
        if rotate != 0:
            img = img.rotate(rotate)
        data = list(img.getdata())
        sum_val = 0
        for i_data in data:
            sum_val += sum(i_data)
        if sum_val > 2:
            np_img = np.asarray(data, dtype='uint8')
            np_img = np_img[:, 0]
            np_img = np_img.reshape((self.height, self.width))
            cropped_box = find_image_bbox.do(np_img)
            left, upper, right, lower = cropped_box
            np_img = np_img[upper: lower + 1, left: right + 1]
            if not self.need_crop:
                preprocess_resize_keep_ratio_fill_bg = \
                    PreprocessResizeKeepRatioFillBG(self.width, self.height,
                                                    fill_bg=False,
                                                    margin=self.margin)
                np_img = preprocess_resize_keep_ratio_fill_bg.do(
                    np_img)
            # cv2.imwrite(path_img, np_img)
            return np_img
        else:
            print("img doesn't exist.")


# 注意,chinese_labels里面的映射关系是:(ID:汉字)
def get_label_dict():
    f = open('./OCR-sets.txt', 'rb')
    label_dict = pickle.load(f)
    f.close()
    return label_dict


def args_parse():
    # 解析输入参数
    parser = argparse.ArgumentParser(
        description=description, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--out_dir', dest='out_dir',
                        default=None, required=True,
                        help='write a caffe dir')
    parser.add_argument('--font_dir', dest='font_dir',
                        default=None, required=True,
                        help='font dir to to produce images')
    parser.add_argument('--test_ratio', dest='test_ratio',
                        default=0.2, required=False,
                        help='test dataset size')
    parser.add_argument('--width', dest='width',
                        default=None, required=True,
                        help='width')
    parser.add_argument('--height', dest='height',
                        default=None, required=True,
                        help='height')
    parser.add_argument('--no_crop', dest='no_crop',
                        default=True, required=False,
                        help='', action='store_true')
    parser.add_argument('--margin', dest='margin',
                        default=0, required=False,
                        help='', )
    parser.add_argument('--rotate', dest='rotate',
                        default=0, required=False,
                        help='max rotate degree 0-45')
    parser.add_argument('--rotate_step', dest='rotate_step',
                        default=0, required=False,
                        help='rotate step for the rotate angle')
    parser.add_argument('--need_aug', dest='need_aug',
                        default=False, required=False,
                        help='need data augmentation', action='store_true')
    args = vars(parser.parse_args())
    return args


if __name__ == "__main__":

    description = '''
python gen_printed_char.py --out_dir ./dataset \
			--font_dir ./chinese_fonts \
			--width 30 --height 30 --margin 4 --rotate 30 --rotate_step 1
    '''
    options = args_parse()

    out_dir = os.path.expanduser(options['out_dir'])
    font_dir = os.path.expanduser(options['font_dir'])
    test_ratio = float(options['test_ratio'])
    width = int(options['width'])
    height = int(options['height'])
    need_crop = not options['no_crop']
    margin = int(options['margin'])
    rotate = int(options['rotate'])
    need_aug = options['need_aug']
    rotate_step = int(options['rotate_step'])
    train_image_dir_name = "train"
    test_image_dir_name = "test"

    # 将dataset分为train和test两个文件夹分别存储
    train_images_dir = os.path.join(out_dir, train_image_dir_name)
    test_images_dir = os.path.join(out_dir, test_image_dir_name)

    if os.path.isdir(train_images_dir):
        shutil.rmtree(train_images_dir)
    os.makedirs(train_images_dir)

    if os.path.isdir(test_images_dir):
        shutil.rmtree(test_images_dir)
    os.makedirs(test_images_dir)

    # 将汉字的label读入,得到(ID:汉字)的映射表label_dict
    label_dict = get_label_dict()

    char_list = []  # 汉字列表
    value_list = []  # label列表
    for (value, chars) in label_dict.items():
        print(value, chars)
        char_list.append(chars)
        value_list.append(value)

    # 合并成新的映射关系表:(汉字:ID)
    lang_chars = dict(zip(char_list, value_list))
    font_check = FontCheck(lang_chars)

    if rotate < 0:
        roate = - rotate

    if rotate > 0 and rotate <= 45:
        all_rotate_angles = []
        for i in range(0, rotate + 1, rotate_step):
            all_rotate_angles.append(i)
        for i in range(-rotate, 0, rotate_step):
            all_rotate_angles.append(i)
        # print(all_rotate_angles)

    # 对于每类字体进行小批量测试
    verified_font_paths = []
    ## search for file fonts
    for font_name in os.listdir(font_dir):
        path_font_file = os.path.join(font_dir, font_name)
        if font_check.do(path_font_file):
            verified_font_paths.append(path_font_file)

    font2image = Font2Image(width, height, need_crop, margin)

    for (char, value) in lang_chars.items():  # 外层循环是字
        image_list = []
        print(char, value)
        # char_dir = os.path.join(images_dir, "%0.5d" % value)
        for j, verified_font_path in enumerate(verified_font_paths):  # 内层循环是字体
            if rotate == 0:
                image = font2image.do(verified_font_path, char)
                image_list.append(image)
            else:
                for k in all_rotate_angles:
                    image = font2image.do(verified_font_path, char, rotate=k)
                    image_list.append(image)

        if need_aug:
            data_aug = dataAugmentation()
            image_list = data_aug.do(image_list)

        test_num = len(image_list) * test_ratio
        random.shuffle(image_list)  # 图像列表打乱
        count = 0
        for i in range(len(image_list)):
            img = image_list[i]
            # print(img.shape)
            if count < test_num:
                char_dir = os.path.join(test_images_dir, "%0.5d" % int(value))
            else:
                char_dir = os.path.join(train_images_dir, "%0.5d" % int(value))

            if not os.path.isdir(char_dir):
                os.makedirs(char_dir)

            path_image = os.path.join(char_dir, "%d.png" % count)
            cv2.imwrite(path_image, img)
            count += 1

在pycharm的终端窗口中输入以下命令,即可开始生成:

python gen_printed_char.py --out_dir ./dataset2 --font_dir ./chinese_fonts --width 100 --height 100 --margin 10 --rotate 30 --rotate_step 1 --need_aug

生成的OCR样本集展示

训练集

训练样本集

训练集:大写字母E
训练集E
训练集:小写字母e
训练集e
训练集:数字5
训练集5

测试集

测试集:
测试集
测试集:小写字母m
测试集m

迁移学习训练

生成字符样本的(路径 标签)映射表

当我们已经有生成好的OCR样本集后,训练时会将图片加载到内存中去,因此,需要将每个需要参与训练的图片的保存路径以及其对应的标签存储在一个文本文件中。
以下代码作用是:遍历数据集存储文件夹下的每一个子文件夹中的每一个图片文件,并以子文件夹名称的后两位作为该类字符样本的标签。

import os
import re

def subdir_list(dirname):
    """获取目录下所有子目录名
    @param dirname: str 目录的完整路径
    @return: list(str) 所有子目录完整路径组成的列表
    """
    return list(filter(os.path.isdir,
                       map(lambda filename: os.path.join(dirname, filename),
                           os.listdir(dirname))
                       ))


def file_list(dirname, ext='.png'):
    """获取目录下所有特定后缀的文件
    @param dirname: str 目录的完整路径
    @param ext: str 后缀名, 以点号开头
    @return: list(str) 所有子文件名(不包含路径)组成的列表
    """
    return list(filter(lambda filename: os.path.splitext(filename)[1] == ext,
                       os.listdir(dirname)))

if __name__ == "__main__":
    dirs = subdir_list("C:\\Users\\leslie\\Desktop\\OCR-sets\\dataset2")

    for dir in dirs:
        print("*********************dir name:************************")
        print(dir)

        dir_name = re.findall("test", dir)
        if len(dir_name) != 0:
            dirs2 = subdir_list(dir)
            file = open("test.txt", 'w')
            for dir2 in dirs2:
                print("*********************dir name:************************")
                print(dir2)
                files = file_list(dir2)
                for file_name in files:
                    w_str = dir2.replace("\\", "\\\\") + '\\\\' + file_name + ' ' + dir2[-2:] + '\n'
                    file.write(w_str)
                print("-----------------------dir end-------------------------")
            file.close()
        else:
            dirs2 = subdir_list(dir)
            file = open("train.txt", 'w')
            for dir2 in dirs2:
                print("*********************dir name:************************")
                print(dir2)
                files = file_list(dir2)
                for file_name in files:
                    w_str = dir2.replace("\\", "\\\\") + '\\\\' + file_name + ' ' + dir2[-2:] + '\n'
                    file.write(w_str)
                print("-----------------------dir end-------------------------")
            file.close()

train.txt

加载数据集,并生成模型文件(.npy)

使用XCeption网络在ImageNet数据集上的训练结果进行迁移学习,加载自己生成的OCR字符样本数据集进行训练,代码(代码参考:小马视频)如下:

import tensorflow as tf
from PIL import Image
import os
import numpy as np
from tensorflow.keras.utils import to_categorical, normalize
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.datasets import cifar10, mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt


train_txt = 'C:\\Users\\leslie\\Desktop\\OCR-sets\\train.txt'
x_train_savepath = './letters_x_train.npy'
y_train_savepath = './letters_y_train.npy'

test_txt = 'C:\\Users\\leslie\\Desktop\\OCR-sets\\test.txt'
x_test_savepath = './letters_x_test.npy'
y_test_savepath = './letters_y_test.npy'


def generateds(txt):
    f = open(txt, 'r')  # 以只读形式打开txt文件
    contents = f.readlines()  # 读取文件中所有行
    f.close()  # 关闭txt文件
    x, y_ = [], []  # 建立空列表
    for content in contents:  # 逐行取出
        value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表
        img_path = value[0]  # 拼出图片路径和文件名
        img = Image.open(img_path)  # 读入图片
        img = np.array(img.convert('RGB'))  # 图片变为3通道彩色的np.array格式
        img = img / 255.  # 数据归一化 (实现预处理)
        x.append(img)  # 归一化后的数据,贴到列表x
        y_.append(value[1])  # 标签贴到列表y_
        print('loading : ' + content)  # 打印状态提示

    x = np.array(x)  # 变为np.array格式
    y_ = np.array(y_)  # 变为np.array格式
    y_ = y_.astype(np.int64)  # 变为64位整型
    return x, y_  # 返回输入特征x,返回标签y_


def run():
    # 加载数据
    if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(
            x_test_savepath) and os.path.exists(y_test_savepath):
        print('-------------Load Datasets-----------------')
        x_train_save = np.load(x_train_savepath)
        y_train = np.load(y_train_savepath)
        x_test_save = np.load(x_test_savepath)
        y_test = np.load(y_test_savepath)
        X_train = np.reshape(x_train_save, (len(x_train_save), 50, 50, 3))
        X_test = np.reshape(x_test_save, (len(x_test_save), 50, 50, 3))
    else:
        print('-------------Generate Datasets-----------------')
        X_train, y_train = generateds(train_txt)
        X_test, y_test = generateds(test_txt)

        print('-------------Save Datasets-----------------')
        x_train_save = np.reshape(X_train, (len(X_train), -1))
        x_test_save = np.reshape(X_test, (len(X_test), -1))
        np.save(x_train_savepath, x_train_save)
        np.save(y_train_savepath, y_train)
        np.save(x_test_savepath, x_test_save)
        np.save(y_test_savepath, y_test)
        # 接下部分代码

加载数据并训练

	# 接上部分代码
    # 数据前处理
    X_train = X_train.astype('float32') / 255
    X_test = X_test.astype('float32') / 255

    num_classes = 62
    y_train = to_categorical(y_train, num_classes = num_classes)
    y_test = to_categorical(y_test, num_classes = num_classes)

    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)

    # from tensorflow.keras.applications.resnet50 import ResNet50
    # base_model = ResNet50(
    #     include_top=False,
    #     weights="imagenet",
    #     input_shape=None
    # )

    from tensorflow.keras.applications.xception import Xception
    base_model = Xception(
        include_top=False,
        weights="imagenet",
        input_shape=None
    )

    # from tensorflow.keras.applications.vgg16 import VGG16
    # base_model = VGG16(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
    # base_model = MobileNetV2(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.inception_v3 import InceptionV3
    # base_model = InceptionV3(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.densenet import DenseNet121
    # base_model = DenseNet121(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    ###################################################################
    # 全连接层
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation = 'relu')(x)
    predictions = Dense(num_classes, activation = 'softmax')(x)

    # 模型网络定义
    model = Model(inputs = base_model.input, outputs = predictions)

    model.compile(
        optimizer = Adam(),
        loss = 'categorical_crossentropy',
        metrics = ["acc"]
    )

    # model.summary()
    print("模型网络定义:{}层".format(len(model.layers)))

    # EarlyStopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        verbose=1
    )

    # Reduce Learning Rate
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=3,
        verbose=1
    )

    ###################################################################
    # 训练参数
    p_batch_size = 128
    p_epochs = 5

    ###################################################################
    # 图片掺水训练
    # 准备图片:ImageDataGenerator
    train_gen  = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        width_shift_range=0.125,
        height_shift_range=0.125,
        horizontal_flip=True)
    test_gen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True)

    # 数据集前计算
    for data in (train_gen, test_gen):
        data.fit(X_train)

    history = model.fit(
        train_gen.flow(X_train, y_train, batch_size=p_batch_size),
        epochs=p_epochs,
        steps_per_epoch=X_train.shape[0] // p_batch_size,
        validation_data=test_gen.flow(X_test, y_test, batch_size=p_batch_size),
        validation_steps=X_test.shape[0] // p_batch_size,
        callbacks=[early_stopping, reduce_lr])

    # 显示训练结果
    plt.plot(history.history['acc'], label='acc')
    plt.plot(history.history['val_acc'], label='val_acc')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(loc='best')
    plt.show()

    # 模型保存
    model.save("model{}.h5".format(p_epochs))

    ###################################################################
    # 结果评价
    test_loss, test_acc = model.evaluate(
        test_gen.flow(X_test, y_test, batch_size=p_batch_size),
        steps=10)
    print('val_loss: {:.3f}\nval_acc: {:.3f}'.format(test_loss, test_acc ))

run()

获得识别模型与准确度

训练好的模型

准确度

注意:训练过程中需要在外网下载XCeption模型的文件,可能会由于网络原因下载失败。

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值