PyTorch数据读取及训练(猫狗大战为例)

以下对接口或者函数的描述从简,建议参照官方文档(记得左上角选择对应版本!)

PyTorch的torchvision下内置了一些常用数据集的接口,比如MNISTCIFARCOCO等,可用以下方式调用:

#  MNIST为例
from torchvision import datasets
dataset = datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

root:数据集根路径
train:用于区分训练或者测试
transform:预处理操作(数据增强、归一化等)
然后使用torch.utils.data.DataLoaderdataset进行封装(封装为Tensor),就可以作为模型训练的输入了。

但是,对于很多公共数据集以及私人数据集,是没有官方接口的,所以就需要用户去实现数据的读取。可以使用torchvision.datasets.ImageFolder接口实现数据导入,但是要注意,这个接口假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名即为类名。例如,可以将猫狗大战的数据集中,可以将cat和dog的图像分别存储到两个文件夹下面,才可以使用这个接口,不过这也很容易操作:

#  在命令行里面运行
#  首先要进入train文件夹,该文件夹下有25000的训练数据,Windows下面可以:
mkdir dog cat
move dog.* dog/
move cat.* cat/

ImageFolder接口使用与MNIST等接口类似,可以查阅官方文档:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None)

又但是,很多数据集(例如猫狗原本维护方式就是train文件夹下面猫狗数据集都包含)也不像上面这也是分类别进行维护的,那就需要用户自定义接口来完成数据读取。PyTorch中自定义的接口需要继承自Dataset类,同时主要需要实现两个类的方法:

  • __getitem__:返回一个样本
  • __len__:返回样本总数

同样以猫狗数据集为例,简要接口可以实现如下,至于不简要的,建议看源码依样画葫芦也很容易实现 >_<

import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T

class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True):

        imgs = [os.path.join(root, img) for img in os.listdir(root)]
        imgs_num = len(imgs)
        
        # 划分训练、验证集,训练:验证 = 4:1
        if  train:
            self.imgs = imgs[:int(0.8 * imgs_num)] # 训练集
        else:
            self.imgs = imgs[int(0.8 * imgs_num):] # 验证集

        if transforms is None:
            if  train:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.RandomResizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

    def __getitem__(self, index):
    
        img_path = self.imgs[index]  # './catdog/train/cat.0.jpg'
        label = 1 if 'dog' in img_path.split('/')[-1] else 0   
        data = Image.open(img_path)
        data = self.transforms(data)

        return data, label

    def __len__(self):
    
        return len(self.imgs)

有了数据读取接口之后,使用torch.utils.data.DataLoaderdataset进行封装就可以作为模型输入了:

# 训练集
train_data = DogCat(data_dir, train=True)  
train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=4)

# 验证集
val_data = DogCat(data_dir, train=False)  
val_dataloader = DataLoader(val_data, batch_size, shuffle=False, num_workers=4)

以上train_dataloaderval_dataloader就可以输入模型进行迭代了,建议使用迁移学习(偷懒)。但是预训练的模型拿过来是不能直接使用的,需要修改最后的全连接层输出类别数为猫狗例子中的classes = 2,然后选择是对模型进行微调还是直接将模型作为特征提取器(设置相应层是否需要参数更新),代码如下:

# 获取resnet18预训练模型
model = models.resnet18(pretrained=True)

def init_model(model, num_classes, feature_extract):
    # 如果不进行微调,则设置相应层的梯度设置
    if feature_extract:
        for param in model.parameters():
            param.requires_grad = False
    # 修改全连接层输出类别为2
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes) 
    return model

model = init_model(model, num_classes, feature_extract)

现在有了模型,和模型输入数据,还需要损失函数和优化器:

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
	lr=0.001, momentum=0.9)

接下来定义训练函数就可以跑起来了:

def train(model, device, train_dataloader, val_dataloader, criterion, optimizer, num_epochs):

    model.to(device)

    # train
    for epoch in range(num_epochs):
        
        model.train()  
        
        for idx, (data,label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            
            pred = model(data)
            loss = criterion(pred, label)

            optimizer.zero_grad()  # 梯度清零
            loss.backward()
            optimizer.step()  # 参数更新
            
            if idx % 100 == 0:
                print("Train Epoch: {}, iteration: {}, Loss: {}".format(epoch, idx, loss.item()))
                
        model.eval() 
        
        total_loss = []
        correct = 0  
        with torch.no_grad():
            for idx, (data, target) in enumerate(val_dataloader):
                data, target = data.to(device), target.to(device)

                output = model(data)
                total_loss.append(criterion(output, target).item())

                pred = output.argmax(dim=1)
                correct += pred.eq(target.view_as(pred)).sum()
        total_loss = np.mean(total_loss)
        acc = correct.item() / len(val_dataloader.dataset) * 100  
        print("Val loss: {}, Accuracy: {}".format(total_loss, acc)) 

总之遇事不决、遇事不懂,看源码(‾◡◝)

本文猫狗简单代码

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值