tf.keras搭建神经网络
- 六步法
- import
- train,test
- Sequential or Class
- model.compile
- model.fit
- model.summary
Q1:当有个自己本领域的数据,又有了标签。显然不能用 load_data
Q2:若数据量过少,模型训练量不够,泛化能力不够。应该怎么解决。(数据增强)
Q3:每次训练都从头开始,不利于效率。(断点续训,参数提取)
今日主要学习内容:
- 自制数据集,解决本领域应用
- 数据增强,扩充数据集
- 断点续训,存取模型
- 参数提取,把参数存入文本
- acc/loss可视化,查看训练效果
- 应用练手,给图识物
自制数据集
1、拿到数据集后,现根据数据集文件(图片等)和给的txt文件(这个txt文件记录了文件名和标签)导入路径train_path和train_txt。同时建立可被tensorflow直接使用的数据集文件的存储路径和名字
(x_train_savepath,y_train_savepath),
(x_test_savepath,y_test_savepath)
2、建立数据集生成函数(generateds(path,txt))用来对给的数据集路径和,数据集标注处理,生成数据集。
def generateds(path, txt):
f = open(txt, 'r') # 以只读形式打开txt文件
contents = f.readlines() # 读取文件中所有行
f.close() # 关闭txt文件
x, y_ = [], [] # 建立空列表
for content in contents: # 逐行取出
value = content.split() # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表
img_path = path + value[0] # 拼出图片路径和文件名
img = Image.open(img_path) # 读入图片
img = np.array(img.convert('L')) # 图片变为8位宽灰度值的np.array格式
img = img / 255. # 数据归一化 (实现预处理)
x.append(img) # 归一化后的数据,贴到列表x
y_.append(value[1]) # 标签贴到列表y_
print('loading : ' + content) # 打印状态提示
x = np.array(x) # 变为np.array格式
y_ = np.array(y_) # 变为np.array格式
y_ = y_.astype(np.int64) # 变为64位整型
return x, y_ # 返回输入特征x,返回标签y_
3、判断4个文件
(x_train_savepath,y_train_savepath,x_test_savepath,y_test_savepath)是否存在。若不存在。调用generateds(path,txt)生成上面的4个文件,用np.load()读取文件内的(x_train,y_train,x_test,y_test),并将x_train_save , x_test_save reshape成 (len(x_train_save), 28, 28)。
若存在。则直接用这4个文件生成(x_train, y_train),(x_test, y_test)
if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(
x_test_savepath) and os.path.exists(y_test_savepath):
print('-------------Load Datasets-----------------')
x_train_save = np.load(x_train_savepath)
y_train = np.load(y_train_savepath)
x_test_save = np.load(x_test_savepath)
y_test = np.load(y_test_savepath)
x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))
x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:
print('-------------Generate Datasets-----------------')
x_train, y_train = generateds(train_path, train_txt)
x_test, y_test = generateds(test_path, test_txt)
print('-------------Save Datasets-----------------')
x_train_save = np.reshape(x_train, (len(x_train), -1))
x_test_save = np.reshape(x_test, (len(x_test), -1))
np.save(x_train_savepath, x_train_save)
np.save(y_train_savepath, y_train)
np.save(x_test_savepath, x_test_save)
np.save(y_test_savepath, y_test)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
数据增强(增大数据量)
image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=所有数据将乘以该数值,
rotation_range=随机旋转角度数范围,
width_shift_range=随机宽度偏移量,
height_shift_range=随机高度偏移量,
horizontal_flip=是否随机水平翻转,
zoom_range=随机缩放的范围,
)
image_gen_train.fit(x_train)
这里的 fit 需要的是一个四维的x_train
所以将x_train进行reshape
x_train = x_train.reshape(x_train.shape[0],28,28,1)
其中 1 为单通道,表示灰度值。
存取模型
读取模型
load_weights(路径文件名)
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
print('______________load the model _________________')
model.load_weights(checkpoint_save_path)
生成参数文件时,会同步生成索引表,可以通过索引表,即可直接查看是否有保存过模型参数
tf.keras.callbacks.ModelCheckpoint(
filepath=路径文件名,
save_best_only=True/False,
save_weights_only=True/False
)
history = model.fit
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train,y_train,batch_size=32,epochs=50,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback])
提取可训练参数
model.trainable_variables 返回模型中可训练的参数
可用print直接打印出来。
np.set_printoptions(threshold = 超过多少省略显示)
np.set_printoptions(threshold=np.inf) # np.inf 表示无限大
也可使用 for 循环,将所有参数存入文本
file = open('./weights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
acc/loss可视化,查看训练效果
其实在model.fit训练过程时,已经同步记录了训练集loss、测试集val_loss、训练集准确率sparse_categorical_accuracy、
测试集准确率:val_sparse_categorical_accuracy
acc = history.history['sparse_categorical_accuracy']
val_acc = istory.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']