一个项目学会pytorch

项目场景:

一个项目学会pytorch

动物二分类

网络搭建和迁移学习


动物二分类


1.数据预处理

import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import os
import random
from PIL import Image

class hym_data(Dataset):
    def __init__(self,img_h=256,img_w=256,path="./data/hyma_data",
                 mode='train',preprocess=True):

        self.mode = mode
        self.img_h = 256
        self.img_w = 256
        self.path = path
        self.preprocess = preprocess

        if self.mode is 'train':
            self.path = self.path+'/train'
            self.transform = transforms.Compose([
                transforms.Resize(size=(self.img_h,self.img_w)),
                transforms.RandomRotation(15),
                transforms.RandomCrop(self.img_w,padding='4'),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
            ])
        else:
            self.path = self.path+'/val'
            self.transform = transforms.Compose([
                transforms.Resize(size=(self.img_h,self.img_w)),
                transforms.ToTensor(),
                transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
            ])

        self.file_list = self.get_filename_list()
        random.shuffle(self.file_list)


    def __len__(self):

        return len(self.file_list)

    def __getitem__(self, item):
        img_name = self.file_list[item]
        if 'ants' in img_name:
            lable = 1
        else:
            lable = 0
        img = Image.open(self.path+'/'+img_name)
        if self.preprocess:
            img = self.transform(img)

        return img , lable

    def get_filename_list(self):
        file_ants = os.listdir(self.path+'/ants')
        file_bees = os.listdir(self.path+'/bees')
        file_list = ['ants/'+file for file in file_ants] + ['bees/'+file for file in file_bees]

        return file_list
if __name__ == '__main__':
    hym = hym_data(mode='val',preprocess=True)
    it = iter(hym)
    img,lable = next(it)

    print(img)
    print(lable)

2.网络搭建和迁移学习

import torch
from torchvision import models
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
from torch.optim import lr_scheduler
from hym_data import *


class Trainer(object):
    def __init__(self, lr=0.005, batch_size=32, num_epoch=64,
                 train_data=None, test_data=None, mode="finetune"):
        self.lr = 0.005
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.mode_path = "./mode1"

        self.data_loader = Data.DataLoader(dataset=train_data,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=0)
        self.test_loader = Data.DataLoader(dataset=test_data,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=0)

        self.loss = torch.nn.CrossEntropyLoss()
        self.device = torch.device("cuda" if torch.cuda.is_available()
                                   else "cpu")

        if mode is "finetune":
            print("wei tiao xue xi")
            self.model = models.resnet18(pretrained=True)
        elif mode is "fixed":
            print("gu ding biao xue xi")
            self.model = models.resnet18(pretrained=True)
            for parm in self.mode.parameters():
                parm.requires_grad = False
        else:
            self.model = models.resnet18(pretrained=False)

        num_fc = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_fc, 2)
        self.model = self.model.to(self.device)
        self.optim = torch.optim.Adam(self.model.parameters(),
                                      lr=self.lr,
                                      betas=(0.5, 0.99))

        self.lr_sche = lr_scheduler.StepLR(self.optim,
                                           step_size=20,
                                           gamma=0.1)

    def train(self):
        best_acc = []
        acc_list = []
        for epoch in range(self.num_epoch):
            self.model.train()
            epoch_loss = 0

            for i, (bx, by) in enumerate(self.data_loader):
                bx = bx.to(self.device)
                by = by.to(self.device)

                pre_y = self.model(bx)

                loss = self.loss(input=pre_y, target=by)

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                epoch_loss += loss.item()

            self.lr_sche.step()
            curr_acc = self.test()
            acc_list.append(curr_acc)
            print("epoch :", epoch, "sun shi zhi :", epoch_loss,
                  "ce shi zheng que lv :", curr_acc)

            if curr_acc > best_acc:
                best_acc = curr_acc
                print("zheng que lv :", best_acc)
                if os.path.exists(self.mode_path) is False:
                    os.makedirs(self.mode_path)
                torch.save(self.model.state_dict(), self.mode_path + "/transfer.pkl")
        return acc_list

    def test(self):
        acc = 0
        for i, (bx, by) in enumerate(self.test_loader):
            bx = bx.to(self.device)
            by = by.to(self.device)
            pred = self.model(bx)
            _, preds = torch.max(pred, 1)
            acc += torch.sum(preds == by.data)
        acc = acc.double() / self.test_loader.dataset.__len__()
        return acc.item()

3.模型训练和绘图


if __name__ == '__main__':
    print("1========================")

    train_data = hym_data(img_h=128, img_w=128, mode="train", preprocess=True)
    test_data = hym_data(img_h=128, img_w=128, mode="val", preprocess=True)
    torch.cuda.empty_cache()
    print("2========================")

    lr = 0.005
    batch_size = 64
    num_epoch = 32
    print("3=========================")

    trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
                      train_data=train_data, test_data=test_data,
                      mode="finetune")

    acc_list_finetune = trainer.train()

    torch.cuda.empty_cache()

    trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
                      train_data=train_data, test_data=test_data,
                      mode="fiexd")

    acc_list_fiexd = trainer.train()

    torch.cuda.empty_cache()

    trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
                      train_data=train_data, test_data=test_data,
                      mode="other")

    acc_list_other = trainer.train()

    x = range(num_epoch)
    plt.figure()
    plt.plot(x,acc_list_finetune,lable="finetune")
    plt.plot(x,acc_list_fiexd,lable="fiexd")
    plt.plot(x,acc_list_finetune,lable="finetune")
    plt.title("acc======,lr"+str(lr))
    plt.xticks(x)
    plt.legend()
    plt.savefig("./saved/transfer_acc.jpg")
    plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值