hi!大家好久haojiu不见,拖延症+休息,时隔这么久才继续推进,哎,come on吧!
这次主要解决GPU显存不足的问题,看能否通过添加一些代码语句清除缓存,或者更改内存地址等等手段来实现,如果都不行只能换设备试了
1.查看GPU的使用情况
在终端输入:
nvidia-smi -l
可以查看实时的GPU使用情况
如上方式会显示历史信息和当前信息,如果只想看当前信息,则可以执行如下命令实现每1s刷新一次:
watch -n 1 nvidia-smi
我的GPU显存为2048MiB,使用最大达到1600+后报错不再执行,我觉得应该是我的显存过小的问题,可能加载模型都不太够,更别说后面的图像处理过程了/(ㄒoㄒ)/~~
2.神经网络显存占用
神经网络模型占用的显存包括:
- 模型自身的参数(二维数组)
- 模型的输出(二维数组)
(1)模型占用
只有有参数的层,才会有显存占用。这部份的显存占用和输入无关,模型加载完成之后就会占用。
有参数的层主要包括:
- 卷积
- 全连接
- BatchNorm
- Embedding层
- ... ...
无参数的层:
- 多数的激活层(Sigmoid/ReLU)
- 池化层
- Dropout
- ... ...
更具体的来说,模型的参数数目计算公式(这里均不考虑偏置项b)为:
- Linear(M->N): 参数数目:M×N
- Conv2d(Cin, Cout, K): 参数数目:Cin × Cout × K × K
- BatchNorm(N): 参数数目: 2N
- Embedding(N,W): 参数数目: N × W
参数占用显存 = 参数数目×n
n = 4 :float32
n = 2 : float16
n = 8 : double64
(2)模型输出的显存占用
总结如下:
- 需要计算每一层的feature map的形状(多维数组的形状)
- 需要保存输出对应的梯度用以反向传播(链式法则)
- 显存占用与 batch size 成正比
- 模型输出不需要存储相应的动量信息。
(3)深度学习中神经网络的显存总占用
显存占用 = 模型显存占用 + batch_size × 每个样本的显存占用
3.节省显存的方法
在深度学习中,一般占用显存最多的是卷积等层的输出,模型参数占用的显存相对较少,而且不太好优化。节省显存一般有如下方法:
- 降低batch-size
- 下采样(NCHW -> (1/4)*NCHW)
- 减少全连接层(一般只留最后一层分类用的全连接层)
假定GPU处理单元已经充分利用的情况下:
- 增大batch size能增大速度,但是很有限(主要是并行计算的优化)
- 增大batch size能减缓梯度震荡,需要更少的迭代优化次数,收敛的更快,但是每次迭代耗时更长。
- 增大batch size使得一个epoch所能进行的优化次数变少,收敛可能变慢,从而需要更多时间才能收敛(比如batch_size 变成全部样本数目)。
4.tensorflow中“缓解”GPU显存不足的方法
(1)降低batch_size
从32到16到10到5,都不行
(2)在主函数中加入以下代码(tensorflow版本为2.0)
os.environ["CUDA_VISIBLE_DEVICES"] = '0' #指定第一块GPU可用
config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5 # 程序最多只能占用指定gpu50%的显存
config.gpu_options.allow_growth = True #程序按需申请内存
session = tf.compat.v1.Session(config=config)
结果还是不行:
5.总结
后续尝试在云服务器或者实验室的电脑上试试,目前的computer只能是显存过小无法解决
参考文章
【工具篇】如何优雅地监控显卡(GPU)使用情况? - 知乎 (zhihu.com)
科普帖:深度学习中GPU和显存分析 - 知乎 (zhihu.com)
解决tensorflow gpu报错: ran out of memory (OOM)-CSDN博客
搞定tensorflow的报错:GPU显存篇 - 知乎 (zhihu.com)
Tensorflow显存不足解决方式 - 知乎 (zhihu.com)
tensorflow2.x(一) 显存不够或内存不够要怎么办?_tensorflow显存存不够怎么借用内存-CSDN博客
tensorflow显存不够使用如何解决(windows系统)-CSDN博客
AttributeError: module 'tensorflow' has no attribute 'ConfigProto'-CSDN博客