resnet18实现猫狗图片的分类

简介

使用猫狗分类数据集中的训练集,共25000张图片。将原始训练集进行拆分,其中20000张用于训练,其余5000张用于测试。分类网络使用ResNet-18,使用了交叉熵损失函数和SGD优化方法。

环境配置

建立Conda虚拟环境,python3.7,几个重要的库:
(1)pytorch 1.7.0

(2)torchvision 0.8.0

(3)opencv-python 4.5.2.52

(4)tqdm 4.61.0

目录结构

在这里插入图片描述

运行方法

必须下载数据集。数据集下载完成后,存放在[工程主目录]/data路径下,首先运行如下命令完成数据集划分。

python prepare_data.py

运行完成后,[工程主目录]/data路径下会生成newtrain和newtest这2个路径,分别存放训练集和测试集。
训练过程

python train.py

训练完成后,在工程主目录下会生成名为resnet18_Cat_Dog.pth的权重文件,推理时会读取该权重文件。

测试过程

python test.py

推理完成后会打印出推理的正确率。

以下是prepare_data.py文件内容,该模块的主要功能是用来划分数据,将全部的数据一部分划分为分类训练集,一部分划分为测试集。

import os
import shutil

def main():
# Step 1:创建训练集路径和测试集路径
    new_train_data_path = os.path.join(os.getcwd(), 'data/newtrain')
    new_test_data_path = os.path.join(os.getcwd(), 'data/newtest')

    if os.path.exists(new_train_data_path) is False:
        os.makedirs(new_train_data_path)
    
    if os.path.exists(new_test_data_path) is False:
        os.makedirs(new_test_data_path)
    
#Step 2:将Cat和Dog类别中id>=10000的图片存到测试集路径中,其他图片存到训练集路径中
    origin_dataset_path = os.path.join(os.getcwd(), 'data/train')
    img_list = os.listdir(origin_dataset_path)
    for img_name in img_list:
        img_name_split = img_name.split('.')
        src_img = os.path.join(origin_dataset_path, img_name)

        if int(img_name_split[1])>=10000:
            shutil.copy(src_img, new_test_data_path)
        else:
            shutil.copy(src_img, new_train_data_path)
        

if __name__ == '__main__':
    main()

以下是DogCatDataset.py。

import os
import cv2
from torch.utils.data import Dataset

class DogCatDataset(Dataset):
    def __init__(self, root_path, transform=None):
        self.label_name = {"Cat": 0, "Dog": 1}
        self.root_path = root_path
        self.transform = transform
        self.get_train_img_info()



    def __getitem__(self, index):
        self.img = cv2.imread(os.path.join(self.root_path, self.train_img_name[index]))
        if self.transform is not None:
            self.img = self.transform(self.img)
        self.label = self.train_img_label[index]
        return self.img, self.label

    def __len__(self):
        return len(self.train_img_name)

    def get_train_img_info(self):
        self.train_img_name = os.listdir(self.root_path)
        self.train_img_label = [0 if 'cat' in imgname else 1 for imgname in self.train_img_name]

    

以下是train.py。

import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torchvision.models as models
import DogCatDataset


def main():
        
    #Step 0:查看torch版本、设置device
    print(torch.__version__)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #Step 1:准备数据集
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    train_data = DogCatDataset.DogCatDataset(root_path=os.path.join(os.getcwd(), 'data/newtrain'),
                                            transform=train_transform)
    train_dataloader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)

    #Step 2: 初始化模型
    model = models.resnet18()
    

    #修改网络结构,将fc层1000个输出改为2个输出
    fc_input_feature = model.fc.in_features
    model.fc = nn.Linear(fc_input_feature, 2)

    #load除最后一层的预训练权重
    pretrained_weight = torch.hub.load_state_dict_from_url(url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)
    del pretrained_weight['fc.weight']
    del pretrained_weight['fc.bias']
    model.load_state_dict(pretrained_weight, strict=False)

    model.to(device)

    #Step 3:设置损失函数
    criterion = nn.CrossEntropyLoss()     #交叉熵损失函数

    #Step 4:选择优化器
    LR = 0.01
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)    

    #Step 5:设置学习率下降策略
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  

    #Step 6:训练网络
    model.train() 
    MAX_EPOCH = 20    #设置epoch=20

    for epoch in range(MAX_EPOCH):
        loss_log = 0
        total_sample = 0
        train_correct_sample = 0    
        for data in tqdm(train_dataloader):
            img, label = data
            img, label = img.to(device), label.to(device)
            output = model(img)

            optimizer.zero_grad()
            loss = criterion(output, label)
            loss.backward()

            optimizer.step()
            
            _, predicted_label = torch.max(output, 1)

            total_sample += label.size(0)
            train_correct_sample += (predicted_label == label).cpu().sum().numpy()
            
            loss_log += loss.item()

            # if total_sample == 2400:
            #     print('mark!')

        #打印信息
        print('epoch: ', epoch)
        print("accuracy:", train_correct_sample/total_sample)
        print('loss:', loss_log/total_sample)
  
        scheduler.step()   #更新学习率
    print('train finish!')
    #Step 7: 存储权重
    torch.save(model.state_dict(), './resnet18_Cat_Dog.pth')


if __name__ == '__main__':
    main()

以下是test.py部分。

import os
import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F 
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import DogCatDataset


def main():
        
    #Step 0:查看torch版本、设置device
    print(torch.__version__)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #Step 1:准备数据集
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    test_data = DogCatDataset.DogCatDataset(root_path=os.path.join(os.getcwd(), 'data/newtest'),
                                            transform=test_transform)
    test_dataloader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)

    #Step 2: 初始化网络
    model = models.resnet18()

    #修改网络结构,将fc层1000个输出改为2个输出
    fc_input_feature = model.fc.in_features
    model.fc = nn.Linear(fc_input_feature, 2)

    #Step 3:加载训练好的权重
    trained_weight = torch.load('./resnet18_Cat_Dog.pth')
    model.load_state_dict(trained_weight)
    model.to(device)

    #Steo 4:网络推理
    model.eval()

    correct_sample = 0
    total_sample = 0
    with torch.no_grad():
        for data in test_dataloader:
            img, label = data  #这里的label啥意思
            img = img.to(device)
            label = label.to(device)
            output = model(img)

            _, predicted_label = torch.max(output, 1)

           
            correct_sample += (predicted_label==label).cpu().numpy()
            total_sample += 1
            #print('Image Name:{},predict:{}'.format(, predicted_label))#这里想提取出文件名


    
    #Step 5:打印分类准确率
    print(correct_sample/total_sample)
  


if __name__ == '__main__':
    main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值