import os
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam,SGD
import cfg
#from network import East
from network_densenet import East
from data_generator import gen
import tensorflow as tf
from keras import backend as K
from keras import models
from losses import quad_loss
#from tensorflow.keras.callbacks import TensorBoard
config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "1"
sess = tf.Session(config=config)
K.set_session(sess)
#模型载入与参数模块
east = East() #返回一个Model类,Model(inputs=self.input_img, outputs=east_detect)
east_network = east.east_network() #1*1的卷积操作
east_network.summary() #打印
if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path): #cfg为参数文件
print('加载模型成功')
east_network.load_weights(cfg.saved_model_weights_file_path)
#载入模型且换参数模块
# models=models.load_model(cfg.load_model_file_path,custom_objects={'quad_loss': quad_loss})
# models.summary()
# east_network=models
#训练器配置模块
east_network.compile(loss=quad_loss, optimizer=Adam(lr=cfg.lr, #配置训练模型
# clipvalue=cfg.clipvalue,
decay=cfg.decay))
#训练器训练初始化模块
east_network.fit_generator(generator=gen(),
steps_per_epoch=cfg.steps_per_epoch,
epochs=cfg.epoch_num,
validation_data=gen(is_val=True),
validation_steps=cfg.validation_steps,
verbose=1,
initial_epoch=cfg.initial_epoch,
callbacks=[
EarlyStopping(patience=cfg.patience, verbose=1),
ModelCheckpoint(filepath=cfg.model_weights_path, #每个周期保存一次
save_best_only=True,
save_weights_only=True,
verbose=1)])
east_network.save(cfg.saved_model_file_path)
east_network.save_weights(cfg.saved_model_weights_file_path)
#生成器配置函数:generator=gen(),返回一个迭代器(图片,bacth*标签)
#训练生成器,如何读取训练图片并进行预处理
import os
import numpy as np
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import cfg
import imghdr
def gen(batch_size=cfg.batch_size, is_val=False):
img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
x = np.zeros((batch_size, img_h, img_w, cfg.num_channels), dtype=np.float32)
#import pdb; pdb.set_trace()
pixel_num_h = img_h // cfg.pixel_size
pixel_num_w = img_w // cfg.pixel_size
y = np.zeros((batch_size, pixel_num_h, pixel_num_w, 7), dtype=np.float32)
if is_val:
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val: #txt文档
f_list = f_val.readlines()
else:
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
f_list = f_train.readlines()
while True:
for i in range(batch_size):
# random gen an image name
random_img = np.random.choice(f_list)
img_filename = str(random_img).strip().split(',')[0]
# load img and img anno
img_path = os.path.join(cfg.data_dir,
cfg.train_image_dir_name,
img_filename)
try: #处理图片异常时
imghdr.what(img_path)
except:
print('88888888888888888888888888888888888')
print('img_filename',img_path)