今天在做DCGAN网络的训练,遇到了很奇怪的问题,因为参考的代码所配置的环境tensorflow_gpu1.15.0 ,在tf2.x普及的今天很是难受 我起初认为是tf环境的问题,一直在重新配置不断尝试,但是更新到2.2.0后曾经顺利跑通过,今天训练过程却又出现了新问题
训练报错如下:
Epoch 0
2024-03-28 20:45:09.714512: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10
tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2024-03-28 21:09:06.971497: W tensorflow/stream_executor/gpu/asm_compiler.cc:81] Running ptxas --version returned 256
2024-03-28 21:09:07.115038: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] Internal: ptxas exited with non-zero error code 256, output:
Relying on driver to perform ptx compilation.
Modify $PATH to customize ptxas location.
This message will be only logged once.
step: 2 monitor: -1.0250881 reference: -1 tolerance: 0
Epoch: [ 0] [ 0/ 638] time: 1455.8668, d_loss: 1.62575448, g_loss: 0.60066640
Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1349, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1441, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: d_
[[{{node d_}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "main.py", line 206, in <module>
main(dataset=args.dataset, build_type=args.mode, sph=args.sph)
File "main.py", line 101, in main
dcgan.train(FLAGS)
File "/root/autodl-tmp/seizure-prediction-GAN-master/dcgan/model.py", line 194, in train
_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={ self.z: batch_z }) #执行优化g_optim,同时获取生成器输出图像的汇总信息self.g_sum
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 957, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1180, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1358, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1384, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: d_
[[node d_ (defined at /root/autodl-tmp/seizure-prediction-GAN-master/dcgan/model.py:104) ]]
Errors may have originated from an input operation.
Input Source operations connected to node d_:
discriminator_1/Sigmoid (defined at /root/autodl-tmp/seizure-prediction-GAN-master/dcgan/model.py:248)
Original stack trace for 'd_':
File "main.py", line 206, in <module>
main(dataset=args.dataset, build_type=args.mode, sph=args.sph)
File "main.py", line 92, in main
dcgan = DCGAN(sess=sess,
File "/root/autodl-tmp/seizure-prediction-GAN-master/dcgan/model.py", line 77, in __init__
self.build_model()
File "/root/autodl-tmp/seizure-prediction-GAN-master/dcgan/model.py", line 104, in build_model
self.d__sum = histogram_summary("d_", self.D_) #D_
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/summary/summary.py", line 178, in histogram
val = _gen_logging_ops.histogram_summary(
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/gen_logging_ops.py", line 284, in histogram_summary
_, _, _op, _outputs = _op_def_library._apply_op_helper(
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py", line 742, in _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3319, in _create_op_internal
ret = Operation(
File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1791, in __init__
self._traceback = tf_stack.extract_stack()
最终耗时一天,原来依然是一些包的版本问题,将版本改为旧版本即可顺利解决,感谢ZXY的真诚祝福给我带来好运~ 解决办法:
pip install six==1.12.0
我严重怀疑是six包的问题,好像其余包没咋变,版本如下供参考:
scipy==1.0.1
np_utils==0.5.10.0
pandas==0.24.2
stft==0.5.2
mne==0.11.0
scikit_learn==0.21.3
3.30更新:又出现了报错,删除了训练的记录/logs文件夹以及log_loss的文件确实成功解决!
附上我之前参考的文章解决办法是:上一次训练得到的结果删除 就可以解决,但是我依然失败..仅供读者参考:训练时出现invalid argument: Nan in summary histogram for: image_pooling/BatchNorm/moving_variance_1 - zmbreathing - 博客园 (cnblogs.com)
网络训练时出现错误:Nan in summary histogram_nan in summary histogram for: bert_model/title_den-CSDN博客