一、项目背景
HiDDen的jpeg()的训练,要求的torch版本是1.0,我的是1.11,考虑过降低版本号,但如果要降版本的话还要改python版本、cuda版本,而且刚开始复现代码,希望能提高自己读代码和纠正错误的能力,所以决定就在这个基础上直接改。
def train_on_batch(self, batch: list):
"""
Trains the network on a single batch consisting of images and messages
:param batch: batch of training data, in the form [images, messages]
:return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
"""
images, messages = batch
batch_size = images.shape[0]
self.encoder_decoder.train()
self.discriminator.train()
with (torch.enable_grad()):
# ---------------- Train the discriminator -----------------------------
self.optimizer_discrim.zero_grad()
# train on cover
d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device).float()
d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device).float()
g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device).float()
d_on_cover = self.discriminator(images)
d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover)
d_loss_on_cover.backward()
# train on fake
encoded_images, noised_images, decoded_messages = self.encoder_decoder(images, messages)
d_on_encoded = self.discriminator(encoded_images.detach())
d_loss_on_encoded = self.bce_with_logits_loss(d_on_encoded, d_target_label_encoded)
d_loss_on_encoded.backward()
self.optimizer_discrim.step()
# --------------Train the generator (encoder-decoder) ---------------------
self.optimizer_enc_dec.zero_grad()
# target label for encoded images should be 'cover', because we want to fool the discriminator
d_on_encoded_for_enc = self.discriminator(encoded_images)
g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded)
if self.vgg_loss == None:
g_loss_enc = self.mse_loss(encoded_images, images)
else:
vgg_on_cov = self.vgg_loss(images)
vgg_on_enc = self.vgg_loss(encoded_images)
g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)
g_loss_dec = self.mse_loss(decoded_messages, messages)
g_loss = (self.config.adversarial_loss * g_loss_adv
+ self.config.encoder_loss * g_loss_enc
+ self.config.decoder_loss * g_loss_dec)
g_loss.backward()
self.optimizer_enc_dec.step()
decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(0, 1)
bitwise_avg_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
batch_size * messages.shape[1])
losses = {
'loss ': g_loss.item(),
'encoder_mse ': g_loss_enc.item(),
'dec_mse ': g_loss_dec.item(),
'bitwise-error ': bitwise_avg_err,
'adversarial_bce': g_loss_adv.item(),
'discr_cover_bce': d_loss_on_cover.item(),
'discr_encod_bce': d_loss_on_encoded.item()
}
return losses, (encoded_images, noised_images, decoded_messages)
二、报错内容
在运行的时候出现了如下报错:
Traceback (most recent call last):
File "K:\python_project\test_HiDDeN\HiDDeN\main.py", line 150, in <module>
main()
File "K:\python_project\test_HiDDeN\HiDDeN\main.py", line 146, in main
train(model, device, hidden_config, train_options, this_run_folder, tb_logger)
File "K:\python_project\test_HiDDeN\HiDDeN\train.py", line 52, in train
losses, _ = model.train_on_batch([image, message])
File "K:\python_project\test_HiDDeN\HiDDeN\model\hidden.py", line 101, in train_on_batch
g_loss.backward()
File "K:\python_project\test_HiDDeN\HiDDeN\venv\lib\site-packages\torch\_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "K:\python_project\test_HiDDeN\HiDDeN\venv\lib\site-packages\torch\autograd\__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:
[torch.cuda.FloatTensor [48, 1, 128, 128]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead.
Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
报错说:在g_loss.backward()【梯度计算】的时候,所需的一个变量已被就地操作修改。
给的建议是:使用torch.autograd.set_detect_anomaly(True)启用异常检测,以查找未能计算其梯度的操作。
三、解决:
3.1 跟着报错建议:
在代码报错行附近加了 torch.autograd.set_detect_anomaly(True) ,但是运行之后还是指到了原来的位置,并没有进一笔指出异常的地方,看起来这个方法不适合我的情况。
3.2 找教程参考:
1. 参考1 [未解决]
看到问题描述是一样的,没多想就套用了,但我还是报错。
回头想想我连人家因为什么报错都不知道,这么鲁莽,肯定会出问题。
后面他又说检查是否有inplace=True的设置,检查了一下确定没有,所以放弃了。
2. 参考2 [未解决]
提到了retain_graph参数:是否在计算完成后释放空间。说的是在存在多个要求导对象的情况。虽然最后没有用到,但是觉得这篇博客之后会派上用场的。
3. 参考3 [未解决]
把更新梯度的步骤调后放在一起
尝试了按照他的格式改代码,但是改了改还是没用,同样会报错。
4. 自己摸索 [ 已解决 ]
知道学艺不精,甚至还没入门呢,打算学李沐老师的动手机器学习熟悉熟悉操作,反而被里面一句if isinstance(updater, torch.optim.Optimizer) 给绕住了。
还是打算回头看看。这时候像是突然长了脑子一样,很有灵感:
他说一个求梯度过程中一个必要的变量出错了,出错又在g_loss,那肯定是g_loss的一个参数有问题。找到了g_loss:
g_loss = (self.config.adversarial_loss * g_loss_adv
+ self.config.encoder_loss * g_loss_enc
+ self.config.decoder_loss * g_loss_dec)
一个一个试,把变量换成常数,结果发现把g_loss_dec换成常数就可以了,说明它在计算过程中发生了变化。追溯到函数本身。
g_loss_dec = self.mse_loss(decoded_messages, messages)
一开始没直接想到,但是就是觉得可以试试detach操作一下,以为g_loss_adv和g_loss_enc的函数会导致g_loss_dec函数参数发生变化,就给他们加了.detach() ,但是没用。又单独给decoded_messages 和 messages加了.detach() ,还是没用。过了一中午,想着要不给g_loss_dec加个.detach() 试试看?本来都不抱希望,觉得不可能就这么简单,结果真的跑出来了!
所以解决办法就是:
- 找到报错行 [ 求g_loss.backward()]
- 找g_loss中哪个变量出错了 [ g_loss_dec ]
- 对出错变量进行排查,既然说是被原地修改了,那就尝试在函数中用该变量的.detach() 版本,而不直接使用变量本身 [ 加上g_loss_dec.detach() 操作]
修改后的代码片段如下:
self.optimizer_enc_dec.zero_grad()
# target label for encoded images should be 'cover', because we want to fool the discriminator
d_on_encoded_for_enc = self.discriminator(encoded_images)
g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded)
if self.vgg_loss == None:
g_loss_enc = self.mse_loss(encoded_images, images)
else:
vgg_on_cov = self.vgg_loss(images)
vgg_on_enc = self.vgg_loss(encoded_images)
g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)
g_loss_dec = self.mse_loss(decoded_messages, messages)
g_loss_dec_detach = g_loss_dec.detach()
g_loss = (self.config.adversarial_loss * g_loss_adv
+ self.config.encoder_loss * g_loss_enc
+ self.config.decoder_loss * g_loss_dec_detach)
g_loss.backward()
self.optimizer_enc_dec.step()