pytorch detach().numpy()

 

    for epoch in range(EPOCH):
        sum_D = 0
        sum_G = 0
        for step, (images, imagesLabel) in enumerate(train_loader):
            print(step)
            G_ideas = t.randn((BATCH_SIZE, Len_Z, 1, 1))

            G_paintings = G(G_ideas)
            prob_artist0 = D(images)  # D try to increase this prob
            prob_artist1 = D(G_paintings)
            p0 = t.squeeze(prob_artist0)
            p1 = t.squeeze(prob_artist1)

            errD_real = criterion(p0, label_Real)

            errD_fake = criterion(p1, label_Fake)
            # errD_fake.backward()

            errD = errD_fake + errD_real
            errG = criterion(p1, label_Real)
            sum_D=sum_D+errD.detach().numpy()
            sum_G=sum_G+errG.detach().numpy()
            #print("errD is %f"%errD)
            #print("sumD is %f"%sum_D)
            optimD.zero_grad()
            errD.backward(retain_graph=True)
            optimD.step()

            optimG.zero_grad()
            errG.backward(retain_graph=True)
            optimG.step()

今天在实验时直接使用sum_D=sum_D+errD,发现内存快速飙升。后来改成sum_D=sum_D+errD.detach().numpy(),总算没问题了,因为第一种表达式等于是在搭网络节点,当然会不断提升网络容量,提高内存消耗量。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值