一个莫名其妙的错误:
data = output[0]
data = torch.squeeze(data)
print(data.shape)
Output:
torch.Size([28, 28])
data = output[0]
output[0] = torch.squeeze(output[0])
print(output[0].shape)
Output:
torch.Size([1, 28, 28])
仿佛squeeze不存在?
试了试这个方法:
data = output
data[0] = torch.squeeze(data[0])
print(data[0].shape)
Output:
torch.Size([1, 28, 28])
说明squeeze可能不能用在高维的tensor上边…
Context:
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
output = cnn(b_x) # cnn output
loss = loss_func(output, b_x) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
print('Training...')
data = output[0]
data = torch.squeeze(output[0])
print(output[0].shape)