代码链接:
https://download.csdn.net/download/qq_38649386/12667825
实验结果:
问题&解决办法:
1.“csv_{0}.log”,log日志文件,路径设置:
CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)
优化:
# 增加的神经元数量
# N_ADDITION=0,128,256,512,1024,2048,并观察实验结果
N_ADDITION = 512
还能继续优化不用手动更新
产生文件:
2.RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)
错误代码:
CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)
model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=VERBOSE,validation_split=VALIDATION_SPLIT,callbacks=[CSVLogger_])
正确代码:
CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)
model.compile(loss='categorical_crossentropy', optimizer=OPTIMIZER,metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=VERBOSE,validation_split=VALIDATION_SPLIT,callbacks=[CSVLogger_])
3.错误代码:
path = 'weights/csv/csv_{0}.log'.format(j)
正确代码:依照上文自己设定的path路径填写
path = 'csv/csv_{0}.log'.format(j)
4.KeyError: "['acc' 'val_acc'] not in index"
错误代码:
sns.lineplot(data=data[['acc','val_acc']])
正确写法:
#绘制曲线
sns.lineplot(data=data[['accuracy','val_accuracy']])
5.AttributeError: 'numpy.float64' object has no attribute 'values'
代码错误定位:
正确:
6.plt.show() 图标重叠,如图:
添加代码:
# 方法一 效果最好
plt.tight_layout()
# 方法二 设置表格大小
#plt.figure(figsize=(16, 12))
# 方法三 设置尺寸
#plt.tight_layout(pad=0.1, w_pad=1.0, h_pad=1.0)
实验原理&代码讲解:
--》有空再写