ResNet网络训练与验证(二)

import os
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import model
'''
加载数据集
1.根路径 dir/train dir/val
2.数据类型 type=train val
'''
def get_dataLoader(dir,batch_size,type=None):
    #训练集
    if type=="train":
        #转换为tensor
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
        #制作数据集
        train_dataset=datasets.ImageFolder(os.path.join(dir,"train\\"),transform=transform)
        #加载数据集为loader
        train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
        return train_loader

    elif type=="val":
        #转换为tensor
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
        #制作数据集
        val_dataset=datasets.ImageFolder(os.path.join(str,'val\\'),transform=transform)
        #加载数据集为loader
        val_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=True)
        return val_loader

from torch import optim
from torch import nn as nn
import torch
from tqdm import tqdm

if __name__ == '__main__':
    #设置超参数
    epoch_num=100
    lr=0.001
    batch_size=64

    #数据集根目录
    str=r"E:\data"

    #首先获取数据集
    train_loader=get_dataLoader(str,batch_size,"train")
    val_loader=get_dataLoader(str,batch_size,"val")

    #调用gpu
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #调用模型
    model=model.resnet18().to(device)

    #设置loss和优化器
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(model.parameters(),lr=lr)

    for epoch in range(1,epoch_num+1):
        #开始加载数据进行批次训练 处理后的数据按照批次加载如模型中
        for i,(img,label) in enumerate(tqdm(train_loader)):
            #将数据和标签加入到设备中
            img,label=img.to(device),label.to(device)
            #进入训练模式
            model.train()
            #梯度归零
            optimizer.zero_grad()
            #前向传播
            output=model(img)
            #计算loss
            loss=criterion(output,label)
            #反向传播
            loss.backward()
            #更新梯度
            optimizer.step()
            #每十个批次记录一下 acc和loss
            if i%10==0:
                correct=0
                total=0
                #对output进行处理,返回的值为batch_size行,类别列
                _,predicted=torch.max(output.data,1)
                #计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]
                total+=label.size(0)
                correct+=(predicted==label).sum()
                acc=(correct/total)
                print("[epoch:%d] iter:%d  acc:%.3f loss:%.3f"%(epoch,i*batch_size,acc*100,loss.item()))
        #经历一个epoch使用val验证一下模型效果
        with torch.no_grad:
            correct = 0
            total = 0
            for img,label in tqdm(val_loader):
                #模型进入验证模式
                model.eval()
                #将图像,标签送入设备中
                img,label=img.to(device),label.to(device)
                #将图片送入模型中
                output=model(img)
                # 对output进行处理,返回的值为batch_size行,类别列
                _, predicted = torch.max(output.data, 1)
                # 计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]
                total += label.size(0)
                correct += (predicted == label).sum()
                acc = (correct / total)
                print("Val’s acc:%.3f " % acc * 100)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吃鱼不卡次

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值