- 代码:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
datasets.MNIST("./mnist_data", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
test_dataloader = torch.utils.data.DataLoader(
datasets.MNIST("./mnist_data", train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
lr = 0.01
momentum = 0.5
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
num_epochs = 2
for epoch in range(num_epochs):
train(model, device, train_dataloader, optimizer, epoch)
test(model, device, test_dataloader)
torch.save(model.state_dict(), "mnist_cnn.pt")
错误提示:
IndexError: too many indices for tensor of dimension 0.
解决办法:
test_dataloader中0.3081后缺逗号
transforms.Normalize((0.1307,), (0.3081,))
])),