星月落
落入世间 使得花败人断肠
多执着 惊鸿怎让人遗忘
—《星月落》· 浮生梦
前言与背景
keras和tensorflow更新的速度可谓是非常之快,有一部分早年写的代码现在全是红色波浪,近期毕设的一部分需要搭建一个简单的DNN网络,谨以此文浅记所遇之问题,以供后来者参考,亦可作备忘复习。均是入门级问题,熟练者不必阅读。
环境配置上的问题
我是tensorflow2.15.0+tensorboard2.15.1+keras2.15.0,怎么确定这个版本,核心方法就是,在一个命令里面安装
pip install tensorflow tensorboard keras
根据你安装的顺序不同,pip解析包依赖的顺序也是不同的,譬如你先安装B,B在满足A的要求的同时选择了最新的版本,但是你的C依赖B的低版本,就导致你后续安装的C无法使用或者要回退B才能解决。一个命令里面安装是最好的解决办法。
然后就是我发现了一个脚本,是Google官方出的tensorboard环境问题检测脚本,复制下来运行就可以了,它会自动给出当前环境存在的冲突和修复建议,我就是按照它的建议修复的。
https://github.com/sjhsjhsjhsjh/tensorboard_env_repair.git
网络的搭建和训练
简单的FCNN网络一般以一个Flatten层、若干FC层组成的,这里我直接复制代码了:
def DNN():
global model
model.add(Flatten())
model.add(Dense(2, activation='tanh', name = "dense"))
model.add(Dense(64, activation='tanh', name = "dense_1"))
model.add(Dense(16, activation='tanh', name = "dense_2"))
model.add(Dense(5, activation='softmax', name = "dense_3"))
其中,最后一层的数量就是你类别总数。因为预测出来是one-hot标签。
不过这不是重点,重点在于设定网络输入层的shape:
model = keras.models.Sequential()
DNN(0)
# pic_width = 2, pic_height = 1300; pic_width = 1
model.build(input_shape=[None, 2, 1300, 1])
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
先使用keras.models.Sequential()获取顺序堆叠的网络容器,然后调用前面定义的DNN函数进行搭建,然后设定input_shape,第一个维度是None,因为这个参数实际上这个是一次输入多少张图片,而训练的时候一个batch输入多少张实际上是可以变的,所以是None。然后是你的数据维度。我的数据是一个1300行,2列的已经归一化的点集,那么这个点集显然是没有z轴维度,也就是深度的,所以最后一个维度填的是1.中间两个自然就是你的数据长宽,不知道是不是1300,2和2,1300都行。
然后就是compile,设定训练的优化器以及使用的损失函数
训练部分
tb_config = TensorBoard(
log_dir= './logs',
write_images= True,
histogram_freq= 5)
cbks = [tb_config]
model.fit(np.array(train, dtype=float), np.array(label, dtype=float),
batch_size=8, epochs=50,
verbose=2,
shuffle=True,
callbacks=cbks)
定义了一个tensorboard的回调函数,并且在model.fit的时候作为callbacks参数传入,如此可以使用tensorboard观察训练。虽然50轮的训练也没什么好观察的就是了。
模型的保存和加载
上面的内容都是极其基础内容,即使是新手学个两天都能整出来,但是下面的这个问题我觉的有些小坑,就是模型的加载和保存。这个地方网上很多文章都没讲清楚,没有说清楚对应的API关系。这里贴一篇文章,讲的不错https://blog.csdn.net/aloveysz/article/details/124981316
- 一种是在训练完成之后使用
model.save("path/model_save.h5")
保存为h5文件。此时保存的h5文件包含了网络层和神经元的参数,直接加载就可以了,连keras.models.Sequential()获取容器都不需要。加载的方式(能且只能)是model = keras.models.load_model("path/model_save.h5")
,此时这个model具有predict
成员函数,可以直接进行预测
可见,在load_model之后,你只需要准备好图片就可以直接搞里头了。但是我发现模型保存之后其input_shape会发生变化,这个我没搞清楚。上文提到,我网络的输入shape是None*2*1300*1,但是保存之后,输入的shape变成了None*1*2*1300,不知道怎么回事z轴跑到前面去了,所以还得# 这里是个进度条,我是为了好看,可以不用tqdm for pic in tqdm(train, desc = "Test"): pre = [] pre.append(pic) ans = model.predict(np.array(pre), 1, verbose=0)
np.array(pic).reshape(1,2,1300)
整一下 - 另一种是使用save_weights,这个函数顾名思义,只保存了模型的神经元参数,并没有保存模型本身。这个函数对应的是load_weights,显然你load之前需要搭建好网络才能加载。但是这两个函数不知道为什么我用起来会有问题,所以此次笔记我主要记第一个方法。
- 然后就是我当时尝试了一个checkpoint,就是创建了一个回调函数,这个回调函数可以以loss为评估参数,自动保存loss最优的网络模型,用法就是把这个cbks接入之前model.fit函数的callbacks参数。
save_model_cbk = [keras.callbacks.ModelCheckpoint(filepath = 'path' + time_string, monitor='loss', save_best_only=True, mode='auto')] cbks = [save_model_cbk]
以时间为文件名保存
这没啥好说的,就是个小技巧。
from datetime import datetime
now = datetime.now()
time_string = now.strftime("%Y_%m_%d_%H_%M_%S")
model.save('path/' + time_string + '.h5')
加载进度条
也是属于花活,tqdm库,简单的用法如下:第一个参数就是你本来for A in B的B,第二个desc是这个进度条的描述文字,我这里就是说明当前在第几条路径
# 重点只有第一行
for file_name in tqdm(file_name_list, desc = str(path_index)):
pic = []
file = open(path + "\\" + file_name, mode = 'r')
file_str = file.readlines()
for line_str in file_str:
line_str = line_str.replace('\r', '').replace('\n', '') # 去除换行符
pos = line_str.find(' ')
if pos != -1:
point_x = float(line_str[0:pos])
point_y = float(line_str[pos+1:])
pic.append([point_x, point_y])
else:
print("error")
continue
其它知识点就是line_str = line_str.replace('\r', '').replace('\n', '')
,连续两个replace函数删除了这一行的换行符和行结束符
反思
今天看来也没啥好记的,其实就是个入门级的小东西,昨天竟然调了整整一天,回忆起来主要卡的位置就是模型的保存和加载,save/load_model
与save_weights/load_weights
应当是成对使用的,当时没有注意到,对搭建好的model进行load_model,一直报错,现在想起来真是搞笑。
然后就是在模型保存之后的读入卡住了,当时发现它的input_shape莫名其妙z轴跑前面去了,但是我对numpy数组又不熟,寻思先把(1300,2,1)delete成(1300,2)在add(1,1300,2),再reshape,然后查了半天delete的资料;实际上一个reshape就行了,确实有点捞了当时。
还有一个值得注意的点就是python的list也就是列表和array也就是numpy数组的互相转换。转换有一个条件,就是你里面所有的子list的维度必须相同,不然会报错。然后是转换后的类型可以通过dtype参数指定。当时读数据集的时候不知道怎么地有的图没读到完整的,导致转换数组一直报错。
暂时记到这吧,后续会发不用onnx模型和cpp联调网络的方法。
锦绣词句本从天上来
狂写诗词三百
如何请这妙笔 入我梦中来
–《春涧》浅影阿