GPU判断
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")*
os.path.join*
root = os.path.join('dataset')
将输入层(图片等)转换为Tensor类型
transforms.ToTensor()
数据处理(以MINIST为例)
train_mnist_set = datasets.MNIST(root=root, train=True, transform=transform, download=True)
test_mnist_set = datasets.MNIST(root=root, train=False, transform=transform, download=False)
train_dataloader = DataLoader(train_mnist_set, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_mnist_set, batch_size=batch_size, shuffle=False)
网络搭建
nn.Conv2d()
nn.BatchNorm2d() #数据的归一化处理
nn.Relu(True) #激活函数
nn.Conv2d()
nn.BatchNo