for epoch in range(EPOCH):
sum_D = 0
sum_G = 0
for step, (images, imagesLabel) in enumerate(train_loader):
print(step)
G_ideas = t.randn((BATCH_SIZE, Len_Z, 1, 1))
G_paintings = G(G_ideas)
prob_artist0 = D(images) # D try to increase this prob
prob_artist1 = D(G_paintings)
p0 = t.squeeze(prob_artist0)
p1 = t.squeeze(prob_artist1)
errD_real = criterion(p0, label_Real)
errD_fake = criterion(p1, label_Fake)
# errD_fake.backward()
errD = errD_fake + errD_real
errG = criterion(p1, label_Real)
sum_D=sum_D+errD.detach().numpy()
sum_G=sum_G+errG.detach().numpy()
#print("errD is %f"%errD)
#print("sumD is %f"%sum_D)
optimD.zero_grad()
errD.backward(retain_graph=True)
optimD.step()
optimG.zero_grad()
errG.backward(retain_graph=True)
optimG.step()
今天在实验时直接使用sum_D=sum_D+errD,发现内存快速飙升。后来改成sum_D=sum_D+errD.detach().numpy(),总算没问题了,因为第一种表达式等于是在搭网络节点,当然会不断提升网络容量,提高内存消耗量。