pytorch——迁移学习实战宝可梦精灵分类

本文档介绍了一个使用PyTorch进行图像分类的案例,数据集为宝可梦精灵图片,分为训练集、验证集和测试集。通过加载数据集,进行数据预处理,使用ResNet18模型进行迁移学习,训练新的分类器。在训练过程中,展示了数据可视化和模型评估。最终,模型在测试集上得到准确率,并保存了最佳模型权重。
摘要由CSDN通过智能技术生成

数据集

使用宝可梦精灵的图片数据集。数据集地址:

  • 链接:https://pan.baidu.com/s/1zDERMsV1AvwfZudhuae6Ew
  • 提取码:rs4h

数据集中的每一类别的图片放在一个文件夹中
在这里插入图片描述
数据集共包含5个类别的图片,我们取每个文件夹(类别):

  • 前60%做训练集
  • 60%~80%做验证集
  • 80%~100%做测试集
    在这里插入图片描述

数据集处理

'''
load图片数据集
'''
import torch
import os, glob
import random, csv

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from PIL import Image


class Pokemon(Dataset):

    def __init__(self, root, resize, mode):
        '''
        :param root: 数据集目录
        :param resize: 图片的输出size
        :param mode: train/val/test
        '''
        super(Pokemon, self).__init__()

        self.root = root  # 根目录
        self.resize = resize  # 图片的输出size
        self.name2label = {} # 对目录名(类别)进行编码
        for name in sorted(os.listdir(os.path.join(root))):  # 遍历目录和文件
            if not os.path.isdir(os.path.join(root, name)):  # 如果不是目录(是图片)
                continue

            self.name2label[name] = len(self.name2label.keys())  # 用字典保存类别的编码
        # print(self.name2label)

        '''读入图片数据集'''
        # image, label
        self.images, self.labels = self.load_csv('images.csv')

        '''划分train、val、test集'''
        if mode=='train':  # train: 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val':  # val: 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:  # test: 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]


    def load_csv(self, filename):
        '''
        一次加载进所有图片可能会造成内存不够用,因此我们可以把图片保存到一个csv文件
        :param filename:保存的文件名
        :return:
        '''

        # 如果csv文件不存在,就创建文件
        # 如果csv文件存在,就是之前已经创建过,直接读取就好了
        if not os.path.exists(os.path.join(self.root, filename)):

            '''把所有的文件放到一个list中去。文件的class可以通过路径名来判定'''
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            print(len(images), images)  # 1167

            random.shuffle(images)  # 打乱顺序

            '''写入csv文件'''
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img, label])
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                print('writen into csv file:', filename)

        '''read from csv file'''
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)  # 检查条件,不符合就终止

        return images, labels


    def __len__(self):
        '''
        返回总体样本数量
        :return:
        '''
        return len(self.images)


    def denormalize(self, x_hat):
        '''
        逆标准化处理
        :param x_hat: 标准化的tensor
        :return: 逆标准化的tensor
        '''
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x: [channel, high, wight]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x


    def __getitem__(self, idx):
        '''
        取得当前位置图片
        :param idx: 图片索引
        :return:
        '''

        img, label = self.images[idx], self.labels[idx]

        '''数据增强之后将图片转换为tensor'''
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),  # 图片放大1.25倍
            transforms.RandomRotation(15),  # 随机旋转,在-15° ~ +15°之间
            transforms.CenterCrop(self.resize),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化,这几个数是大范围统计出来的rgb三原色的均值和方差
                                 std=[0.229, 0.224, 0.225])
        ])

        # tf = transforms.Compose([
        #     lambda x:Image.open(x).convert('RGB'),  # string path= > image data
        #     transforms.Resize((self.resize, self.resize)),  # 图片放大1.25倍
        #     transforms.ToTensor(),
        # ])

        img = tf(img)
        label = torch.tensor(label)

        return img, label


def main():
    '''
    可视化查看数据集

    此处需要安装并开启visdom
    安装:pip install visdom
    开启:python -m visdom.server
    '''
    import visdom
    import time
    import torchvision

    viz = visdom.Visdom()

    # 如果图片的存储很标准,可以用这种方法
    # tf = transforms.Compose([
    #                 transforms.Resize((64,64)),
    #                 transforms.ToTensor(),
    # ])
    # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    # loader = DataLoader(db, batch_size=32, shuffle=True)
    #
    # print(db.class_to_idx)
    #
    # for x,y in loader:
    #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    #
    #     time.sleep(10)


    # 通用的方法
    db = Pokemon('pokemon', 64, 'train')

    x,y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    # 加载一张图片
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    # viz.image(x, win='sample_x', opts=dict(title='sample_x'))

    # 加载一个batch的图片
    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)


if __name__ == '__main__':
    main()

迁移学习网络

原理

Pokemon和ImageNet都需要图片中提取特征,因此存在某些共性的knowledge。因此我们可以利用更加通用的ImageNet的模型,帮我们解决特定的图片分类任务。

我们采用torchvision.models中训练好的resnet18,使用它训练好的卷积部分提取图像特征,并训练新的分类器处理我们提取到的特征。

这样我们只需要训练分类器,而不用再训练特征提取器,因此可以减少所需训练量。
在这里插入图片描述

代码实现

辅助文件:utils.py

from matplotlib import pyplot as plt
import torch
from torch import nn

'''
定义一个神经网络层
第一个维度保持,其他维度打平成一个维度
'''
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


'''
把image打印在matplotlab上
'''
def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

实现网络构建,网络训练与评估的文件:train_transfer.py

'''
利用迁移学习

torchvision提供了训练好的resnet18、resnet34、resnet50...

此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''

import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader

from pokemon import Pokemon
from utils import Flatten

# 引入已经训练好的model
from torchvision.models import resnet18



batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
# 每次会开启num_work个线程,分别去加载dataset里面的数据,直到每个worker加载数据量为batch_size 大小(总共num_work*batch_size)才会进行下一步训练
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


viz = visdom.Visdom()

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():  # 不计算梯度
            logits = model(x)  # 前向运算
            pred = logits.argmax(dim=1)  # 选出输出层最大的元素
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def main():

    '''初始化网络'''
    trained_model = resnet18(pretrained=True)  # 已经训练好的model
    # x: [b, 3, 224, 224]
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 3, 224, 224] => [b, 512, 1, 1] # 取出从0到17层,作为特征提取器
                          Flatten(),  # [b, 512, 1, 1] => [b, 512] # 自己定义的类,改变tensor维度
                          nn.Linear(512, 5)  # [b, 512] => [b, 5] # 随机初始化的一个新的线性层,作为分类器
                          ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    '''记录实验结果参数'''
    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    '''训练与评估'''
    for epoch in range(epochs):

        '''训练一次模型'''
        for step, (x, y) in enumerate(train_loader):  # 遍历
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)

            # logits: [b, 5]
            # y: [b]
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        '''评估模型'''
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')  # 保存评估结果最好的模型

                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    '''加载最优模型'''
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    '''测试模型'''
    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)



if __name__ == '__main__':
    main()

PyTorch是一个流行的深度学习框架,常用于计算机视觉和自然语言处理任务。迁移学习(Transfer Learning)是利用预训练模型在一个大任务(比如ImageNet中的大量图像分类)上获得的知识,将其应用到一个小规模但相关的任务中的一种方法,例如精灵(如宝可梦)的分类。 在PyTorch中,你可以使用已经训练好的卷积神经网络(CNN),如ResNet、VGG或Inception等,作为基础模型来进行迁移学习。对于宝可梦精灵分类,首先你需要: 1. **准备数据集**:收集并整理包含宝可梦图片的数据集,确保它们被正确地标注为各个类别。 2. **加载预训练模型**:从 torchvision.models 中选择一个适合的模型,如resnet18、resnet50等,并设置其参数为不可训练(`.eval()`)以保持前几层不变。 3. **特征提取**:将模型应用于每个输入图像,仅取输出的特征向量(通常是`model.fc`之前的最后一层)而不是最终的分类结果。 4. **添加新层**:由于原始模型的最后一层可能不适合新的分类任务,通常会添加一层或多层全连接层(Linear Layer)以及适当的激活函数。 5. **微调**:如果希望进一步提升性能,可以选择部分或全部冻结的预训练层进行微调(`.train()`),调整这些层的权重以适应新任务。 6. **训练和评估**:使用训练集对模型进行训练,并用验证集监控性能,然后在测试集上评估模型的实际效果。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值