Keras是一款比较容易上手的深度学习框架,在构建模型/训练数据方面比较方便使用
1 训练数据的传输
def prepare_input_data(img_width,img_height):
train_datagen=image.ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = image.ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
config['Train_path'],
target_size=(img_width, img_height),
batch_size=int(config['Batch_size']),
class_mode='categorical')
validation_generator = val_datagen.flow_from_directory(
config['Val_path'],
target_size=(img_width, img_height),
batch_size=int(config['Batch_size']),
class_mode='categorical',
shuffle=False)
print(train_generator.class_indices)
#print(train_generator.shape)
#print(validation_generator.class_indices)
return train_generator, validation_generator
2 模型构建
import keras
import keras.preprocessing import image
import keras.layers import Conv2D,MaxPooling2D,Dense,Activation,Flatten,Dropout
import keras.layers.normalization import BatchNormalization
import keras.optimizers import SGD,RMSprop,Adagrad, Adadelta, Adam, Adamax, Nadam
import numpy as np
import keras.models import Sequential
import keras.models import load_model
def get_model(config):
model=Sequential()
#block 1
model.add(Conv2D(filters=32,kernel_size=(5,5),padding='valid',activation='relu',input_shape=(int(config['IMG_WIDTH']),int(config['IMG_HEIGHT']),3)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
#block 2
model.add(Conv2D(filters=64,kernel_size=(3,3),padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
#block 3
model.add(Conv2D(filters=128,kernel_size=(3,3),padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
#block 4
model.add(Conv2D(filters=128,kernel_size=(3,3),padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
#block 5
model.add(Flatten())
model.add(Dense(units=1024,activation='relu'))
model.add(Dropout(rate=0.5))
#model.add(Dense(units=20),activation='softmax')
model.add(Dense(units=8, activation='softmax'))
return model
3 加载模型
def train(config):
model = get_model(config)
model.compile(loss=config['Loss'],
optimizer=Adam(lr=float(config['learning_rate'])),
metrics=[config['METRICS']])
# Train the model
train_generator, validation_generator = prepare_input_data(int(config['IMG_WIDTH']), int(config['IMG_HEIGHT']))
hist = LossHistory()
#class_weights = class_weight_generation(config['TRAIN_DATA_PATH'])
model.fit_generator(
train_generator,
steps_per_epoch=2,
epochs=1,
validation_data=validation_generator,
validation_steps=1,
verbose=1,callbacks=[hist]
)
print(hist.losses)
#save_model(model, config['model_save']+'12345')
4 配置文件的解析
def get_cfg():
cfg={}
f=open(r'C:/Users/zjunzhao/Desktop/cfg.txt')
lines=f.readlines()
for line in lines:
cfg_value=[i for i in line.strip().split(':',1)]
#(key, value)=line.strip().split(':',1)
cfg[cfg_value[0]]=cfg_value[1]
#cfg[key]=value
#print(cfg_value)
return cfg
5 绘制Loss/Acc图像
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = {"batch":[],"epoch":[]}
self.accuracy = {"batch":[],"epoch":[]}
self.val_loss = {"batch":[],"epoch":[]}
self.val_acc = {"batch":[],"epoch":[]}
def on_batch_end(self, batch, logs={}):
self.losses["batch"].append(logs.get('loss'))
self.accuracy["batch"].append(logs.get('acc'))
self.val_loss["batch"].append(logs.get('val_loss'))
self.val_acc["batch"].append(logs.get('val_acc'))
def on_epoch_end(self,epoch,logs={}):
self.losses["epoch"].append(logs.get('loss'))
self.accuracy["epoch"].append(logs.get('acc'))
self.val_loss["epoch"].append(logs.get('val_loss'))
self.val_acc["epoch"].append(logs.get('val_acc'))
def loss_plot(self, loss_type):
iters = range(len(self.losses[loss_type]))
plt.figure()
# acc
plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
# loss
plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
if loss_type == 'epoch':
# val_acc
plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
# val_loss
plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
plt.grid(True)
plt.xlabel(loss_type)
plt.ylabel('acc-loss')
plt.legend(loc="upper right")
plt.show()
6 python实现拷贝指定文件到指定目录
import os
import shutil
alllist=os.listdir(u"D:\\notes\\python\\资料\\")
for i in alllist:
aa,bb=i.split(".")
if 'python' in aa.lower():
oldname= u"D:\\notes\\python\\资料\\"+aa+"."+bb
newname=u"d:\\copy\\newname"+aa+"."+bb
shutil.copyfile(oldname,newname)
参考博客:
绘制loss:https://www.cnblogs.com/jzy996492849/p/7234233.html
https://blog.csdn.net/u011037837/article/details/51593099
https://blog.csdn.net/u013381011/article/details/78911848
保存checkpoints以及kepoints detection:https://blog.csdn.net/hjimce/article/details/49095199