Pytorch(6): 自定义数据集,使用resnet18对pokemon分类+transfer learning优化,auto-encoder&VAE实战代码,

本文介绍了如何使用PyTorch自定义数据集进行Pokemon图像分类,详细阐述了数据加载、预处理、ResNet模型构建及训练过程。通过转移学习,模型在测试集上的准确率提升至93%。此外,还探讨了Auto-Encoder及其变种,包括Denoising AutoEncoders、Dropout AutoEncoders、Adversarial AutoEncoder和Variational Auto-Encoder,并提供了相关训练策略和实战代码。
摘要由CSDN通过智能技术生成

1、自定义数据集实战

step1: load data

torch.util.data.Dataset
__ len __:数量
__ getitem __:返回样本

数据预处理

  • image resize
  • data argumentation:rotate, crop
  • normalize :mean,std
  • totensor

将名称存入字典

加载每张图片的地址

load_csv

将images 和label分别存入数据集和标签集

step2:build model

step 3:train and test

step 4:transfer learning: trainer的初始化

train_scratch 代码:

import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    pokemon import Pokemon
from    resnet18 import ResNet18

batchsize=64
lr=1e-3
epochs=10

device=torch.device("cuda")
torch.manual_seed(1234)

train_db=Pokemon("pokemon",64,mode="train")
val_db=Pokemon("pokemon",64,mode="validation")
test_db=Pokemon("pokemon",64,mode="test")

train_loader = DataLoader(train_db,batch_size=batchsize,shuffle=True,num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsize,shuffle=True,num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsize,shuffle=True,num_workers=2)

viz=visdom.Visdom()

def evaluate(model,loader):
    model.eval()

    correct=0
    total=len(loader.dataset)

    for x, y in loader:
        x,y=x.to(device),y.to(device)
        with torch.no_grad():
            logits=model(x)
            pred=logits.argmax(dim=1)
            correct+=torch.eq(pred,y).sum().float().item()

    return correct/total

def main():

    model=ResNet18(5).to(device)
    optimizer=optim.Adam(model.parameters(),lr=lr)
    criterion=nn.CrossEntropyLoss().to(device)

    best_acc,best_epoch=0,0
    global_step=0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        for step,(x,y) in enumerate(train_loader):

            x,y=x.to(device),y.to(device)

            model.train()
            logits=model(x)
            loss=criterion(logits,y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch%1==0:
            val_acc=evaluate(model,val_loader)
            if val_acc>best_acc:
                best_epoch=epoch
                best_acc=val_acc

                torch.save(model.state_dict(),"best.mdl")
                viz.line([val_acc], [global_step], win='val_acc', update='append')
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    #model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)

if __name__ == '__main__':
    main()

ResNet代码:

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):

    def __init__(self,ch_in,ch_out,stride=1):
        super(ResBlk, self).__init__()

        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(ch_out)

        self.extra=nn.Sequential()
        self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                                   nn.BatchNorm2d(ch_out))
        if ch_in!=ch_out:
            self.extra=nn.Sequential(nn.
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值