java读取网络图片_Dataloader读取图片集并训练网络

本文介绍了如何使用Java读取网络上的图片集合,并结合Dataset和Dataloader进行处理。利用ResNet34模型进行训练,详细展示了数据读取、网络构建、训练前的准备、网络训练过程、验证网络以及模型保存和运行的步骤。
摘要由CSDN通过智能技术生成

使用Dataset制作好数据集之后,可以用Dataloader进行读取,然后用resnet34进行训练。

具体代码及注释如下

1 模块导入

其中data_read是利用Dataset制作数据集时写的文件

# 从data_read文件中读取函数
# data_read是创建的数据集制作函数
from data_read import ImageFloder, train_transform, test_transform
import numpy as np
import torch
# torch.nn用于网络的自定义
import torch.nn as nn
# torch.optim用于训练过程中参数的更新
import torch.optim as optim
# DataLoader用于储存数据,方便使用
from torch.utils.data import DataLoader
import torchvision
# 如果不是自己建立网络,可以从torchvision.models读取到已有的网络
# 然后对网络做适当的修改
from torchvision.models import vgg16, resnet34
import os
from os.path import join

2 数据的读取

batch_size为每一次处理数据的数量,以60张图片为一组

# 读取数据
trainset = ImageFloder(root='D:/Anaconda3/data/tiny-imagenet-200', subdir='train', transform=train_transform)
traindataloader = DataLoader(trainset, batch_size=60, shuffle=True, num_workers=0)
valset = ImageFloder(root='D:/Anaconda3/data/tiny-imagenet-200', subdir='val', transform=test_transform)
valdataloader = DataLoader(valset, batch_size=60, num_workers=0)

3 网络构建

# 因为batch_size的存在,因此输入的x的size实际为([60,1,224,224])
# 网络开始搭建,自定义类均要继承 nn.Module 类
# 只需要定义 __init__ 和 forward 即可
# class_number表明最后分类的数量,200意味着一共有200个类型的图片,编号为0到199
class Net(nn.Module):
    def __init__(self, class_number=200):
        super(Net, self).__init__()
        self.net = resnet34(pretrained=True)
        # 修改 resnet34 的全连接层,定义第一次Linear的输出通道数
        self.cov2fc = 1024
        self.net.fc = nn.Sequential(
            nn.Linear(512*1*1, self.cov2fc),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(self.cov2fc, self.cov2fc),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(self.cov2fc, class_number)
        )
    # 建立前向传播
    def forward(self, x):
        x = self.net(x)
        return x

4 网络训练的前期准备

# e_epoch为迭代次数,8代表迭代8次,但实际e_epoch是从0到7
# lr为初始学习率
# step用于学习率的更新
# checkpoint是指输出loss时的节点
e_epoch = 8
lr = 0.01
step = 20
checkpoint = 100

# device用于判断是不是满足cuda运行的条件,如果不满足则使用cpu
# 需要注意的是,用了device定义之后,往后的代码中,涉及到网络内容、参数、及数据集的情况都要加上to(device)
# 用net = Net().to(device)来调用上面创建的网络
# nn.CrossEntropyLoss()用来计算损失loss
# optim.SGD用来更新参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = Net().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

# 制定学习率更新的条件
def re_learningrate(epoch):
    re_lr = lr *(0.1 ** (epoch // step))
    return re_lr

5 定义训练网络

我的理解中,每一个epoch中,要完成整个所有图片数据的训练,因此设置了batch_size=60,数据集的大小为100000,因此需要循环的次数为100000//60+1次(//表示取商,向下取整,在此处表示1666),因此i的范围是(0,100000//60+1-1),当检查点checkpoint设置为100时,即当i=99、199、299、399…时,已经处理的数据量大小应该为6000、12000、18000…

def train(epoch):
    net.train()
    # 更新梯度下降中的学习率
    for p in optimizer.param_groups:
        p['lr'] = re_learningrate(epoch)
    total_loss = 0.0
    # enumerate可以将数据集排序,因此i提取到每一个数据集的序号
    # 每一个数据集中有60张图片,意味着每一个i运行结束后有60张图片进行了训练
    # 调用的data_read函数中已经将图片读取到内存中
    # 因此traindataloader中,imgs可以直接读取到图片,labels则是对应的标签,是用0-199的数值代表的标签
    for i, (imgs, labels) in enumerate(traindataloader):
        imgs, labels = imgs.to(device), labels.to(device)
        # 运行前需要清除原有的grad
        optimizer.zero_grad()
        output = net(imgs)
        # 计算损失,这个loss代表的是一个batch_size大小的数据计算出的损失,不是指每一张图片的损失
        loss = criterion(output, labels)
        # 反向传播
        loss.backward()
        # 梯度下降方法更新参数
        optimizer.step()
        # 计算累加的损失
        total_loss += loss.item()
        # 以checkpoint为检查点,计算平均的损失loss
        # 其中imgs.size()应该为([60,1,224,224])
        # 因为batch_size=60,所以一个epoch中,当i=99时,代表已经循环了训练了100次
        # 因此已经训练的图片量为(i+1)*imgs.size()[0]
        # 总共要训练的图片量为len(traindataloader.dataset)
        # 因为已经训练了100次(checkpoint设置的数量),因此平均的loss为total_loss/float(checkpoint)
        if (i+1) % checkpoint == 0:
            print('Epoch: {} [{}/{}] loss: {}'.format(epoch+1, (i+1)*imgs.size()[0], len(traindataloader.dataset),
                                                      total_loss/float(checkpoint)))
            # 最后需要将总的loss归0,防止循环叠加出现误判
            total_loss = 0.0

6 定义验证网络

def val(epoch):
    net.eval()
    correct = 0.0
    # 仅对网络训练结果进行验证,不需要反向传播,也不需要计算梯度,所以要使用with torch.no_grad()
    with torch.no_grad():
        # 此处的循环次数与train函数中的一样
        # 因为不用每个循环都输出correct,只需要最后整个测试集的correct,因此不用enumerate进行编号
        for imgs, labels in valdataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            output = net(imgs)
            # 以下注释没有考虑总数据集的大小,是将总的图片看为了60张,其实60只是一个batch_size的大小
            # output中其实为60行200列的数组,size大小为([60,200])
            # 假如一个epoch中的60张图片分别属于第1类、第2类、第3类...第200类
            # 逻辑判断的结果,1则代表判断结果是属于这一类
            # 则里面的内容为([[1,0,0,...,0],[0,1,0,...,0],...,[0,0,0,...,1]])
            # dim=1代表在列方向进行处理
            # argmax(dim=1)可以提取出列方向上最大值所在的序号值
            # 结果为([0,1,2,...,199])
            # predict.shape输入预测结果的维数,结果为([60])
            # labels.view()可以将标签文件中的数据按照predict.shape()的维度重组
            # 然后用predict.eq()判断两者是不是相等,相等则为1,不相等为0
            # 然后对判断结果进行累加,累加之后需要用item()将其转换为纯数字
            predict = output.argmax(dim=1)
            correct += predict.eq(labels.view(predict.shape)).sum().item()
        #然后计算出识别的正确率,此处为计算每一个epoch的正确率,因此要跳出循环
        print('Epoch: {} Accuracy: {}'.format(epoch+1, correct/float(len(valdataloader.dataset))*100))

7 网络保存及运行

# 每次完成一个epoch保存一次
def save(epoch):
    root = 'D:/file/deep learning/pytorch practice/resnet34'
    stats = {
        'epoch': epoch,
        'model': net.state_dict()
    }
    if not os.path.exists(root):
        os.makedirs(root)
    savepath = join(root, 'model_{}.pth'.format(epoch+1))
    torch.save(stats, savepath)
    print('saving checkpoint in {}'.format(savepath))

if __name__ == '__main__':
    for epoch in range(e_epoch):
        train(epoch)
        val(epoch)
        save(epoch)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值