现象:
训练过程中多个loss回传产生了GPU显存不够用的情况(即使是设置batch_size最小也不行),在backward函数中去掉retain_graph=True之后,情况没有出现。
注意:retain_graph=True是一个在训练代码中需要极力避免的一个设置。降低了训练速度,增大了显存开销。但是我的代码中暂时不能避免这个设置。
原因分析:
我这里出现这个情况的原因:因为不同loss求完之后没有算均值,可能返回的是一个tensor,在训练loop中会累积越来越大,要通过 .mean() 把它变成标量。
解决:
criterion = torch.nn.CrossEntropyLoss()
output = module_a(fc1Features