keras搭建全连接神经网络问题小记

星月落
落入世间 使得花败人断肠
多执着 惊鸿怎让人遗忘
—《星月落》· 浮生梦

前言与背景

keras和tensorflow更新的速度可谓是非常之快,有一部分早年写的代码现在全是红色波浪,近期毕设的一部分需要搭建一个简单的DNN网络,谨以此文浅记所遇之问题,以供后来者参考,亦可作备忘复习。均是入门级问题,熟练者不必阅读。

环境配置上的问题

我是tensorflow2.15.0+tensorboard2.15.1+keras2.15.0,怎么确定这个版本,核心方法就是,在一个命令里面安装

pip install tensorflow tensorboard keras

根据你安装的顺序不同,pip解析包依赖的顺序也是不同的,譬如你先安装B,B在满足A的要求的同时选择了最新的版本,但是你的C依赖B的低版本,就导致你后续安装的C无法使用或者要回退B才能解决。一个命令里面安装是最好的解决办法。

然后就是我发现了一个脚本,是Google官方出的tensorboard环境问题检测脚本,复制下来运行就可以了,它会自动给出当前环境存在的冲突和修复建议,我就是按照它的建议修复的。

https://github.com/sjhsjhsjhsjh/tensorboard_env_repair.git

网络的搭建和训练

简单的FCNN网络一般以一个Flatten层、若干FC层组成的,这里我直接复制代码了:

def DNN():
    global model
    
    model.add(Flatten())

    model.add(Dense(2, activation='tanh', name = "dense"))

    model.add(Dense(64, activation='tanh', name = "dense_1"))

    model.add(Dense(16, activation='tanh', name = "dense_2"))

    model.add(Dense(5, activation='softmax', name = "dense_3"))

其中,最后一层的数量就是你类别总数。因为预测出来是one-hot标签。

不过这不是重点,重点在于设定网络输入层的shape:

model = keras.models.Sequential()
DNN(0)
# pic_width = 2, pic_height = 1300;  pic_width = 1
model.build(input_shape=[None, 2, 1300, 1])
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])

先使用keras.models.Sequential()获取顺序堆叠的网络容器,然后调用前面定义的DNN函数进行搭建,然后设定input_shape,第一个维度是None,因为这个参数实际上这个是一次输入多少张图片,而训练的时候一个batch输入多少张实际上是可以变的,所以是None。然后是你的数据维度。我的数据是一个1300行,2列的已经归一化的点集,那么这个点集显然是没有z轴维度,也就是深度的,所以最后一个维度填的是1.中间两个自然就是你的数据长宽,不知道是不是1300,2和2,1300都行。

然后就是compile,设定训练的优化器以及使用的损失函数

训练部分

        tb_config = TensorBoard(
            log_dir= './logs', 
            write_images= True, 
            histogram_freq= 5)
        cbks = [tb_config]

        model.fit(np.array(train, dtype=float), np.array(label, dtype=float), 
            batch_size=8, epochs=50,
            verbose=2,
            shuffle=True,
            callbacks=cbks)

定义了一个tensorboard的回调函数,并且在model.fit的时候作为callbacks参数传入,如此可以使用tensorboard观察训练。虽然50轮的训练也没什么好观察的就是了。

模型的保存和加载

上面的内容都是极其基础内容,即使是新手学个两天都能整出来,但是下面的这个问题我觉的有些小坑,就是模型的加载和保存。这个地方网上很多文章都没讲清楚,没有说清楚对应的API关系。这里贴一篇文章,讲的不错https://blog.csdn.net/aloveysz/article/details/124981316

  • 一种是在训练完成之后使用model.save("path/model_save.h5")保存为h5文件。此时保存的h5文件包含了网络层和神经元的参数,直接加载就可以了,连keras.models.Sequential()获取容器都不需要。加载的方式(能且只能)是model = keras.models.load_model("path/model_save.h5"),此时这个model具有predict成员函数,可以直接进行预测
    # 这里是个进度条,我是为了好看,可以不用tqdm
    for pic in tqdm(train, desc = "Test"):
          pre = []
          pre.append(pic)
          ans = model.predict(np.array(pre), 1, verbose=0)
    
    可见,在load_model之后,你只需要准备好图片就可以直接搞里头了。但是我发现模型保存之后其input_shape会发生变化,这个我没搞清楚。上文提到,我网络的输入shape是None*2*1300*1,但是保存之后,输入的shape变成了None*1*2*1300,不知道怎么回事z轴跑到前面去了,所以还得np.array(pic).reshape(1,2,1300)整一下
  • 另一种是使用save_weights,这个函数顾名思义,只保存了模型的神经元参数,并没有保存模型本身。这个函数对应的是load_weights,显然你load之前需要搭建好网络才能加载。但是这两个函数不知道为什么我用起来会有问题,所以此次笔记我主要记第一个方法。
  • 然后就是我当时尝试了一个checkpoint,就是创建了一个回调函数,这个回调函数可以以loss为评估参数,自动保存loss最优的网络模型,用法就是把这个cbks接入之前model.fit函数的callbacks参数。
    save_model_cbk = [keras.callbacks.ModelCheckpoint(filepath = 'path' + time_string, monitor='loss', save_best_only=True, mode='auto')]
              cbks = [save_model_cbk]
    

以时间为文件名保存

这没啥好说的,就是个小技巧。

from datetime import datetime
now = datetime.now()
time_string = now.strftime("%Y_%m_%d_%H_%M_%S")
model.save('path/' + time_string + '.h5')

加载进度条

也是属于花活,tqdm库,简单的用法如下:第一个参数就是你本来for A in B的B,第二个desc是这个进度条的描述文字,我这里就是说明当前在第几条路径

# 重点只有第一行
for file_name in tqdm(file_name_list, desc = str(path_index)):
            pic = []
            file = open(path + "\\" + file_name, mode = 'r')
            file_str = file.readlines()
            for line_str in file_str:
                line_str = line_str.replace('\r', '').replace('\n', '')  # 去除换行符
                pos = line_str.find(' ')
                
                if pos != -1:
                    point_x = float(line_str[0:pos])
                    point_y = float(line_str[pos+1:])
                    pic.append([point_x, point_y])
                else:
                    print("error")
                    continue

其它知识点就是line_str = line_str.replace('\r', '').replace('\n', ''),连续两个replace函数删除了这一行的换行符和行结束符

反思

今天看来也没啥好记的,其实就是个入门级的小东西,昨天竟然调了整整一天,回忆起来主要卡的位置就是模型的保存和加载,save/load_modelsave_weights/load_weights应当是成对使用的,当时没有注意到,对搭建好的model进行load_model,一直报错,现在想起来真是搞笑。
然后就是在模型保存之后的读入卡住了,当时发现它的input_shape莫名其妙z轴跑前面去了,但是我对numpy数组又不熟,寻思先把(1300,2,1)delete成(1300,2)在add(1,1300,2),再reshape,然后查了半天delete的资料;实际上一个reshape就行了,确实有点捞了当时。
还有一个值得注意的点就是python的list也就是列表和array也就是numpy数组的互相转换。转换有一个条件,就是你里面所有的子list的维度必须相同,不然会报错。然后是转换后的类型可以通过dtype参数指定。当时读数据集的时候不知道怎么地有的图没读到完整的,导致转换数组一直报错。
暂时记到这吧,后续会发不用onnx模型和cpp联调网络的方法。

锦绣词句本从天上来
狂写诗词三百
如何请这妙笔 入我梦中来
–《春涧》浅影阿

  • 17
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值