nn.CrossEntropyLoss()
中两个参数,其中的标签必须为long型(int64)的,不能是float32
hwLabels = torch.Tensor(hwLabels).long()
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step%5==0:
loss_count.append(loss.detach().numpy())
print('{}:\t'.format(step),"\tloss:",loss.item())
#torch.save(cnn,r'save/model')