for epoch in range(EPOCH):
for step, (x, batch_y) in enumerate(train_loader):
# print('x',x.shape)
# x torch.Size([64, 1, 28, 28]) #batch_size,channels,width,height
batch_x = x.view(-1, 28, 28)
# print('batch_x',batch_x.shape)
# batch torch.Size([64, 28, 28])
output = rnn(batch_x)
loss = loss_func(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
rnn需要三个维度
batch_size
input_size
time_step
cnn需要四个维度
batch_size
channels
width
height