pytorch用法记录(torch.Storage与detach)

1.torch.Storage类

 

 

 

 使用storage()函数把Tensor数据转换为float类型的Storage数据,再使用tolist() 返回一个包含此存储中元素的列表。

2.detach 计算图截断

detach 的意思是,这个数据和生成它的计算图“脱钩”了,即detach就是截断反向传播的梯度流。GAN中,Train D on fake,G生成的数据会传入D,然后计算loss,再反向传播更新。由于backward()的操作是我们希望D(判别器)端的loss更新D但不要影响到 G(生成器)。

#  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()

如上例,对G生成的数据G(d_gen_input)执行detach()操作,判别器D梯度反向传播,就到它自己身上为止,不会继续反向传播到G。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值