数据集核心的两个类:Dataset(获取数据集)-》, DataLoader(按批次获取,传给模型)
from torch.utils.data import Dataset, DataLoader
-
GPU配置
超参数可以统一设置,参数初始化:
- batch size:每一批次数据集的样本大小
- 初始学习率(初始)
- 训练次数(max_epochs):跑几轮
- GPU配置
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 100
# 方案一:指定GPU的方式
//本例子指定0、1
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号
//本例子智能选择,若有GPU则用第2块GPU1
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号
-
数据预处理
1、数据读取
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets//数据集合
#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])//tensor化和归一化
train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
//在下载完数据集后对数据集进行transform变换
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)
2、查看数据集及标签
#查看数据集
import matplotlib.pyplot as plt
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
dataiter = iter(test_cifar_dataset)
plt.show()
for i in range(10):
images, labels = dataiter.__next__()///核心!!!!!!
print(images.size())
print(str(classes[labels]))
# images = images.numpy().transpose(1, 2, 0) # 把channel那一维放到最后
# plt.title(str(classes[labels]))
# plt.imshow(images)
3、 使用dataload才能进行epoach(固定流程)
#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
//一批次16个、4个线程处理、顺序打乱增强泛化能力、最后不满batch_size个的那组不要了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset,
batch_size=batch_size, num_workers=4,
shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(test_cifar_dataset,
batch_size=batch_size, num_workers=4,
shuffle=False)
4、自己构建数据集
Dataset类主要包含三个函数:
- init: 用于向类中传入外部参数,同时定义样本集
- getitem: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据
- len: 用于返回数据集的样本数
#自定义 Dataset 类
class MyDataset(Dataset):
//初始化步骤获取路径
//数据集路径、标签路径、图片路径、是否对数据集转换
def __init__(self, data_dir, info_csv, image_list, transform=None):
"""
Args:
data_dir: path to image directory.
info_csv: path to the csv file containing image indexes
with corresponding labels.
image_list: path to the txt file contains image names to training/validation set
transform: optional transform to be applied on a sample.
"""
label_info = pd.read_csv(info_csv)
image_file = open(image_list).readlines()
self.data_dir = data_dir
self.image_file = image_file
self.label_info = label_info
self.transform = transform
//
def __getitem__(self, index):
"""
Args:
index: the index of item
Returns:
image and its labels
"""
image_name = self.image_file[index].strip('\n')
raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
label = raw_label.iloc[:,0]
image_name = os.path.join(self.data_dir, image_name)
image = Image.open(image_name).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.image_file)
-
划分训练集、验证集、测试集
见上
-
选择模型
-
设定损失函数&优化方法
-
模型效果评估