Keras 框架构建

import os
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam,SGD
import cfg
#from network import East
from network_densenet import East

from data_generator import gen
import tensorflow as tf
from keras import backend as K
from keras import models
from losses import quad_loss
#from tensorflow.keras.callbacks import TensorBoard
config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "1"
sess = tf.Session(config=config)
K.set_session(sess)

#模型载入与参数模块
east = East()               #返回一个Model类,Model(inputs=self.input_img, outputs=east_detect) 
east_network = east.east_network()   #1*1的卷积操作
east_network.summary()   #打印
if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path):  #cfg为参数文件
    print('加载模型成功')
    east_network.load_weights(cfg.saved_model_weights_file_path)

#载入模型且换参数模块
# models=models.load_model(cfg.load_model_file_path,custom_objects={'quad_loss': quad_loss})
# models.summary()
# east_network=models

#训练器配置模块
east_network.compile(loss=quad_loss, optimizer=Adam(lr=cfg.lr,    #配置训练模型
                                                    # clipvalue=cfg.clipvalue,
                                                    decay=cfg.decay))
#训练器训练初始化模块
east_network.fit_generator(generator=gen(),
                           steps_per_epoch=cfg.steps_per_epoch,
                           epochs=cfg.epoch_num,
                           validation_data=gen(is_val=True),
                           validation_steps=cfg.validation_steps,
                           verbose=1,
                           initial_epoch=cfg.initial_epoch,
                           callbacks=[
                               EarlyStopping(patience=cfg.patience, verbose=1),
                               ModelCheckpoint(filepath=cfg.model_weights_path, #每个周期保存一次
                                               save_best_only=True,
                                               save_weights_only=True,
                                               verbose=1)])
east_network.save(cfg.saved_model_file_path)
east_network.save_weights(cfg.saved_model_weights_file_path)

 

 

 

#生成器配置函数:generator=gen(),返回一个迭代器(图片,bacth*标签)

#训练生成器,如何读取训练图片并进行预处理
import os
import numpy as np
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import cfg
import imghdr

def gen(batch_size=cfg.batch_size, is_val=False):
    img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
    x = np.zeros((batch_size, img_h, img_w, cfg.num_channels), dtype=np.float32)
    #import pdb; pdb.set_trace()
    pixel_num_h = img_h // cfg.pixel_size
    pixel_num_w = img_w // cfg.pixel_size
    y = np.zeros((batch_size, pixel_num_h, pixel_num_w, 7), dtype=np.float32)
    if is_val:
        with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:  #txt文档
            f_list = f_val.readlines()
    else:
        with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
            f_list = f_train.readlines()
    while True:
        for i in range(batch_size):
            # random gen an image name
            random_img = np.random.choice(f_list)
            img_filename = str(random_img).strip().split(',')[0]
            # load img and img anno
            img_path = os.path.join(cfg.data_dir,
                                    cfg.train_image_dir_name,
                                    img_filename)
            try:    #处理图片异常时
                imghdr.what(img_path)
            except:
                print('88888888888888888888888888888888888')
                print('img_filename',img_path)
          
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值