PyTorch 数据读取

数据集核心的两个类:Dataset(获取数据集)-》, DataLoader(按批次获取,传给模型)

from torch.utils.data import Dataset, DataLoader
  • GPU配置

超参数可以统一设置,参数初始化:

  • batch size:每一批次数据集的样本大小
  • 初始学习率(初始)
  • 训练次数(max_epochs):跑几轮
  • GPU配置
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 100
# 方案一:指定GPU的方式
//本例子指定0、1
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号
//本例子智能选择,若有GPU则用第2块GPU1
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号
  • 数据预处理

1、数据读取

# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets//数据集合

#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                   ])//tensor化和归一化


train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
//在下载完数据集后对数据集进行transform变换
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)

2、查看数据集及标签

#查看数据集
import matplotlib.pyplot as plt
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

dataiter = iter(test_cifar_dataset)
plt.show()
for i in range(10):
    images, labels = dataiter.__next__()///核心!!!!!!
    print(images.size())
    print(str(classes[labels]))
#     images = images.numpy().transpose(1, 2, 0)  # 把channel那一维放到最后
#     plt.title(str(classes[labels]))
#     plt.imshow(images)

3、 使用dataload才能进行epoach(固定流程)

#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
//一批次16个、4个线程处理、顺序打乱增强泛化能力、最后不满batch_size个的那组不要了

train_loader = torch.utils.data.DataLoader(train_cifar_dataset, 
                                           batch_size=batch_size, num_workers=4, 
                                           shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(test_cifar_dataset, 
                                         batch_size=batch_size, num_workers=4, 
                                         shuffle=False)

 4、自己构建数据集

Dataset类主要包含三个函数:

  • init: 用于向类中传入外部参数,同时定义样本集
  • getitem: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据
  • len: 用于返回数据集的样本数
#自定义 Dataset 类
class MyDataset(Dataset):
//初始化步骤获取路径
//数据集路径、标签路径、图片路径、是否对数据集转换
    def __init__(self, data_dir, info_csv, image_list, transform=None):
        """
        Args:
            data_dir: path to image directory.
            info_csv: path to the csv file containing image indexes
                with corresponding labels.
            image_list: path to the txt file contains image names to training/validation set
            transform: optional transform to be applied on a sample.
        """
        label_info = pd.read_csv(info_csv)
        image_file = open(image_list).readlines()
        self.data_dir = data_dir
        self.image_file = image_file
        self.label_info = label_info
        self.transform = transform
//
    def __getitem__(self, index):
        """
        Args:
            index: the index of item
        Returns:
            image and its labels
        """
        image_name = self.image_file[index].strip('\n')
        raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
        label = raw_label.iloc[:,0]
        image_name = os.path.join(self.data_dir, image_name)
        image = Image.open(image_name).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_file)
  • 划分训练集、验证集、测试集

见上

  • 选择模型

  • 设定损失函数&优化方法

  • 模型效果评估

  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值