1.数据预处理
torch.utils.data — PyTorch 1.7.1 documentation
(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform_MIss-Y的博客-CSDN博客_tusimple数据集
1.1继承torch.utils.data.Dataset,重写__getitem__和__len__,来从硬盘读取数据,在__getitem__中实现transform方法
from torch.utils.data import Dataset,DataLoader
import json
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
class CIFAR10_IMG(Dataset):
def __init__(self, root, train=True, transform = None, target_transform=None):
super(CIFAR10_IMG, self).__init__()
self.train = train
self.transform = transform
self.target_transform = target_transform
if self.train :
file_annotation = root + '/annotations/cifar10_train.json'
img_folder = root + '/train_cifar10/'
else:
file_annotation = root + '/annotations/cifar10_test.json'
img_folder = root + '/test_cifar10/'
fp = open(file_annotation,'r')
data_dict = json.load(fp)
assert len(data_dict['images'])==len(data_dict['categories'])
num_data = len(data_dict['images'])
self.filenames = []
self.labels = []
self.img_folder = img_folder
for i in range(num_data):
self.filenames.append(data_dict['images'][i])
self.labels.append(data_dict['categories'][i])
def __getitem__(self, index):
img_name = self.img_folder + self.filenames[index]
label = self.labels[index]
img = plt.imread(img_name)
img = self.transform(img) #可以根据指定的转化形式对数据集进行转换
#return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
return img, label
def __len__(self):
return len(self.filenames)
1.2继承torch.utils.data.DataLoader,加载
train_dataset = datasets.CIFAR10_IMG('./datasets',train=True,transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10_IMG('./datasets',train=False,transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=True)
2.模型搭建(以ResNet为例)
2.1ResNet结构
2.1使用torchvision
import torchvision model = torchvision.models.resnet50(pretrained=True)
2.2继承nn.Module类定义自己的模型,重新实现构造函数__init__构造函数和forward这两个方法
pytorch教程之nn.Module类详解——使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思
3.损失函数
使用torch.nn
二分类交叉熵(需要加softmax或sigmoid)
torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
多分类交叉熵:
torch.nn.CrossEntropyLoss()
L1范数,平均绝对误差
torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
L2范数,平均平方误差
torch.nn.MSELoss()