pytorch 模型训练过程(入门级代码)
加载数据集
Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader来更加快捷的对数据进行操作。
Dataset
数据集包括训练集和测试集,在pytorch torchvision中有些常见的数据集,可以直接写代码设置参数下载。例如下载CIFAR10数据集:
train_data=torchvision.datasets.CIFAR10(root="../data",transform=torchvision.transforms.ToTensor(),train=True,download=True)
test_data=torchvision.datasets.CIFAR10(root="../data",transform=torchvision.transforms.ToTensor(),train=False,download=True)
download为True时会自动下载,下载的数据为PIL格式,要转为tensor。
每个数据集下载时的参数设置可能不一样,可以去官方文档具体查看。
如果是自己的数据集,大多数时候需要重写dataset类以方便对数据操作,入门级先不做介绍。
DataLoader
train_dataloader=DataLoader(dataset=train_data,batch_size=64,shuffle=True,num_works=0,drop_last=