解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题
问题描述
学习Pytorch官方的tutorial时,在教程的第四部分(Training a Classifier)中会看到作者展示的代码:
etc.
于是,我就把上面展示代码的如数复制到了pycharm上,想着直接运行。
直接copy得到这样的运行结果:
出错了!!
解决办法
好端端的怎么会出错呢??仔细看看异常对象和python提供的描述:
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
原来问题可以通过添加一段代码解决:
if __name__ == '__main__':
只要把它放在运行的文件的开头就行了
像这样:
if __name__=='__main__':
import torch
import torchvision
import torchvision.transforms as transforms
# function to show an image
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
print(trainloader)
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))